In [None]:
import pandas as pd
from tqdm import tqdm
from lightning.fabric import Fabric
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

import sys
sys.path.append('..')
sys.path.append('../models')

import TCN
import utils

In [None]:
input_size = 16
# 1 + 2 * (kernel_size - 1) * (2 ^ num_layers - 1) should > 10000
kernel_size = 21
num_layers = 3
out_channels = [32] * num_layers
dilation_base = 2

In [None]:
df = pd.read_csv(f'../hms-train/train.csv')
train_df, val_df = train_test_split(df, test_size=0.2, random_state=8056)

train_dataset_high_votes = utils.EEGCleanDataset(train_df, 4)
train_dataset_low_votes = utils.EEGCleanDataset(train_df, (1, 4))
val_dataset = utils.EEGCleanDataset(val_df)

train_loader_high_votes = DataLoader(train_dataset_high_votes, batch_size=64, shuffle=True)
train_loader_low_votes = DataLoader(train_dataset_low_votes, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

model = TCN.TCN(input_size, out_channels, kernel_size, dilation_base)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
kl_loss_fn = nn.KLDivLoss(reduction='batchmean')
ce_loss_fn = nn.CrossEntropyLoss()

# fabric module
fabric = Fabric(accelerator='cuda', precision='16-mixed')
model, optimizer = fabric.setup(model, optimizer)
train_loader_high_votes = fabric.setup_dataloaders(train_loader_high_votes)
train_loader_low_votes = fabric.setup_dataloaders(train_loader_low_votes)
val_loader = fabric.setup_dataloaders(val_loader)

In [None]:
model

In [None]:
epochs = 5
# lr = 0.001    if epoch < 21
# lr = 0.0001    if 21 <= epoch < 42
# lr = 0.00001   if 42 <= epoch < 63
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=21, gamma=0.1)

for epoch in range(epochs):
    running_loss = 0

    model.train()
    for batch_high_votes, batch_low_votes in zip(tqdm(train_loader_high_votes, 
                                                      desc=f'Train {epoch + 1}/{epochs}'
                                                      ), 
                                                 train_loader_low_votes):
        optimizer.zero_grad()
        
        eeg_high_votes, _, class_prob_high_votes = batch_high_votes
        outputs1 = model.forward(eeg_high_votes)
        kl_loss1 = kl_loss_fn(outputs1, class_prob_high_votes)
        
        eeg_low_votes, _, class_prob_low_votes = batch_low_votes
        outputs2 = model.forward(eeg_low_votes)
        kl_loss2 = kl_loss_fn(outputs2, class_prob_low_votes)
        
        # if torch.sum(torch.isnan(outputs)) > 0:
        #     print(outputs)
        #     break
        kl_loss = kl_loss1 + kl_loss2 * 0.2
        fabric.backward(kl_loss)
        optimizer.step()
        
        running_loss += kl_loss.item() * (eeg_high_votes.size(0))

    running_val_loss = 0

    model.eval()
    for val_eeg, val_label, val_class_prob in tqdm(val_loader, desc=f'Validation {epoch + 1}/{epochs}'):
        val_outputs = model.forward(val_eeg)
        pred = torch.argmax(val_outputs, dim=1)
        kl_val_loss = kl_loss_fn(val_outputs, val_class_prob)
        
        running_val_loss += kl_val_loss.item() * val_eeg.size(0)
    
    # scheduler.step()
        
    print(f'Epoch {epoch + 1} Train KL loss: ', running_loss / len(train_dataset_high_votes))
    print(f'Epoch {epoch + 1} Val KL loss: ', running_val_loss / len(val_dataset))