# 1.1 Import Libraries

In [1]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt

# 1.2 Define Paths

In [9]:
MRI_DATASET_PATH = "/kaggle/input/imagesoasis/Data"

# 1.3 Data Transforms

In [11]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225])
])

# 1.4 Load Data

In [12]:
dataset = ImageFolder(MRI_DATASET_PATH, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)

# 1.5 Load ResNet-50

In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_mri = models.resnet50(pretrained=True)
model_mri.fc = nn.Linear(model_mri.fc.in_features, 4)
model_mri = model_mri.to(device)


In [14]:
print("Total images:", len(dataset))
print("Train size:", len(train_ds))
print("Batches:", len(train_loader))
print("Classes:", dataset.classes)


Total images: 86437
Train size: 69149
Batches: 2161
Classes: ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']


In [15]:
from torchvision.datasets import ImageFolder

dataset = ImageFolder(MRI_DATASET_PATH)

print("Classes found:", dataset.classes)
print("Total images:", len(dataset))

Classes found: ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']
Total images: 86437


In [16]:
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((160,160)),   # faster
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225])
])

dataset = ImageFolder(MRI_DATASET_PATH, transform=transform)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=32, num_workers=2)

# 1.6 Train MRI Model

In [17]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_mri.parameters(), lr=1e-4)

for epoch in range(5):
    model_mri.train()
    running_loss = 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        outputs = model_mri(x)
        loss = criterion(outputs, y)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/5], Loss: {running_loss/len(train_loader):.4f}")


Epoch [1/5], Loss: 0.0945
Epoch [2/5], Loss: 0.0199
Epoch [3/5], Loss: 0.0130
Epoch [4/5], Loss: 0.0109
Epoch [5/5], Loss: 0.0083


In [22]:
def accuracy(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == y).sum().item()
            total += y.size(0)
    return correct / total

In [23]:
for epoch in range(5):
    model_mri.train()
    running_loss = 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        outputs = model_mri(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # Compute validation accuracy
    val_acc = accuracy(model_mri, val_loader)

    print(f"Epoch [{epoch+1}/5], Loss: {running_loss/len(train_loader):.4f}, Val Accuracy: {val_acc*100:.2f}%")


Epoch [1/5], Loss: 0.0073, Val Accuracy: 99.97%
Epoch [2/5], Loss: 0.0054, Val Accuracy: 99.85%
Epoch [3/5], Loss: 0.0059, Val Accuracy: 99.92%
Epoch [4/5], Loss: 0.0042, Val Accuracy: 99.98%
Epoch [5/5], Loss: 0.0016, Val Accuracy: 100.00%


## Data Augmentation

In [24]:
transforms.RandomRotation(15),
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(224, scale=(0.8,1.0))

RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear, antialias=True)

## Patient-Level Prediction Example

In [26]:
import numpy as np

def patient_level_predict(model, patient_slices):
    model.eval()
    outputs = []
    with torch.no_grad():
        for x in patient_slices:
            x = x.to(device).unsqueeze(0)
            out = model(x)
            outputs.append(out.cpu().numpy())
    avg_output = np.mean(outputs, axis=0)
    return np.argmax(avg_output)

# 1.7 Save MRI Model

In [18]:
torch.save(model_mri.state_dict(), "mri_resnet50.pth")

## HANDWRITING MODEL

# 2.1 Load Dataset

In [None]:
HW_PATH = "/kaggle/input/handwriting-data-to-detect-alzheimers-disease"

hw_dataset = ImageFolder(HW_PATH, transform=transform)
train_hw, val_hw = torch.utils.data.random_split(hw_dataset, [80, 20])

train_hw_loader = DataLoader(train_hw, batch_size=16, shuffle=True)


# 2.2 Load ResNet-18

In [None]:
model_hw = models.resnet18(pretrained=True)
model_hw.fc = nn.Linear(model_hw.fc.in_features, 2)
model_hw = model_hw.to(device)

# 2.3 Train Handwriting Model

In [None]:
optimizer = torch.optim.Adam(model_hw.parameters(), lr=1e-4)

for epoch in range(10):
    for x, y in train_hw_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model_hw(x), y)
        loss.backward()
        optimizer.step()
    print(f"HW Epoch {epoch+1} done")

# 2.4 Save Model

In [None]:
torch.save(model_hw.state_dict(), "handwriting_model.pth")

## FEATURE EXTRACTION

# 3.1 Remove Final Layers

In [None]:
mri_feature_extractor = nn.Sequential(*list(model_mri.children())[:-1])
hw_feature_extractor = nn.Sequential(*list(model_hw.children())[:-1])

# 3.2 Extract Features

In [None]:
def extract_features(model, loader):
    features = []
    labels = []
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            f = model(x).squeeze()
            features.append(f.cpu())
            labels.append(y)
    return torch.cat(features), torch.cat(labels)

## FUSION NETWORK

# 4.1 Fusion Model

In [None]:
class FusionNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(512+512, 256),
            nn.ReLU(),
            nn.Linear(256, 4)
        )

    def forward(self, mri, hw):
        x = torch.cat([mri, hw], dim=1)
        return self.fc(x)