# Notebook Set-up

In [None]:
!pip install timm

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torchmetrics
import timm

import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from PIL import Image
import random
import math

from datetime import date
import time
import csv
import os

!nvidia-smi

# Preprocessing

In [None]:
clean_deluxe = '/global/cfs/projectdirs/cosmo/work/users/xhuang/dr10_1/Clean-Samples/TS40_deluxe_clean'
data_path = clean_deluxe

 * clean deluxe is our highest quality sample, in the Clean-Samples dir you will find TS40 Baseline, which has more samples, but some positive and negative candidates may not be as clear, or have additional noise

In [None]:
xtrain = np.load(f"{data_path}/train_x.npy")
ytrain = np.load(f"{data_path}/train_y.npy").reshape(-1, 1)

xval = np.load(f"{data_path}/val_x.npy")
yval = np.load(f"{data_path}/val_y.npy").reshape(-1, 1)

xtrain = np.clip(xtrain, -1, 1)  
xval = np.clip(xval, -1, 1)

In [None]:
class ReflectiveAugmentor:
    def __init__(self, rotation=True, flip=True, max_angle=180, pad_margin=0.01): # Extra padding beyond what's necessary 
        self.rotation = rotation
        self.flip = flip
        self.max_angle = max_angle
        self.pad_margin = pad_margin

    def __call__(self, image):

        if self.flip:
            if random.random() > 0.5:
                image = TF.hflip(image)
            if random.random() > 0.5:
                image = TF.vflip(image)

        if self.rotation:
            angle = random.uniform(-self.max_angle, self.max_angle)
            
            # Compute padding
            _, h, w = image.shape
            pad_r = self.pad_margin + pad_ratio(angle)
            pad_h = int(pad_r * h)
            pad_w = int(pad_r * w)

            # Reflect padded image
            image = F.pad(image, (pad_w, pad_w, pad_h, pad_h), mode='reflect')

            # Rotate
            image = TF.rotate(image, angle, interpolation=InterpolationMode.BILINEAR, fill=None)

            # Center crop back to original size
            image = TF.center_crop(image, (h, w))

        return image

    @staticmethod
    pad_ratio = lambda x: (sqrt(2)/2) * cos(x % 90) * sqrt((1 - cos(x % 90))/(1 - cos(90 + (x % 90))))

class LensDataset(Dataset):
    def __init__(self, X, y, augmentor=None):
        self.X = X
        self.y = y
        self.augmentor = augmentor

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.X[idx]
        label = self.y[idx]

        if self.augmentor:
            image = self.augmentor(image)

        return image, torch.tensor(label)

In [None]:
train_dataset = LensDataset(xtrain, ytrain, augmentor=ReflectiveAugmentor())
val_dataset = LensDataset(xval, yval)

batch_size = 1024

