In [1]:
from src.dataset._dataset_utils import create_datasets


ds_names = ['atnf', 'biaf', 'bivi', 'cycc', 'vtak']

SEQ_LEN = 30
LOG_SPLITS = False
FIXED_SCALING = [(7, 3000.), (8, 12.), (9, 31.)]
ROOT = './data/clean'

datasets = { x: create_datasets(x, ROOT, FIXED_SCALING, LOG_SPLITS) for x in ds_names }

In [2]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [None]:
LR = 1e-4
BETAS = (0.9, 0.999)
EPS = 1e-8
WEIGHT_DECAY=1e-3

GAMMA = 0.1
STEP_SIZE = 0.1
MILESTONES = [5, 10, 15]
MIN_LR = 1e-7
CRITERION_GAMMA=2.0

OPTIMIZER = 'adam' # 'adam' or 'sgd'
SCHEDULER = 'step' # 'plateau', 'step', or 'multi'
CRITERION = 'ce'   # 'ce' or 'cb_focal'

BATCH_SIZE = 32
HIDDEN_SIZE = 128

EPOCHS = 20

test = datasets

data='atnf'

input_size = datasets[data][0][0][0].shape[1]
train_label_ct = datasets[data][0].target_counts


In [None]:
from typing import Optional
from src.cbfocal_loss import FocalLoss
from src.models.abstract_model import AbstractModel


def get_optimizer(type: str, model: AbstractModel, lr, **kwargs):
  match type:
    case 'adam':
      return torch.optim.Adam(
        model.parameters(),
        lr=lr,
        **kwargs,
      )
    case 'sgd':
      return torch.optim.SGD(
        model.parameters(),
        lr=lr,
        **kwargs,
      )
    case _:
      raise ValueError(f'Unknown optimizer: {type}')
  

def get_scheduler(type: str, optimizer: torch.optim.Optimizer, **kwargs):
  match type:
    case 'plateau':
      return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **kwargs)
    case 'step':
      return torch.optim.lr_scheduler.StepLR(optimizer, **kwargs)
    case 'multi':
      return torch.optim.lr_scheduler.MultiStepLR(optimizer, **kwargs)
    case _:
      raise ValueError(f'Unknown scheduler: {type}')
    

def get_criterion(type: str, train_label_ct: Optional[torch.Tensor] = None, device='cpu', **kwargs):
  match type:
    case 'ce':
      weight = None
      if train_label_ct is not None:
        weight = train_label_ct.max() / train_label_ct
        weight = weight / weight.sum()
        weight = weight.to(device)

      return torch.nn.CrossEntropyLoss(
        weight=weight,
      )
    case 'cb_focal':
      return FocalLoss(
        class_counts=train_label_ct.to(device),
        **kwargs,
      )


In [None]:
from src.models.lnn import LNN
from torch.utils.data import DataLoader


model = LNN(BATCH_SIZE, input_size, HIDDEN_SIZE)
optimizer = get_optimizer(OPTIMIZER, model, LR, betas=BETAS, eps=EPS, weight_decay=WEIGHT_DECAY)
scheduler = get_scheduler(SCHEDULER, optimizer, step_size=STEP_SIZE, gamma=GAMMA)
criterion = get_criterion(CRITERION, train_label_ct, device)

train, valid, test = datasets[data]

train_loader, valid_loader, test_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True), DataLoader(valid, batch_size=BATCH_SIZE, shuffle=False), DataLoader(test, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np


def train(
    model: AbstractModel,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.CrossEntropyLoss | FocalLoss,
    device: torch.device,
    epoch: int,
    writer: SummaryWriter = None,
  ):
  model.train()

  losses = np.zeros(len(dataloader))

  for idx, data in enumerage(dataloader):
    x = data[0].to(device)
    y = data[1].to(device)

    optimizer.zero_grad()
    logits = model(x)

    if len(logits.shape) == 1:
      logits = logits.unsqueeze(0)

    loss = criterion(logits, y)

    loss.backward()
    th.nn.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    if writer is not None:
      for l, (name, param) in enumerate(model.named_parameters()):
        if param.grad is not None:
          writer.add_scalar(f'Gradients/{l:02}_{name}', param.grad.norm().item(), epoch * len(dataloader) + idx)

    losses[idx] = loss.item()

  return losses.sum(), losses.mean()

train_losses = np.zeros(epochs)
valid_losses = np.zeros(epochs)

pb = tqdm(total=EPOCHS, desc="Epochs")
for epoch in range(EPOCHS):
  train_loss, train_loss_avg = train(model, train_loader, optimizer, criterion, device, epoch, writer=None)
  scheduler.step(train_loss)

  