 # 1. Setup and Initializations

In [1]:
import copy
import os
import timm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import json
import random
from PIL import Image
import warnings
import time
from monai.networks.nets import ViT

# Suppress warnings
warnings.filterwarnings("ignore")
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

# Use DataLoader to load the dataset (json includes the first value as the path to the image and the second value as the label)
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, json_data, label_map, custom_transform=None):
        self.data = json_data
        self.label_map = label_map
        self.transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])])
        if custom_transform:
            self.transform = custom_transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_sub_path, label = self.data[idx]
        img_path = os.path.join("data", img_sub_path)
        image = Image.open(img_path).convert("RGB")

        # Transform the images
        image = self.transform(image)

        # Convert label to tensor
        label = torch.tensor(self.label_map[label], dtype=torch.long)

        return image, label

# ------------------
# Dataset Functions
# ------------------
def split_dataset(data_path, train_percent=0.7):
    # Split the dataset into train, val, and test sets (4 directories, one for each class. Combine all directories after splitting)
    train, val, test = [], [], []
    val_percent = ((1-train_percent)) / 2

    for dir in os.listdir(data_path):
        # Skip non directories
        if not os.path.isdir(os.path.join(data_path, dir)):
            continue
        dir_path = os.path.join(data_path, dir)
        dir_json_path = os.path.join(dir_path, "data.json")
        data = json.load(open(dir_json_path, "r"))

        # Randomly split the data into train, val, and test sets.
        shuffled_data = data.copy()
        random.shuffle(shuffled_data)

        train_size = int(train_percent * len(shuffled_data))
        val_size = int(val_percent * len(shuffled_data))

        train_data = shuffled_data[:train_size]
        val_data = shuffled_data[train_size:train_size + val_size]
        test_data = shuffled_data[train_size + val_size:]

        # Add each item in train_data to train
        for item in train_data:
            train.append(item)

        for item in val_data:
            val.append(item)

        for item in test_data:
            test.append(item)

    # Save the split data into separate JSON files
    with open(os.path.join(data_path, "train.json"), "w") as f:
        json.dump(train, f)
    with open(os.path.join(data_path, "val.json"), "w") as f:
        json.dump(val, f)
    with open(os.path.join(data_path, "test.json"), "w") as f:
        json.dump(test, f)

    # Print length of each set
    print(f"Train set size: {len(train)}")
    print(f"Validation set size: {len(val)}")
    print(f"Test set size: {len(test)}")

def get_loaders(base_path, label_map, batch_size):
    # Load the dataset
    train_json = json.load(open(os.path.join(base_path, "train.json"), "r"))
    val_json = json.load(open(os.path.join(base_path, "val.json"), "r"))
    test_json = json.load(open(os.path.join(base_path, "test.json"), "r"))

    train_dataset= CustomDataset(train_json, label_map=label_map)
    val_dataset= CustomDataset(val_json, label_map=label_map)
    test_dataset= CustomDataset(test_json, label_map=label_map)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    return train_loader, val_loader, test_loader

# -------------------
# Training Functions
# -------------------
def train_model(model, model_name, epochs, lr, train_loader, val_loader):
    start_time = time.time()
    # Prepare learning rate scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    # Move model to be in same place as training
    model.to(device)

    # Training loop
    for epoch in range(epochs):
        epoch_start_time = time.time()
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            # Fix for monai models
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        scheduler.step()

        # Validation loop
        accuracy = test_model(model, val_loader, print_results=False)
        print(f"Epoch {epoch+1}/{epochs} completed. | Loss: {loss.item():.4f} | Accuracy: {accuracy:.2f}% | elapsed time: {time.time() - epoch_start_time:.2f} seconds")

        # TODO: add this back in once hyperparameter tuning is done
        # Save the model if it is the best so far
        # if epoch == 0 or accuracy > best_accuracy:
        #     best_accuracy = accuracy
        #     torch.save(model.state_dict(), f'{model_name}_best_model_({accuracy:.2f}).pth')

    print(f"Time elapsed: {time.time() - start_time:.2f} seconds")
    # Save model as pth file
    # TODO: consider saving this as a LoRA.
    torch.save(model.state_dict(), 'best_vit_model.pth')

def test_model(model, test_loader, print_results=True):
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            # Fix for monai models
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total

    if print_results:
        print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

def free_up_memory(models:list=None):
    if models is not None:
        for model in models:
            del model # unload models
    torch.cuda.empty_cache() # Clear cache state to make each lr run independent
    torch.cuda.synchronize() # Wait for all kernels in all streams on a device to finish (may not need, but why not right?)

