In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2
    
# Accessing moduels
import sys,os
sys.path.append(os.path.realpath('../Modules'))

from dataloader.dataset import ADNI3Channels
from dataloader.dataloader import ADNILoader
from dataloader.transforms import Transforms

from model.model import ViT
from model.train import Trainer

from matplotlib import pyplot as plt
# from utils.report import sklearn_classification_report, custom_classification_report

# Dataset and Dataloader Setup

In [None]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}
label2id = {"CN": 0, "MCI": 1, "AD": 2}

transforms = Transforms(image_size=(384, 384), p=0.5)

train_ds = ADNI3Channels("../Data/Training/", transforms=transforms.train())
valid_ds = ADNI3Channels("../Data/Validation/", transforms=transforms.eval())
test_ds = ADNI3Channels("../Data/Test/", transforms=transforms.eval())

In [None]:
image, label = train_ds[0]

print("Image shape:", image.shape)
print("Label:", id2label[label.item()], "\n")

print("Number of training samples:", len(train_ds))
print("Number of validation samples:", len(valid_ds))
print("Number of test samples:", len(test_ds), "\n")

fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(image[i, :, :])
    axes[i].axis("off");

print("Min pixel value =", image.min().item())
print("Max pixel value =", image.max().item())

In [None]:
kwargs = {'train_ds': train_ds,
           'valid_ds': valid_ds,
           'test_ds': test_ds,
         }

train_dataloader = ADNILoader(**kwargs).train_dataloader()
valid_dataloader= ADNILoader(**kwargs).validation_dataloader()
test_dataloader = ADNILoader(**kwargs).test_dataloader()

batch = next(iter(train_dataloader))
print(batch[0].shape)
print(batch[1].shape)

# Model Development

In [None]:
model = ViT(
    pretrained=True,
    model_name="google/vit-base-patch32-384",
    device="cuda:1"
)

kwargs = {
    "epochs": 100,
    "model":model,
    "train_dataloader": train_dataloader,
    "valid_dataloader": valid_dataloader,
    "test_dataloader": test_dataloader,
}

trainer = Trainer(**kwargs)

In [None]:
# trainer.train()

# Model Save and Load

In [None]:
# model.load_best_state("acc")
# model.save_best_state_file("acc", "Best models/", "ViT_Pretrained")
model.load_best_state_file("acc", "Best models/", "ViT_Pretrained")
train_ds.transforms = transforms.eval()

# Evaluation

In [None]:
# trainer.test(trainer.train_dataloader)
# trainer.test(trainer.valid_dataloader)
trainer.test(trainer.test_dataloader)

# 95% CI

In [None]:
import numpy as np
from torchmetrics.classification import MulticlassF1Score, MulticlassAccuracy
import torch

# metric = MulticlassF1Score(num_classes=3, average=None)
metric = MulticlassAccuracy(num_classes=3, average=None)

# indices = range(0, 60)

In [None]:
metrics = []

for itr in range(1000):
    y_true = []
    y_pred = []
    indices = np.random.randint(0, len(test_ds), len(test_ds))
    
    for i in indices:
        x, y = test_ds[i]
        y_true.append(y.item())
    
        logits, _, _ = model(x)
        y_pred.append(logits.argmax(1).cpu().item())
    
    y_true = torch.tensor(y_true)
    y_pred = torch.tensor(y_pred)

    metric_value = metric(y_pred, y_true)
    metrics.append(metric_value)
    print(f"{itr}: {metric_value}")

In [None]:
metric_tensor = torch.stack(metrics)
print(metric_tensor)

# torch.save(metric_tensor, "accuracy_tensor.pt")

In [None]:
# (torch.load("accuracy_tensor.pt")

In [None]:
CN_tensor = metric_tensor[:, 0].numpy()
MCI_tensor = metric_tensor[:, 1].numpy()
AD_tensor = metric_tensor[:, 2].numpy()

In [None]:
print(f"CN  -> {np.mean(CN_tensor)}: {np.percentile(CN_tensor, 2.5)}, {np.percentile(CN_tensor, 97.5)}")
print(f"MCI -> {np.mean(MCI_tensor)}: {np.percentile(MCI_tensor, 2.5)}, {np.percentile(MCI_tensor, 97.5)}")
print(f"AD  -> {np.mean(AD_tensor)}: {np.percentile(AD_tensor, 2.5)}, {np.percentile(AD_tensor, 97.5)}")

In [None]:
f1s = (torch.sum(metric_tensor, dim=1) / 3).numpy()
print(f"Whole  -> {np.mean(f1s)}: {np.percentile(f1s, 2.5)}, {np.percentile(f1s, 97.5)}")