# Imports

In [None]:
import datetime
import os
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

# Helper functions

In [None]:
# Plotting helpers
def random_plots(dataset):
    fig, axs = plt.subplots(2, 5, layout="tight", figsize=(15, 10))
    for i in range(axs.shape[1]):
        idx = np.random.randint(dataset.__len__())
        img, label = dataset.__getitem__(idx)
        ax0, ax1 = axs[:, i]
        ax0.imshow(img.squeeze(), cmap="gray")
        ax0.set_title(f"Label: {label}")
        ax1.plot(img.squeeze())

In [None]:
# Model trainers and testers
def modeltrainer(
    model,
    optimizer,
    trainloader: DataLoader,
    valloader: DataLoader = None,
    epochs: int = 5,
    criterion=nn.MSELoss(),
    scheduler=None,
    meta: bool = False,
    pre_trained_models: dict = None,
    meta_pred_func=None,
) -> tuple:
    """
    Returns a tuple of trained model and loss lists. Can also train meta models.
    Parameters:
        model (PyTorch Model): An instance of a PyTorch model
        optimizer (nn.optim): An instance of an optimizer linked to the model's parameters
        trainloader (DataLoader): Train dataloader to be used while training
        valloader (DataLoader, optional, None): Validation loader to be used for validation and early stopping
        epochs (int, recommended, 5): No. of epochs to be trained for
        criterion (optional, nn.MSELoss()): Loss function to be used for backprop
        scheduler (optional, None): Learning rate scheduler to be used if needed
        meta (bool, optional, False): Whether the model to be trained is a meta model
        pre_trained_models (dict, optional, None): Pre-trained models whose outputs are to be used in meta training
        meta_pred_func (function, optional, None): Function which uses outputs of pre-trained models and meta model to provide a new output
    Returns:
        model: Trained PyTorch model
        train_loss: Training loss of each epoch
        val_loss: Validation loss of each epoch
    """
    train_loss = []
    val_loss = []

    # Establishing valid meta parameters for meta mode
    if meta is True:
        assert pre_trained_models is not None, "Provide pre_trained_models!"
        assert meta_pred_func is not None, "Provide meta_pred_func!"

    # Using GPU if available
    if torch.cuda.is_available():
        devname = "cuda"
    else:
        devname = "cpu"
    device = torch.device(devname)

    for epoch in range(epochs):
        model.to(device)
        model.train()  # prep model for training
        pbar = tqdm(total=len(trainloader), leave=True)
        epoch_loss = 0
        epoch_start = time.time()
        for batch, (features, target) in enumerate(trainloader):
            features, target = features.to(device), target.to(device)
            optimizer.zero_grad()
            if meta is True:
                weights = model(features)
                pred = meta_pred_func(
                    pre_trained_models=pre_trained_models, X=features, weights=weights
                )
            else:
                pred = model(features)
            loss = criterion(pred.to(device), target)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.cpu().data.item()
            pbar.update()
            pbar.desc = f"Train loss: {loss.cpu().data.item()} | EP({epoch})"
        train_loss.append(epoch_loss / len(trainloader))
        if scheduler is not None:
            scheduler.step()
        epoch_loss = 0
        epoch_end = time.time()
        epoch_time = time.strftime("%H:%M:%S", time.gmtime(epoch_end - epoch_start))
        print(f"Epoch finished in {epoch_time}")

        if (valloader is not None) and (len(valloader) > 1):
            with torch.no_grad():
                for batch, (features, target) in enumerate(valloader):
                    features, target = features.to(device), target.to(device)
                    if meta is True:
                        weights = model(features)
                        pred = meta_pred_func(
                            pre_trained_models=pre_trained_models,
                            X=features,
                            weights=weights,
                        )
                    else:
                        pred = model(features)
                    loss = criterion(pred, target)
                    epoch_loss += loss.cpu().data.item()
            val_loss.append(epoch_loss / len(valloader))
        pbar.refresh()
        pbar.close()
    return model, train_loss, val_loss


# Optim Func - Change optimizer here if required
def optim(model):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    return torch.optim.Adam(model.parameters(), lr=0.001)

# DataLoaders and Datasets for loading datasets

In [None]:
class CT1Set(Dataset):
    """Creates a PyTorch Image Dataset given a set of file paths and label mapping"""

    def __init__(self, file_paths, label_map, transform=None, target_transform=None):
        self.files = file_paths
        self.transform = transform
        self.label_map = label_map
        self.target_transform = None

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file = self.files[idx]
        image = np.load(file)
        label = self.label_map[file.parent.name]

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            label = self.target_transform(label)

        return image, label

