# Imports and Data

In [2]:
!pip install -q torchvision
!pip install -q pytorch_lightning
!pip install -q torchmetrics
!pip install -q tensorboard

In [3]:
import numpy as np
import random
import os
from PIL import Image as im
import torch
import torch.nn as nn
from torch.optim import Adam
from torchmetrics import Accuracy
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader, random_split
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import pytorch_lightning as pl

%load_ext autoreload
%autoreload 2

  device: torch.device = torch.device("cpu"),


In [16]:
def set_seed(worker_ID = 0):
  # set seed.
  worker_seed = (torch.initial_seed() + worker_ID) % 2**32
  np.random.seed(worker_seed)
  random.seed(worker_seed)

def load_data(g, PATH_arr):
  # load preprocessed data
  X_train = np.load(PATH_arr[0])
  y_train = np.load(PATH_arr[1])
  mask_train = np.load(PATH_arr[2])
  X_test = np.load(PATH_arr[3])
  y_test = np.load(PATH_arr[4])
  mask_test = np.load(PATH_arr[5])

  # REFORMAT MASKS??
  # How am I going to manage data

  X_train_tensor = torch.Tensor(X_train)
  mask_train_tensor = torch.Tensor(mask_train)
  y_train_tensor = torch.Tensor(y_train)
  X_test_tensor = torch.Tensor(X_test)
  mask_test_tensor = torch.Tensor(mask_test)
  y_test_tensor = torch.Tensor(y_test)

  full_train_dataset = TensorDataset(X_train_tensor, y_train_tensor, mask_train_tensor)
  num_data = len(full_train_dataset) # may need to just hardcode
  train_dataset, valid_dataset = random_split(full_train_dataset, [int(np.floor(num_data*0.8)), int(np.ceil(num_data*0.2))])
  test_dataset = TensorDataset(X_test_tensor, y_test_tensor, mask_test_tensor)

  # params to investigate: num_workers, pin_memory, worker_init_fn, generator
  train_dataloader = DataLoader(train_dataset)
  valid_dataloader = DataLoader(valid_dataset)
  test_dataloader = DataLoader(test_dataset)
  return train_dataloader, valid_dataloader, test_dataloader

In [None]:
# TODO later: DATA AUGMENTATION

# Training

In [None]:
torch.use_deterministic_algorithms(True, warn_only=True)

checkpoint_callback = ModelCheckpoint(save_top_k=3, monitor="val_loss", mode="min",filename="best_val-{epoch:02d}-val_loss:.2f")
model_dict = {}
# to look over: torch.Generator, TensorBoardLogger params, EarlyStopping Params, pl.Trainer params

def test(model_name, PATH_arr, seed):
    pl.seed_everything(seed, workers=True)
    g = torch.Generator()
    g.manual_seed(seed)
    train_dataloader, valid_dataloader, test_dataloader = load_data(g, PATH_arr)
    logger = pl.loggers.TensorBoardLogger('test_logs', name = model_name)
    model_naive = model_dict[model_name]
    early_stopping_callback = EarlyStopping(monitor='val_loss', mode='min', patience=10)
    trainer = pl.Trainer(max_epochs=300,logger=logger, 
                         callbacks=[early_stopping_callback, checkpoint_callback], 
                         auto_lr_find=True, accelerator="gpu", devices=1, 
                         deterministic=True, gradient_clip_val=0.5)
    trainer.fit(model_naive, train_dataloader, valid_dataloader)
    return trainer, test_dataloader

# Parameters

In [None]:
# import python model files

PATH_arr = ['',
            '',
            '',
            '',
            '',
            '']

learning_rate = 3e-4
model_dict = {"model_1": 1,
              "model_2": 1,
              "model_3": 1,
              "model_4": 1,
              "model_5": 1}

model_name = "model_1"

In [None]:
trainer, model, test_dataloader = test(model_name, PATH_arr, 0)

In [None]:
trainer.test(model, dataloaders=test_dataloader, ckpt_path="best", verbose=True)

In [None]:
%load_ext tensorboard
%tensorboard --logdir test_logs/