In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install timm
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report
import timm

In [None]:
!pip install torchsummary

In [None]:
import os

# Get the GitHub token from Kaggle environment
token = os.environ.get("GITHUB_TOKEN")

# Format the clone URL with the token
repo_url = f"https://{token}@github.com/Omid-Nejati/MedViT.git"

# Clone the repository
!git clone {repo_url}

In [None]:
%cd /kaggle/working/MedViT
!pip install -r requirements.txt

In [None]:
!ls

In [None]:
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import random_split, DataLoader

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Load full dataset
dataset = ImageFolder('/kaggle/input/adni-3-class/ADNI_3_class', transform=transform)

# Split into train/test
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [None]:
!pip install timm
!pip install einops

In [None]:
from MedViT import MedViT_small as tiny

In [None]:
model = tiny()

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = tiny().to(device)

In [None]:
from torchsummary import summary
model = model.to(device)
summary(model, (3, 224, 224))


In [None]:
def print_shapes(name):
    def hook(model, input, output):
        print(f"{name}:")
        print(f"  Input shape: {input[0].shape}")
        print(f"  Output shape: {output.shape}")
    return hook

# Register hooks to layers you care about (Conv2d, Linear, etc.)
for name, layer in model.named_modules():
    if isinstance(layer, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.LayerNorm)):
        layer.register_forward_hook(print_shapes(name))

# Feed a dummy input
dummy_input = torch.randn(1, 3, 224, 224).to(device)
model.eval()  # Important to disable dropout/batchnorm randomness
with torch.no_grad():
    _ = model(dummy_input)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, precision_recall_curve
from sklearn.manifold import TSNE
from sklearn.preprocessing import label_binarize
import numpy as np
import seaborn as sns

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Tracking variables
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

# Early stopping
best_acc = 0
patience = 5
counter = 0

# Store for plots
all_preds, all_labels = [], []
all_probs, all_logits = [], []
all_features = []

# Training loop
for epoch in range(10):
    model.train()
    train_loss, correct, total = 0, 0, 0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/10", leave=False)
    for inputs, labels in loop:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        loop.set_postfix(loss=loss.item())

    avg_train_loss = train_loss / len(train_loader)
    train_acc = correct / total
    train_losses.append(avg_train_loss)
    train_accuracies.append(train_acc)

    # Validation
    model.eval()
    val_loss, correct, total = 0, 0, 0

    temp_preds, temp_labels = [], []
    temp_probs, temp_logits, temp_features = [], [], []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            temp_preds.extend(preds.cpu().numpy())
            temp_labels.extend(labels.cpu().numpy())
            probs = F.softmax(outputs, dim=1)
            temp_probs.extend(probs.cpu().numpy())
            temp_logits.extend(outputs.cpu().numpy())

            # Feature extraction for t-SNE
            if hasattr(model, 'forward_features'):
                features = model.forward_features(inputs)
            else:
                features = inputs  # fallback

            temp_features.extend(features.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    val_acc = correct / total
    val_losses.append(avg_val_loss)
    val_accuracies.append(val_acc)

    # Save predictions and features from final epoch only
    if epoch == 19 or (val_acc > best_acc):
        all_preds = temp_preds
        all_labels = temp_labels
        all_probs = temp_probs
        all_logits = temp_logits
        all_features = temp_features

    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Train Acc={train_acc:.4f}, "
          f"Val Loss={avg_val_loss:.4f}, Val Acc={val_acc:.4f}")

    # Early stopping logic
    if val_acc < best_acc + 0.03:
        counter += 1
        if counter >= patience:
            print(f"🛑 Early stopping at epoch {epoch+1}")
            break
    else:
        best_acc = val_acc
        counter = 0

# Save model
torch.save(model.state_dict(), "medvit_model.pth")
print("✅ Model saved as 'medvit_model.pth'")

# === Evaluation Metrics ===
cf = confusion_matrix(all_labels, all_preds)
report = classification_report(all_labels, all_preds, output_dict=True)
acc = report['accuracy']
precision = np.mean([report[str(i)]['precision'] for i in range(len(cf))])
recall = np.mean([report[str(i)]['recall'] for i in range(len(cf))])
specificity = np.mean([
    (cf.sum() - (cf[i].sum() + cf[:, i].sum() - cf[i, i])) / 
    (cf.sum() - cf[:, i].sum()) 
    for i in range(len(cf))
])

print("\nClassification Report:")
print(classification_report(all_labels, all_preds))
print(f"Accuracy: {acc:.4f}")
print(f"Precision (macro): {precision:.4f}")
print(f"Recall (macro): {recall:.4f}")
print(f"Specificity (macro): {specificity:.4f}")

# === Confusion Matrix ===
plt.figure(figsize=(6,5))
sns.heatmap(cf, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

# === Accuracy & Loss Curves ===
epochs_range = range(1, len(train_losses) + 1)

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_losses, label="Train Loss")
plt.plot(epochs_range, val_losses, label="Val Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_accuracies, label="Train Acc")
plt.plot(epochs_range, val_accuracies, label="Val Acc")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Accuracy Curve")
plt.legend()

plt.tight_layout()
plt.show()

# === Ready for Additional Plots ===
# All data for t-SNE, ROC, and PR curves are now stored in:
# - all_labels
# - all_preds
# - all_probs
# - all_logits
# - all_features

# Example of use:
# tsne = TSNE().fit_transform(np.array(all_features))
# fpr, tpr, _ = roc_curve(...)
# precision, recall, _ = precision_recall_curve(...)