# Model Classes

In [None]:
# CNN Generator
class CNN(nn.Module):
    """Creates a CNN based on specificatioins"""

    def __init__(
        self,
        inp_size: int,
        n_convs: int,
        n_lin: int,
        hid_size: int,
        out_size: int,
        batnorm: bool = False,
        c_start: int = 1,
        increasing: bool = False,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 0,
        activation=nn.LeakyReLU(),
        last_activation=nn.LeakyReLU(),
    ):
        """
        Initializes the CNN with the given parameters.
        Parameters:
            inp_size (int): Number of features given as input to the model. Must be at least 1
            n_convs (int): Number of convolutional layers to use in the model. Must be at least 1
            n_lin (int): Number of hidden layers to use in the model. Must be at least 1.
            hid_size (int): Number of neurons in each hidden layer. Must be at least 1
            out_size (int): Number of variables to be predicted. Must be at least 1
            batnorm (bool, optional, False): Whether to use batch normalization
            c_start (int, optional, 1): Controls channels in the layers. 4**i for i in range(c_start, c_start+n_convs)
            increasing (bool, optional, True): Whether convolutional depth must keep increasing or be pyramidal
            kernel_size (int, optional, 3): Kernel size for each convolutional layer
            stride (int, optional, 1): Stride for kernel in each convolutional layer
            padding (int, optional, 0): Padding in each convolutional layer
            activation: Activation function to use in between linear layers. Must be from nn module
        """
        super().__init__()
        if any(
            [
                inp_size < 1,
                n_convs < 1,
                n_lin < 1,
                hid_size < 1,
                out_size < 1,
                c_start < 1,
                kernel_size < 1,
                stride < 1,
            ]
        ):
            raise ValueError(
                "Please enter a value greater than or equal to 1 for all the integer parameters!"
            )

        # Convolutional Calcs
        # clayers = [4**i for i in range(c_start, c_start + n_convs + 1)]
        # clayers.insert(0, 1)
        self.inp_size = inp_size
        clayers = [4**i for i in range(n_convs)]

        if increasing:
            clayers = clayers
        else:
            if n_convs % 2 == 0:
                clayers += reversed(clayers)
            else:
                clayers += reversed(clayers[:-1])

        # Conv part generation
        conv_part = []
        for i in range(1, len(clayers)):
            if batnorm is True:
                layer = [
                    nn.Conv2d(
                        clayers[i - 1],
                        clayers[i],
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=padding,
                    ),
                    nn.BatchNorm1d(clayers[i]),
                    activation,
                ]
            else:
                layer = [
                    nn.Conv2d(
                        clayers[i - 1],
                        clayers[i],
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=padding,
                    ),
                    activation,
                ]
            conv_part.extend(layer)
        self.cnn = nn.Sequential(*conv_part)

        # First hid calculations
        # conv_outs = [inp_size]
        # for i in range(n_convs + 1):
        #     Li = conv_outs[-1]
        #     Lo = ((Li + 2 * padding - kernel_size) / stride) + 1
        #     if int(Lo) != Lo:
        #         raise ValueError(
        #             "Please check stride, kernel_size and padding to ensure sizes are returned as int!"
        #         )
        #     else:
        #         conv_outs.append(int(Lo))
        # conv_out = conv_outs[-1]
        conv_out = self._calc_first_hid()

        # Linear part generation (MLP)
        lin_part = (
            [nn.Linear(conv_out, hid_size), activation]
            + [nn.Linear(hid_size, hid_size), activation] * n_lin
            + [nn.Linear(hid_size, out_size), last_activation]
        )

        # model = conv_part + lin_part
        self.mlp = nn.Sequential(*lin_part)

    def _calc_first_hid(self):
        dummy = torch.randn(32, self.inp_size)
        dummy = self.cnn(dummy.unsqueeze(1))  # .unsqueeze(1)
        first_hid = dummy.flatten(1).shape[1]
        # print(first_hid)
        return first_hid

    def forward(self, x):
        """Forward passes the input tensor"""
        # x = torch.permute(x, (1, 0))
        x = x.unsqueeze(1)
        cnn_out = self.cnn(x).flatten(1)
        output = self.mlp(cnn_out).squeeze()
        return output