In [13]:
import os
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
from tqdm import tqdm  
from torchvision.models import alexnet, AlexNet_Weights
from torch.optim.lr_scheduler import ReduceLROnPlateau
from collections import Counter
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16, ViT_B_16_Weights
from sklearn.metrics import classification_report, confusion_matrix
import pandas as pd

Dataset Processing and Parameters

In [14]:
data_dir = "/kaggle/input/t1-augmented-testcases-removed/T1_augmented_hflip - Test Cases Removed"
batch_size = 16
num_epochs = 30
learning_rate = 0.0001  # Learning rate parameter


In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
# Transformations: replicate grayscale channels to match ResNet18's input requirements (3 channels)
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3 channels
    transforms.Resize((224, 224)),               # Resize to ResNet18 input size
    transforms.ToTensor(),                       # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet mean
                         std=[0.229, 0.224, 0.225])   # ImageNet std
])

In [17]:
# Load dataset
dataset = datasets.ImageFolder(root=data_dir, transform=transform)

# Prepare 5-fold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# Results dictionary to store validation results for each fold
results = {}
results_train={}

# Perform cross-validation and store metrics for all folds
all_train_losses, all_val_losses = [], []
all_train_accuracies, all_val_accuracies = [], []

In [18]:
# Reverse the class_to_idx mapping to get idx_to_class
idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}

In [19]:
print(f"Number of classes: {len(dataset.classes)}")
print(f"Class names: {dataset.classes}")

Number of classes: 3
Class names: ['AD', 'CN', 'MCI']


In [20]:
class_counts = Counter(dataset.targets)
print(f"Number of images in each class: {class_counts}")

Number of images in each class: Counter({0: 1074, 1: 1005, 2: 998})


