In [None]:
from PIL import Image
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torchvision.transforms as transforms
from transformers import ViTForImageClassification, ViTImageProcessor
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
import numpy as np

In [None]:
dataset = pd.read_csv('evaluation_dataset/lesion_dataset.csv')
label_encoder = LabelEncoder()

In [None]:
class LesionDataset(Dataset):
    def __init__(self, image_paths, labels, image_processor=None):
        self.image_paths = image_paths
        self.labels = labels
        self.image_processor = image_processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        # Load image
        image = Image.open(self.image_paths[index]).convert('RGB')
        label = self.labels[index]

        if self.image_processor:
            image = self.image_processor(image, return_tensors='pt')['pixel_values'].squeeze(0) # Squeeze out the batch dim.

        return image, label

In [None]:
image_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

In [None]:
class MLPHead(nn.Module):
    def __init__(self, in_features, hidden_size, num_classes):
        super(MLPHead, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_features, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes)
        )
    
    def forward(self, x):
        return self.mlp(x)

In [None]:
n_samples = len(dataset)
predictions = []
true_labels = []

for test_idx in tqdm(range(n_samples), desc='LOOCV Folds', leave=True):
    print(f'LOOCV Fold {test_idx+1}/{n_samples}')
    
    train_indices = [i for i in range(n_samples) if i != test_idx]
    test_indices = [test_idx]
    
    train_images = dataset.loc[train_indices, 'X'].tolist()
    train_labels = dataset.loc[train_indices, 'labels'].tolist()
    train_labels = label_encoder.fit_transform(train_labels)

    test_images = dataset.loc[test_indices, 'X'].tolist()
    test_labels = dataset.loc[test_indices, 'labels'].tolist()
    test_labels = label_encoder.transform(test_labels)
    num_classes = len(label_encoder.classes_)

    
    train_dataset = LesionDataset(
        train_images,
        train_labels,
        image_processor
    )
    test_dataset = LesionDataset(
        test_images,
        test_labels,
        image_processor
    )
    
    train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

    # Initialize model with MLP head
    model = ViTForImageClassification.from_pretrained(
        'google/vit-base-patch16-224',
        num_labels=num_classes,
        ignore_mismatched_sizes=True
    )

    in_features = model.classifier.in_features
    model.classifier = MLPHead(in_features=in_features, hidden_size=256, num_classes=num_classes)
    
    for param in model.vit.parameters():
        param.requires_grad = False
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    
    num_epochs = 5
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
            optimizer.zero_grad()
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_dataloader):.4f}')
    
    model.eval()
    with torch.no_grad():
        for images, labels in test_dataloader:
            outputs = model(images).logits
            _, predicted = torch.max(outputs, dim=1)
            predictions.append(predicted.item())
            true_labels.append(labels.item())
        print(f'Evaluation: True label: {true_labels[-1]}, Predicted label: {predictions[-1]}')
    del model, optimizer, train_dataloader, test_dataloader
    torch.cuda.empty_cache()

In [None]:
# Compute accuracy
accuracy = np.mean(np.array(predictions) == np.array(true_labels))
print(f'LOOCV Accuracy: {100 * accuracy:.2f}%')

# Decode predictions
decoded_predictions = label_encoder.inverse_transform(predictions)
decoded_true_labels = label_encoder.inverse_transform(true_labels)
print("Sample predictions:", decoded_predictions)
print("Sample true labels:", decoded_true_labels)