 # Train The model

In [1]:
import os
import timm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import json
import random
from PIL import Image
import warnings
import time

# Suppress warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


# Dataset Functions

In [2]:
# Use DataLoader to load the dataset (json includes the first value as the path to the image and the second value as the label)
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, json_data, label_map, transform):
        self.data = json_data
        self.transform = transform
        self.label_map = label_map

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_sub_path, label = self.data[idx]
        img_path = os.path.join("data", img_sub_path)
        image = Image.open(img_path).convert("RGB")

        # Transform the images
        image = self.transform(image)

        # Convert label to tensor
        label = torch.tensor(self.label_map[label], dtype=torch.long)

        return image, label

def split_dataset(data_path, train_percent=0.7):
    # Split the dataset into train, val, and test sets (4 directories, one for each class. Combine all directories after splitting)
    train, val, test = [], [], []
    val_percent = ((1-train_percent)) / 2

    for dir in os.listdir(data_path):
        # Skip non directories
        if not os.path.isdir(os.path.join(data_path, dir)):
            continue
        dir_path = os.path.join(data_path, dir)
        dir_json_path = os.path.join(dir_path, "data.json")
        data = json.load(open(dir_json_path, "r"))

        # Randomly split the data into train, val, and test sets.
        shuffled_data = data.copy()
        random.shuffle(shuffled_data)

        train_size = int(train_percent * len(shuffled_data))
        val_size = int(val_percent * len(shuffled_data))

        train_data = shuffled_data[:train_size]
        val_data = shuffled_data[train_size:train_size + val_size]
        test_data = shuffled_data[train_size + val_size:]

        # Add each item in train_data to train
        for item in train_data:
            train.append(item)

        for item in val_data:
            val.append(item)

        for item in test_data:
            test.append(item)


    # Save the split data into separate JSON files
    with open(os.path.join(data_path, "train.json"), "w") as f:
        json.dump(train, f)
    with open(os.path.join(data_path, "val.json"), "w") as f:
        json.dump(val, f)
    with open(os.path.join(data_path, "test.json"), "w") as f:
        json.dump(test, f)

    # Print length of each set
    print(f"Train set size: {len(train)}")
    print(f"Validation set size: {len(val)}")
    print(f"Test set size: {len(test)}")

def get_loaders(base_path, label_map, batch_size):
    # Load the dataset
    train_json = json.load(open(os.path.join(base_path, "train.json"), "r"))
    val_json = json.load(open(os.path.join(base_path, "val.json"), "r"))
    test_json = json.load(open(os.path.join(base_path, "test.json"), "r"))

    transform_rules = [transforms.Resize((224,224)), transforms.ToTensor()] # TODO: add normalization
    train_dataset= CustomDataset(train_json, transform=transforms.Compose(transform_rules), label_map=label_map)
    val_dataset= CustomDataset(val_json, transform=transforms.Compose(transform_rules), label_map=label_map)
    test_dataset= CustomDataset(test_json, transform=transforms.Compose(transform_rules), label_map=label_map)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    return train_loader, val_loader, test_loader

# Model Functions

In [3]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

def train_model(model, model_name, epochs, lr, train_loader, val_loader):
    start_time = time.time()
    # Prepare learning rate scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    model_stagnation_count = 0

    # Move model to be in same place as training
    model.to(device)

    # Training loop
    for epoch in range(epochs):
        epoch_start_time = time.time()
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        scheduler.step()

        # Add accuracy calculation
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total

        print(f"Epoch {epoch+1}/{epochs} completed. | Loss: {loss.item():.4f} | Accuracy: {accuracy:.2f}% | elapsed time: {time.time() - epoch_start_time:.2f} seconds")

        # Save the model if it is the best so far
        if epoch == 0 or accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), f'{model_name}_best_model_({accuracy:.2f}).pth')
            # model_stagnation_count = 0
        # else:
        #     # print(f"Model not improved. Current best accuracy: {best_accuracy:.2f}%")
        #     model_stagnation_count += 1
        #     if model_stagnation_count >= 3:
        #         print(f"Early stopping at epoch {epoch+1} due to no improvement.")
        #         break

    print(f"Time elapsed: {time.time() - start_time:.2f} seconds")
    # ==== SAVE MODEL ====
    torch.save(model.state_dict(), 'best_vit_model.pth')

def test_model(model, test_loader):
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy


In [None]:
# Hyperparameters
lrs = [
    1e-4,   # 92.31
    2e-4,   # 76.14
    3e-4,   #
    1e-5,   #
    2e-5,   #
    3e-5]   #
# lr = 3e-4
num_classes = 4
batch_size = 64
epochs = 10
data_path = "data/EDC/"
label_map = {
    "n": 0, # Normal
    "c": 1, # Cataract
    "d": 2, # Diabetic Retinopathy
    "g": 3  # Glaucoma
}

# Split dataset
split_dataset(data_path, train_percent=0.7) # Turn this off to reuse the same dataset split

# Get Loaders
train_loader, val_loader, test_loader = get_loaders(data_path, label_map=label_map, batch_size=batch_size)

first_run = True

for lr in lrs:
    # Load & test generic model
    generic_model = timm.create_model('vit_base_patch16_224', pretrained=True)
    generic_model.head = nn.Linear(generic_model.head.in_features, num_classes)

    if first_run:
        test_model(generic_model, test_loader)

    # Train generic model
    print(f"\nTraining with learning rate: {lr}")
    train_model(generic_model, 'timm', epochs, lr, train_loader, val_loader)
    acc = test_model(generic_model, test_loader)

    metric = (lr, acc)
    # append this metric to json file
    with open('metrics.json', 'a') as f:
        json.dump(metric, f)
        f.write('\n')

    # Clean up
    del generic_model # unload model
    torch.cuda.empty_cache() # Clear cache state to make each lr run independent
    torch.cuda.synchronize() # Wait for all kernels in all streams on a device to finish (may not need, but why not right?)


# TODO: Test MONAI model by loading it here then training on top of it.
# model = ViT(in_channels=3, img_size=(224, 224), patch_size=(16, 16), classification=True, num_classes=_num_classes) # pos_embed='conv', # dim=768,# depth=12, # heads=12, # mlp_dim=3072, # dropout=0.1)

Train set size: 2949
Validation set size: 631
Test set size: 637
Test Accuracy: 26.37%

Training with learning rate: 0.0001
Epoch 1/10 completed. | Loss: 0.3009 | Accuracy: 83.36% | elapsed time: 28.66 seconds
Epoch 2/10 completed. | Loss: 0.2759 | Accuracy: 88.91% | elapsed time: 28.21 seconds
Epoch 3/10 completed. | Loss: 0.0077 | Accuracy: 91.28% | elapsed time: 28.30 seconds
Epoch 4/10 completed. | Loss: 0.0934 | Accuracy: 86.37% | elapsed time: 28.13 seconds
Epoch 5/10 completed. | Loss: 0.0060 | Accuracy: 89.22% | elapsed time: 28.34 seconds
Epoch 6/10 completed. | Loss: 0.0003 | Accuracy: 90.97% | elapsed time: 28.55 seconds
Epoch 7/10 completed. | Loss: 0.3302 | Accuracy: 88.91% | elapsed time: 28.40 seconds
Epoch 8/10 completed. | Loss: 0.0004 | Accuracy: 92.39% | elapsed time: 28.42 seconds
Epoch 9/10 completed. | Loss: 0.0000 | Accuracy: 92.71% | elapsed time: 28.42 seconds
Epoch 10/10 completed. | Loss: 0.0001 | Accuracy: 92.71% | elapsed time: 28.24 seconds
Time elapsed: 2