In [None]:
# Define model (ViT with 3-channel input for grayscale images)
class ViTClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super(ViTClassifier, self).__init__()
        # Load the pre-trained ViT model
        self.model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        # Modify the final fully connected layer to match the number of classes
        self.model.heads.head = nn.Linear(self.model.heads.head.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

In [None]:
# class ViTClassifier(nn.Module):
#     def __init__(self, num_classes=3):
#         super(ViTClassifier, self).__init__()
#         # Load the pre-trained ViT model
#         self.model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
#         # Freeze the pre-trained layers
#         for param in self.model.parameters():
#             param.requires_grad = False
#         # Replace the classification head with a custom head
#         self.model.heads = nn.Sequential(
#             nn.Linear(self.model.heads.head.in_features, 512),
#             nn.ReLU(),
#             nn.Dropout(0.5),
#             nn.Linear(512, 256),
#             nn.ReLU(),
#             nn.Dropout(0.5),
#             nn.Linear(256, num_classes)
#         )

#     def forward(self, x):
#         return self.model(x)


In [None]:
# early stopping criteria
patience = 5 # Number of epochs to wait for improvement

screen_width = 80
for fold_idx, (train_idx, val_idx) in enumerate(kf.split(dataset)):
    # Create the fold header with "--" padding
    fold_header = f" Fold {fold_idx+1}/{kf.n_splits} "
    padding = (screen_width - len(fold_header)) // 2
    print(f"{'-' * padding}{fold_header}{'-' * padding}")
    
    # Split dataset into train and validation sets based on indices
    train_subset = torch.utils.data.Subset(dataset, train_idx)
    val_subset = torch.utils.data.Subset(dataset, val_idx)
    
    train_loader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    model = ViTClassifier(num_classes=3).to(device)
    # Move the model to the GPU
    model = model.to(device)
        
    criterion = nn.CrossEntropyLoss()
    
    # Use SGD optimizer with momentum
    # optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Add a ReduceLROnPlateau scheduler
    # scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    
    # # Add StepLR scheduler - reduces learning rate by a factor of 0.1 every 5 epochs
    scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

    # Early stopping parameters
    best_val_loss = float('inf')  # Initialize best validation loss
    epochs_without_improvement = 0  # Counter for epochs without improvement

    # Store metrics for plotting later
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Training phase with tqdm progress bar
        model.train()
        running_loss, correct_preds, total_samples = 0.0, 0, 0
        
        train_loader_tqdm = tqdm(train_loader, desc="Training", leave=False)
        for inputs, labels in train_loader_tqdm:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct_preds += (preds == labels).sum().item()
            total_samples += labels.size(0)
        
        epoch_train_loss = running_loss / total_samples
        epoch_train_acc = correct_preds / total_samples
        
        train_losses.append(epoch_train_loss)
        train_accuracies.append(epoch_train_acc)

        print(f"Train Loss: {epoch_train_loss:.4f}, Train Accuracy: {epoch_train_acc:.4f}")
      
        # Validation phase with tqdm progress bar
        model.eval()
        all_preds = []
        all_labels = [] 
        running_loss, correct_preds, total_samples = 0.0, 0, 0
        
        val_loader_tqdm = tqdm(val_loader, desc="Validation", leave=False)
        with torch.no_grad():
            for inputs, labels in val_loader_tqdm:
                inputs, labels = inputs.to(device), labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            
                running_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                
                correct_preds += (preds == labels).sum().item()
                total_samples += labels.size(0)

                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())
                
        # Check if it's the last epoch
        if epoch == num_epochs - 1:
            
            # Convert numeric predictions and labels to class names
            all_labels_names = [idx_to_class[label] for label in all_labels]
            all_predictions_names = [idx_to_class[pred] for pred in all_preds]

            # Calculate confusion matrix
            conf_matrix = confusion_matrix(all_labels, all_preds)

            # Create a DataFrame for better visualization
            conf_matrix_df = pd.DataFrame(
                conf_matrix,
                index=[f"True: {label}" for label in dataset.classes],  # True labels
                columns=[f"Pred: {label}" for label in dataset.classes]  # Predicted labels
            )

            # Print metrics
            print("\nConfusion Matrix:")
            print(conf_matrix_df)

            # Print detailed classification report
            print(classification_report(all_labels, all_preds, target_names=dataset.classes))
        
        epoch_val_loss = running_loss / total_samples
        epoch_val_acc = correct_preds / total_samples
        
        val_losses.append(epoch_val_loss)
        val_accuracies.append(epoch_val_acc)

        print(f"Val Loss: {epoch_val_loss:.4f}, Val Accuracy: {epoch_val_acc:.4f}")
        # scheduler.step(epoch_val_loss)

        # Step the scheduler to update learning rate
        scheduler.step()

        # Log the updated learning rate
        current_lr = scheduler.get_last_lr()
        print(f"Epoch {epoch+1}: Current Learning Rate: {current_lr}")

        # Early stopping logic
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            epochs_without_improvement = 0  # Reset counter
            print("Validation loss improved. ")
            # print("Saving Model...")
            # torch.save(model.state_dict(), f"best_model_fold_{fold_idx+1}.pth")  # Save the best model
        else:
            epochs_without_improvement += 1
            print(f"No improvement in validation loss for {epochs_without_improvement} epoch(s).")

        if epochs_without_improvement >= patience:
            print("Early stopping triggered.")

            # Convert numeric predictions and labels to class names
            all_labels_names = [idx_to_class[label] for label in all_labels]
            all_predictions_names = [idx_to_class[pred] for pred in all_preds]

            # Calculate confusion matrix
            conf_matrix = confusion_matrix(all_labels, all_preds)

            # Create a DataFrame for better visualization
            conf_matrix_df = pd.DataFrame(
                conf_matrix,
                index=[f"True: {label}" for label in dataset.classes],  # True labels
                columns=[f"Pred: {label}" for label in dataset.classes]  # Predicted labels
            )

            # Print metrics
            print("\nConfusion Matrix:")
            print(conf_matrix_df)

            # Print detailed classification report
            print(classification_report(all_labels, all_preds, target_names=dataset.classes))

            break

    all_train_losses.append(train_losses)
    all_val_losses.append(val_losses)
    all_train_accuracies.append(train_accuracies)
    all_val_accuracies.append(val_accuracies)

    # Store validation accuracy of this fold in the results dictionary
    results[f"Fold {fold_idx+1}"] = epoch_val_acc
    results_train[f"Fold {fold_idx+1}"] = epoch_train_acc

In [24]:
# torch.save(model.state_dict(), "/kaggle/working/ViTmodel.pth")

In [25]:
# Calculate average accuracy across all folds
average_accuracy = np.mean(list(results.values()))

# Print results for each fold and average accuracy
print("\nValidation Results:")
for fold_name, accuracy in results.items():
    print(f"{fold_name}: {accuracy:.4f}")
print(f"Average Validation Accuracy: {average_accuracy:.4f}")


Validation Results:
Fold 1: 0.6380
Fold 2: 0.5812
Fold 3: 0.5821
Fold 4: 0.6797
Fold 5: 0.6374
Average Validation Accuracy: 0.6237


In [26]:
# Calculate average accuracy across all folds
average_train_accuracy = np.mean(list(results_train.values()))

