In [10]:
# Core Libraries
import os, random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# pytorch models
import pytorch_lightning as L
import torch.nn as nn
import torch
from torchvision import models
from model.model_v1 import *

# dataset imports
from torch.utils.data import Dataset, DataLoader
from dataset_loader import *

# Scikit Learn
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix

In [None]:
# Visualizaion of dataset properties.

In [None]:
# Setting the model.
model = ResNetFineTuner(num_classes=YOUR_NUM_CLASSES, lr=1e-4, freeze_backbone=True)

In [None]:
# Importing the dataset.
data_path = "data"
dm = BRISCDataModule(data_path)
trainer = L.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=50
    #,callbacks=[loss_tracker]
)
trainer.fit(model, dm)

In [None]:
# Performing data augmentation.

In [None]:
# Training and fine tuning.
class ResNetFineTuner(pl.LightningModule):
    def __init__(self, num_classes=10, lr=1e-3, freeze_backbone=True):
        super().__init__()
        self.save_hyperparameters()

        # Load pretrained ResNet50
        backbone = models.resnet50(pretrained=True)
        if freeze_backbone:
            for param in backbone.parameters():
                param.requires_grad = False

        # Replace classifier head
        num_ftrs = backbone.fc.in_features
        backbone.fc = nn.Linear(num_ftrs, num_classes)
        self.model = backbone

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        val_loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log('val_loss', val_loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

In [None]:
# Validation and Test.

In [None]:
# Printing statistics (seaborn).