seed = 42
generator = torch.Generator()
generator.manual_seed(seed)
def worker_seed_func(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def make_shuffled_dataloader(dataset, seed, epoch):
    g = torch.Generator()
    g.manual_seed(seed + epoch)
    return DataLoader(
        dataset = dataset,
        batch_size = batch_size,
        shuffle = True,
        generator = g,
        num_workers = 4,
        pin_memory = True,
        drop_last = True,
        worker_init_fn = worker_seed_func,
    )

val_loader = torch.utils.data.DataLoader(
    dataset = train_dataset, 
    batch_size = batch_size,
    num_workers = 4,
    pin_memory = True,
    drop_last = False,
)

# Training

##### Create CSV training log

In [None]:
### To resume training, set epoch to last save epoch and set LR to last known LR.
START_EPOCH = 0
lr_stopped_at = 0.0

run_name = "F1"
today = date.today()
d1 = today.strftime("%d_%m_%Y") 
# this WILL override multiple runs on same day because of line 17, 
# rename run name to distinguish between runs on the same day

parent_dir = "_Time_Trials"
save_dir = parent_dir + "/" + d1 + run_name

if START_EPOCH == 0:
    !mkdir {parent_dir}
    !rm -rf {save_dir}
    !mkdir {save_dir}
    print("CREATED DIRECTORY")

In [None]:
log_file = f"{save_dir}/training_log.csv"
log_fields = ["epoch", "train_loss", "train_auc", "train_precision", "train_recall", "val_loss", "val_auc", "val_precision", "val_recall"]

# Initialize CSV
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

with open(log_file, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(log_fields)

##### Model Set Up

In [None]:
class MaxViTClassifier(nn.Module):
    def __init__(self, num_classes=1, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model('maxvit_small_256', pretrained=pretrained, features_only=False)
        
        # Replace last layer with a binary classifier head
        self.backbone.head = nn.Sequential(
            nn.BatchNorm1d(self.backbone.num_features),
            nn.Linear(self.backbone.num_features, 128),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128, num_classes),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.backbone(x)

In [None]:
auc_metric = torchmetrics.AUROC(pos_label=1).to(device)
precision_metric = torchmetrics.Precision(threshold=0.9).to(device)
recall_metric = torchmetrics.Recall(threshold=0.9).to(device)

In [None]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0

    for x_batch, y_batch in dataloader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device).float()

        optimizer.zero_grad()
        outputs = model(x_batch).squeeze()
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x_batch.size(0)

        auc_metric.update(outputs, y_batch.int())
        precision_metric.update(outputs, y_batch.int())
        recall_metric.update(outputs, y_batch.int())

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_auc = auc_metric.compute().item()
    epoch_precision = precision_metric.compute().item()
    epoch_recall = recall_metric.compute().item()

    auc_metric.reset()
    precision_metric.reset()
    recall_metric.reset()

    return epoch_loss, epoch_auc, epoch_precision, epoch_recall


def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0

    with torch.no_grad():
        for x_batch, y_batch in dataloader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device).float()
            outputs = model(x_batch).squeeze()
            loss = criterion(outputs, y_batch)

            running_loss += loss.item() * x_batch.size(0)

            auc_metric.update(outputs, y_batch.int())
            precision_metric.update(outputs, y_batch.int())
            recall_metric.update(outputs, y_batch.int())

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_auc = auc_metric.compute().item()
    epoch_precision = precision_metric.compute().item()
    epoch_recall = recall_metric.compute().item()

    auc_metric.reset()
    precision_metric.reset()
    recall_metric.reset()

    return epoch_loss, epoch_auc, epoch_precision, epoch_recall

##### Train

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = MaxViTClassifier()
model = torch.nn.DataParallel(model)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-4)
criterion = nn.BCELoss()

num_epochs = 160
best_epoch = None

for epoch in range(num_epochs):
    train_loss, train_auc, train_prec, train_rec = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_auc, val_prec, val_rec = validate_epoch(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train | loss: {train_loss:.4f} AUC: {train_auc:.4f} Prec: {train_prec:.4f} Rec: {train_rec:.4f}")
    print(f"Val   | loss: {val_loss:.4f} AUC: {val_auc:.4f} Prec: {val_prec:.4f} Rec: {val_rec:.4f}")
    if (best_epoch == None || val_auc >= best_epoch[1]):
        best_epoch = (epoch, val_auc)
        torch.save(model.state_dict(), f"{save_dir}/chkpt_epoch_{epoch}.pt")
        print(f"val_auc of {val_auc:.4f} beat previous best {best_epoch[1]:.4f}. Checkpoint Saved.")
        best_epoch = (epoch, val_auc)
    print("--------")
        
    with open(log_file, mode='a', newline='') as file:
    writer = csv.writer(file)
    writer.writerow([
        epoch + 1,
        train_loss,
        train_auc,
        train_prec,
        train_rec,
        val_loss,
        val_auc,
        val_prec,
        val_rec,
    ])

In [None]:
torch.save(model.state_dict(), f"{save_dir}/endrun.pt")