# Print results for each fold and average accuracy
print("\nTrain Results:")
for fold_name, accuracy in results_train.items():
    print(f"{fold_name}: {accuracy:.4f}")
print(f"Average Train Accuracy: {average_train_accuracy:.4f}")


Train Results:
Fold 1: 0.6526
Fold 2: 0.6664
Fold 3: 0.6470
Fold 4: 0.6592
Fold 5: 0.6483
Average Train Accuracy: 0.6547


# Train on Entire Dataset and Save the Model

In [None]:
# data_dir = "/kaggle/input/t1-augmented-testcases-removed/T1_augmented_hflip - Test Cases Removed"
# batch_size = 16
# num_epochs = 30
# learning_rate = 0.0001  # Learning rate parameter

In [None]:
# # Instantiate full dataset again
# full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)
# full_loader = DataLoader(full_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# final_model = ViTClassifier(num_classes=3).to(device)

In [None]:
# # Set up optimizer and loss
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(final_model.parameters(), lr=learning_rate)

In [31]:
# Initialize the StepLR scheduler
# scheduler = StepLR(optimizer, step_size=3, gamma=0.1)  # Reduces LR by a factor of 0.1 every 10 epochs

In [None]:
# # Training loop
# final_model.train()
# for epoch in range(num_epochs):
#     running_loss, correct, total = 0.0, 0, 0
#     loop = tqdm(full_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
#     for images, labels in loop:
#         images, labels = images.to(device), labels.to(device)
#         optimizer.zero_grad()
#         outputs = final_model(images)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()

#         running_loss += loss.item()
#         _, preds = torch.max(outputs, 1)
#         correct += (preds == labels).sum().item()
#         total += labels.size(0)
#         loop.set_postfix(loss=loss.item(), acc=100. * correct / total)

#     # print(f"Epoch {epoch+1}: Loss = {running_loss / len(full_loader):.6f}, Accuracy = {100. * correct / total:.4f}%")

#     # Step the scheduler at the end of each epoch
#     # scheduler.step()

#     # Optionally, print the current learning rate
#     current_lr = optimizer.param_groups[0]['lr']
#     print(f"Epoch {epoch+1}: Loss = {running_loss / len(full_loader):.6f}, "
#           f"Accuracy = {100. * correct / total:.4f}%, LR = {current_lr:.15f}")

                                                                                   

Epoch 1: Loss = 1.075899, Accuracy = 40.7540%, LR = 0.000100000000000


                                                                                   

Epoch 2: Loss = 1.038421, Accuracy = 45.7914%, LR = 0.000100000000000


                                                                                   

Epoch 3: Loss = 1.021553, Accuracy = 46.8313%, LR = 0.000100000000000


                                                                                   

Epoch 4: Loss = 1.006558, Accuracy = 49.0738%, LR = 0.000100000000000


                                                                                   

Epoch 5: Loss = 0.987804, Accuracy = 51.6412%, LR = 0.000100000000000


                                                                                   

Epoch 6: Loss = 0.968808, Accuracy = 51.9337%, LR = 0.000100000000000


                                                                                   

Epoch 7: Loss = 0.940862, Accuracy = 54.5661%, LR = 0.000100000000000


                                                                                   

Epoch 8: Loss = 0.925815, Accuracy = 56.8411%, LR = 0.000100000000000


                                                                                   

Epoch 9: Loss = 0.901555, Accuracy = 57.8486%, LR = 0.000100000000000


                                                                                    

Epoch 10: Loss = 0.864301, Accuracy = 60.9685%, LR = 0.000100000000000


                                                                                    

Epoch 11: Loss = 0.847324, Accuracy = 61.5210%, LR = 0.000100000000000


                                                                                    

Epoch 12: Loss = 0.815477, Accuracy = 63.6009%, LR = 0.000100000000000


                                                                                    

Epoch 13: Loss = 0.781264, Accuracy = 66.1033%, LR = 0.000100000000000


                                                                                    

Epoch 14: Loss = 0.756522, Accuracy = 66.4608%, LR = 0.000100000000000


                                                                                    

Epoch 15: Loss = 0.720296, Accuracy = 69.4833%, LR = 0.000100000000000




In [None]:
# # Save the trained model
# torch.save(final_model.state_dict(), "vit_final_model_freezed_layers_modified_layers.pth")
# print("Model trained on full dataset and saved as vit_final_model.pth ✅")

Model trained on full dataset and saved as vit_final_model.pth ✅
