In [3]:
from datasets import load_dataset

ds = load_dataset("Falah/Alzheimer_MRI")

In [4]:
f_5 = ds['train'][:5]
print(f_5)

{'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=L size=128x128 at 0x1662026D8D0>, <PIL.JpegImagePlugin.JpegImageFile image mode=L size=128x128 at 0x16621E026B0>, <PIL.JpegImagePlugin.JpegImageFile image mode=L size=128x128 at 0x16621E029E0>, <PIL.JpegImagePlugin.JpegImageFile image mode=L size=128x128 at 0x16621E028F0>, <PIL.JpegImagePlugin.JpegImageFile image mode=L size=128x128 at 0x16621E02860>], 'label': [2, 0, 3, 3, 2]}


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from torchvision import transforms
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import torch.nn.functional as F

In [7]:
torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device}")

Using cuda


In [8]:
ds = load_dataset("Falah/Alzheimer_MRI")


print(f"Available splits: {list(ds.keys())}")
print(f"Training samples: {len(ds['train'])}")


first_sample = ds['train'][0]

print(f"Keys: {list(first_sample.keys())}")

train_labels = [sample['label'] for sample in ds['train']]
unique_labels, counts = np.unique(train_labels, return_counts=True)

if hasattr(ds['train'].features['label'], 'names'):
    label_names = ds['train'].features['label'].names
    print(f"Label names: {label_names}")
    num_classes = len(label_names)
    for i, (label, count) in enumerate(zip(unique_labels, counts)):
        label_name = label_names[label] if label < len(label_names) else f"Label_{label}"
        print(f"  {label_name}: {count} samples")
else:
    num_classes = len(unique_labels)
    label_names = [f"Class_{i}" for i in range(num_classes)]
    for label, count in zip(unique_labels, counts):
        print(f"  Label {label}: {count} samples")

print(f"Number of classes: {num_classes}")

Available splits: ['train', 'test']
Training samples: 5120
Keys: ['image', 'label']
Label names: ['Mild_Demented', 'Moderate_Demented', 'Non_Demented', 'Very_Mild_Demented']
  Mild_Demented: 724 samples
  Moderate_Demented: 49 samples
  Non_Demented: 2566 samples
  Very_Mild_Demented: 1781 samples
Number of classes: 4


In [9]:
sample_image = first_sample['image']
if hasattr(sample_image, 'size'):
    print(f"Image size: {sample_image.size}")
elif hasattr(sample_image, 'shape'):
    print(f"Image shape: {sample_image.shape}")
else:
    img_array = np.array(sample_image)
    print(f"Image shape: {img_array.shape}")


def visualize_samples(dataset, num_samples=6):
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    axes = axes.ravel()
    
    for i in range(num_samples):
        sample = dataset[i]
        image = sample['image']
        label = sample['label']
        
        
        if hasattr(image, 'numpy'):
            img_array = image.numpy()
        else:
            img_array = np.array(image)

        if len(img_array.shape) == 3 and img_array.shape[0] in [1, 3]:
            img_array = np.transpose(img_array, (1, 2, 0))
        elif len(img_array.shape) == 3 and img_array.shape[2] == 1:
            img_array = img_array.squeeze()
        
        axes[i].imshow(img_array, cmap='gray' if len(img_array.shape) == 2 else None)
        label_name = label_names[label] if label < len(label_names) else f"Label_{label}"
        axes[i].set_title(f'{label_name} (Label: {label})')
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
    visualize_samples(ds['train'])


Image size: (128, 128)


In [10]:
class AlzheimerMRIDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image = sample['image']
        label = sample['label']
        
        # Convert to RGB if needed
        if hasattr(image, 'convert'):
            if image.mode != 'RGB':
                image = image.convert('RGB')
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        else:
            # Default transform
            if not isinstance(image, torch.Tensor):
                image = transforms.ToTensor()(image)
        
        return image, label

In [11]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(p=0.3),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [12]:
train_dataset = AlzheimerMRIDataset(ds['train'], transform=transform)
from torch.utils.data import random_split
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])


val_dataset.dataset = AlzheimerMRIDataset(ds['train'], transform=test_transform)


In [13]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [14]:
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

Training batches: 128
Validation batches: 32


In [15]:
class AlzheimerCNN(nn.Module):
    def __init__(self, num_classes):
        super(AlzheimerCNN, self).__init__()
        
        self.conv_layers = nn.Sequential(
            # First conv block
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            # Second conv block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            # Third conv block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            # Fourth conv block
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 8 * 8, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# Initialize model
model = AlzheimerCNN(num_classes=num_classes).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n=== Model Summary ===")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)


=== Model Summary ===
Total parameters: 8,780,548
Trainable parameters: 8,780,548


In [16]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

In [17]:
def validate_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc, all_predictions, all_labels

In [18]:
print("\n=== Starting Training ===")
num_epochs = 15
best_val_acc = 0.0
train_losses, val_losses = [], []
train_accs, val_accs = [], []

for epoch in range(num_epochs):
    # Training
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validation
    val_loss, val_acc, _, _ = validate_epoch(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step()
    
    # Store metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    
    print(f'Epoch [{epoch+1}/{num_epochs}]')
    print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')


=== Starting Training ===


RuntimeError: DataLoader worker (pid(s) 10248, 8316) exited unexpectedly

In [None]:
if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_alzheimer_cnn.pth')
        print(f'  ↳ New best model saved! (Val Acc: {val_acc:.2f}%)')

print(f"\nBest validation accuracy: {best_val_acc:.2f}%")

In [None]:
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
print("\n=== Final Evaluation ===")
model.load_state_dict(torch.load('best_alzheimer_cnn.pth'))
val_loss, val_acc, predictions, true_labels = validate_epoch(model, val_loader, criterion, device)

print(f"Final Validation Accuracy: {val_acc:.2f}%")
print(f"Final Validation Loss: {val_loss:.4f}")

# Classification report
print("\n=== Classification Report ===")
print(classification_report(true_labels, predictions, target_names=label_names[:num_classes]))

In [None]:
plt.figure(figsize=(8, 6))
cm = confusion_matrix(true_labels, predictions)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=label_names[:num_classes], 
            yticklabels=label_names[:num_classes])
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()

In [None]:
print("\n=== Testing on Sample Images ===")
model.eval()
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.ravel()

with torch.no_grad():
    for i in range(6):
        image, true_label = val_dataset[i]
        image = image.unsqueeze(0).to(device)
        
        output = model(image)
        _, predicted = torch.max(output, 1)
        predicted_label = predicted.item()
        
        # Denormalize for visualization
        img = image.squeeze().cpu().numpy()
        img = np.transpose(img, (1, 2, 0))
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        axes[i].imshow(img)
        true_name = label_names[true_label] if true_label < len(label_names) else f"Label_{true_label}"
        pred_name = label_names[predicted_label] if predicted_label < len(label_names) else f"Label_{predicted_label}"
        
        color = 'green' if true_label == predicted_label else 'red'
        axes[i].set_title(f'True: {true_name}\nPred: {pred_name}', color=color)
        axes[i].axis('off')

plt.tight_layout()
plt.show()

print("\n=== Training Complete ===")
print(f"Best model saved as: 'best_alzheimer_cnn.pth'")
print(f"Final validation accuracy: {val_acc:.2f}%")