<a href="https://www.kaggle.com/code/anshitavermas/convnext-oasis?scriptVersionId=246756824" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os

dataset_path = "/kaggle/input/alzheimer-mri-4-classes-dataset/Alzheimer_MRI_4_classes_dataset"

# Count images per class
class_counts = {}
total_images = 0

for class_name in os.listdir(dataset_path):
    class_dir = os.path.join(dataset_path, class_name)
    if os.path.isdir(class_dir):
        num_images = len(os.listdir(class_dir))
        class_counts[class_name] = num_images
        total_images += num_images

print("Total Images:", total_images)
print("Class Distribution:", class_counts)


In [None]:
import matplotlib.pyplot as plt
import cv2
import random

def show_samples(dataset_path, class_name, num_samples=5):
    class_dir = os.path.join(dataset_path, class_name)
    images = os.listdir(class_dir)
    sample_images = random.sample(images, min(num_samples, len(images)))
    
    plt.figure(figsize=(10, 5))
    for i, img_name in enumerate(sample_images):
        img_path = os.path.join(class_dir, img_name)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        plt.subplot(1, num_samples, i+1)
        plt.imshow(img)
        plt.title(class_name)
        plt.axis("off")
    
    plt.show()

# Show samples from each class
for class_name in class_counts.keys():
    show_samples(dataset_path, class_name)


In [None]:
!pip install timm albumentations torch torchvision --upgrade albumentations

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import timm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from albumentations import Compose, Normalize, HorizontalFlip, RandomBrightnessContrast, Resize
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from sklearn.model_selection import train_test_split


In [None]:
# dataset_path = "/kaggle/input/alzheimer-mri-4-classes-dataset/Alzheimer_MRI_4_classes_dataset"
# device = torch.device("cuda")

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Path to dataset
dataset_path = "/kaggle/input/alzheimer-mri-4-classes-dataset/Alzheimer_MRI_4_classes_dataset"


In [None]:
# Image transformations
train_transform = Compose([
    Resize(224, 224),  # Resize images to 224x224 (ConvNeXt requirement)
    HorizontalFlip(p=0.5),  # Random horizontal flip
    RandomBrightnessContrast(p=0.2),  # Adjust brightness & contrast
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet Normalization
    ToTensorV2(),  # Convert to PyTorch Tensor
])

val_transform = Compose([
    Resize(224, 224),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])


In [None]:
from torchvision.datasets import ImageFolder
from albumentations.core.composition import OneOf

class AlbumentationsDataset(torch.utils.data.Dataset):
    def __init__(self, folder_path, transform):
        self.dataset = ImageFolder(folder_path)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        img = np.array(img)  # Convert PIL Image to NumPy
        img = self.transform(image=img)["image"]  # Apply albumentations
        return img, label


In [None]:
# Load dataset using ImageFolder
full_dataset = AlbumentationsDataset(dataset_path, transform=train_transform)

# Split dataset
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Apply validation transforms separately
val_dataset.dataset.transform = val_transform


In [None]:
# Define batch size
batch_size = 8

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)


In [None]:
# Load pre-trained ConvNeXt model
model = timm.create_model('convnext_base', pretrained=True, num_classes=4)  # Change classifier for 4 classes
model.to(device)  # Move model to GPU


In [None]:
# Modify classifier head
in_features = model.head.fc.in_features
model.head.fc = nn.Linear(in_features, 4)  # 4 output classes
model.to(device)


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer):
    model.train()
    total_loss, correct = 0, 0
    total_samples = 0

    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)  # Move to GPU

        optimizer.zero_grad()  # Reset gradients
        outputs = model(images)  # Forward pass
        loss = criterion(outputs, labels)  # Compute loss

        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()
        total_samples += labels.size(0)

    return total_loss / len(dataloader), correct / total_samples


In [None]:
def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss, correct = 0, 0
    total_samples = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
            total_samples += labels.size(0)

    return total_loss / len(dataloader), correct / total_samples


In [None]:
num_epochs = 10
best_acc = 0

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_acc = evaluate(model, val_loader, criterion)

    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_convnext_model.pth")


In [None]:
model.load_state_dict(torch.load("best_convnext_model.pth"))
model.to(device)
model.eval()  # Set to evaluation mode

In [None]:
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score

def evaluate_metrics(model, dataloader):
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)  # Get predicted class

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Compute Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    TN, FP, FN, TP = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)  # Handle multi-class cases properly
    
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average="macro")
    sensitivity = recall_score(all_labels, all_preds, average="macro")  # Sensitivity = Recall
    specificity = TN / (TN + FP) if (TN + FP) != 0 else 0  # Avoid division by zero

    return accuracy, precision, sensitivity, specificity, cm



In [None]:
accuracy, precision, sensitivity, specificity, cm = evaluate_metrics(model, val_loader)

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Sensitivity (Recall): {sensitivity:.4f}")
print(f"Specificity: {specificity:.4f}")

print("\nConfusion Matrix:")
print(cm)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(6,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=["Class 0", "Class 1", "Class 2", "Class 3"], yticklabels=["Class 0", "Class 1", "Class 2", "Class 3"])
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.show()

In [None]:
from torchvision import datasets

# Load dataset
dataset = datasets.ImageFolder(root=dataset_path)

# Count images per class
from collections import Counter
class_counts = Counter([label for _, label in dataset.samples])

# Get class names
class_names = dataset.classes  

# Print results
for class_idx, count in class_counts.items():
    print(f"Class '{class_names[class_idx]}' has {count} images.")