def train_and_test_on_model(model, model_name:str, lrs: list, batch_size: int, epochs: int, train_percent=0.7, data_path="data/EDC", new_dataset_split=True, freeze_layers=[""]):
    label_map = {
        "n": 0, # Normal
        "c": 1, # Cataract
        "d": 2, # Diabetic Retinopathy
        "g": 3  # Glaucoma
    }

    # Copy of the model to reload between runs
    clean_model = copy.deepcopy(model.state_dict())

    # Split dataset (can be disabled to test on same set once it's been done once)
    if new_dataset_split:
        split_dataset(data_path, train_percent=train_percent) # Turn this off to reuse the same dataset split

    # Get Loaders
    train_loader, val_loader, test_loader = get_loaders(data_path, label_map=label_map, batch_size=batch_size)
    first_run = True

    # Run for each lr & frozen layer combo
    for lr in lrs:
        if first_run:
            test_model(model, test_loader)
            first_run = False
        print("---------------------\n\n---------------------")
        for blocks in freeze_layers:
            # Reload the model to reset it
            model.load_state_dict(clean_model)
            free_up_memory() # Clean up memory before each run

            # Freeze subset of blocks
            for name, param in model.named_parameters():
                if name in blocks:
                    param.requires_grad = False
                else:
                    param.requires_grad = True

            print(f"Training model:{model_name} on learning rate: {lr} | batch size: {batch_size} | epochs: {epochs} | Frozen Layers: {blocks}")
            train_model(model, model_name, epochs, lr, train_loader, val_loader)
            acc = test_model(model, test_loader)


            hyps = ({"LR" : lr}, {"batch_size":batch_size}, {"epochs": epochs})
            modl = ({"model": model_name},{"freeze_layers": blocks}, {"accuracy": acc})
            metric = ({"hyperparameters": hyps}, {"model": modl})
            # metric = ({"LR":lr}, {"Accuracy":acc}, {"Batch Size": batch_size})
            # append this metric to json file
            with open(f'metrics_{model_name}_X.json', 'a') as f:
                json.dump(metric, f)
                f.write('\n')

    # Clean up
    free_up_memory([model])

  from .autonotebook import tqdm as notebook_tqdm


# 2. Load and Test Models

In [None]:
# These lrs are for generic timm model only
lrs = [
    #         64 batch                32 batch
    #         10 ep       8 ep        8ep
    # LRs     run1        run2        run1
    1e-4,   # 92.31 |   92.94       91.68
    # 2e-4,   # 76.14
    # 3e-4,   # 74.73
    1e-5,     # 92.94 |   92.78       94.19
    2e-5,   # 92.78 |   93.88       93.25
    # 3e-5    # 93.41 |   93.41       93.72
    ]

is_monai = False
# lrs = [1e-5]
epochs = 7
batch_size = 32
train_percent = 0.7
num_classes = 4
model_name = "UNSET"

free_up_memory()
if is_monai:
    batch_size = 16
    model_name = "monai"
    model = ViT(
        spatial_dims=2,           # Specifies 2D images since MONAI kept doing 3D by default
        in_channels=3, img_size=(224,224), patch_size=(16,16),
        hidden_size=768,          # Embedding dimension (standard for base ViT)
        mlp_dim=3072,             # MLP hidden layer size
        num_layers=16,            # Number of transformer layers (ViT-Base)
        num_heads=16,             # Number of attention heads
        classification=True,      # Enable classification head
        num_classes=num_classes,  # Your target classes
        dropout_rate=0.1
    )

    # Load weights
    state_dict = torch.load('MODEL_NAME_HERE.pth')
    model.load_state_dict(state_dict)
else:
    model_name = "timm"
    model = timm.create_model('vit_base_patch16_224', pretrained=True)
    model.head = nn.Linear(model.head.in_features, num_classes)

    # Blocks 0-11

    #        w/o normalization      w/ normalization
    # b0-1 = 93.56%                 92.78%
    # b0-2 = 93.88%                 %
    # b0-3 = 94.35% # best          92.46%
    # b0-5 = 94.03%                 93.41%
    # b0-8 = 92.78%                 %


    # Freeze subset of blocks
    # blocks = blocks[:2]
    #
    # # Freeze all blocks in the blocks list
    # for name, param in model.named_parameters():
    #     if name in blocks:
    #         param.requires_grad = False
    #     else:
    #         param.requires_grad = True
    # Freeze all layers except the head (low accuracy 48%)
    # for param in model.parameters():
    #     param.requires_grad = False
    # for param in model.head.parameters():
    #     param.requires_grad = True

blocks = ["blocks.0", "blocks.1", "blocks.2", "blocks.3", "blocks.4", "blocks.5", "blocks.6", "blocks.7", "blocks.8", "blocks.9", "blocks.10", "blocks.11"]
freeze_layers = [
    blocks[:2], # first 2 frozen
    blocks[:6], # 50% frozen
    blocks[:9], # 75% Frozen
    # blocks[:12], # 100% frozen
]

# TODO: remove new_dataset_split once ready for full testing again
train_and_test_on_model(model, model_name, lrs=lrs, batch_size=batch_size, epochs=epochs, train_percent=train_percent, new_dataset_split=False, freeze_layers=freeze_layers)

# Final cleanup
free_up_memory([model])