In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split, ConcatDataset, Subset
import random
import os
from pathlib import Path

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

aug_dataset = datasets.ImageFolder(root="/kaggle/input/augmented-alzheimer-mri-dataset/AugmentedAlzheimerDataset", transform=transform)
org_dataset = datasets.ImageFolder(root="/kaggle/input/augmented-alzheimer-mri-dataset/OriginalDataset", transform=transform)

In [None]:
allowed_classes = ['NonDemented', 'VeryMildDemented']

def filtered_classes(dataset):
    idx_dict = dataset.class_to_idx
    allowed_indices = [idx_dict[c] for c in allowed_classes] 
    indices = [i for i, (_, label) in enumerate(dataset) if label in allowed_indices]
    return Subset(dataset, indices), allowed_indices 

aug_dataset, _ = filtered_classes(aug_dataset)
org_dataset, _ = filtered_classes(org_dataset)

full_dataset = ConcatDataset([aug_dataset, org_dataset])

In [None]:
train_size = int(0.8*len(full_dataset))
val_size = len(full_dataset) - train_size 

train_ds, val_ds = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)

In [None]:
device = torch.device("cuda")

model = models.resnet34(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 4)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
epochs = 5  

for epoch in range(epochs):
    model.train()
    total, correct = 0, 0
    epoch_loss = 0

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(imgs)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        total += labels.size(0)
        correct += (output.argmax(1) == labels).sum().item()

    acc = correct / total
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Train Accuracy: {acc:.4f}")

In [None]:
model.eval()
correct = 0
with torch.no_grad():
    for imgs, labels in val_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        correct += (outputs.argmax(1) == labels).sum().item()

val_acc = correct / len(val_ds)
print(f"Validation Accuracy: {val_acc:.4f}")