# MRI Classification Using PyTorch
This notebook demonstrates how to load an MRI dataset, split it into training and validation sets, define a custom CNN model, and train the model using PyTorch.

In [1]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

## Data Preparation
Load the dataset and split it into training and validation sets. We will apply data augmentations to the training set.

In [None]:
# Define transformations for training and validation sets
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the dataset
dataset = ImageFolder(root='dataset_13', transform=transform_train)

# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))  # 80% training
val_size = len(dataset) - train_size  # 20% validation
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Apply validation transformations to the validation set
val_dataset.dataset.transform = transform_val

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

## Transfer Learning

Pre-trained models are typically trained on large datasets, such as ImageNet, which contains millions of images and thousands of classes. Transfer learning leverages this prior knowledge to benefit a new, often smaller dataset.

In [3]:
import torchvision.models as models

# Number of classes in your dataset
num_classes = 4  # glioma, meningioma, notumor, pituitary

# Load a pre-trained ResNet model
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

# Replace the final fully connected layer
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)  # num_classes is the number of classes in your dataset

# Move the model to the GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Optionally, freeze all layers except the final layer if needed
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the final layer
for param in model.fc.parameters():
    param.requires_grad = True

# Training and evaluation
num_epochs = 10  # Define the number of epochs
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    
    # Training phase
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = running_corrects.double() / len(train_loader.dataset)
    
    print(f'Train - Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_corrects = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            val_loss += loss.item() * inputs.size(0)
            val_corrects += torch.sum(preds == labels.data)
    
    val_loss /= len(val_loader.dataset)
    val_acc = val_corrects.double() / len(val_loader.dataset)
    
    print(f'Validation - Epoch {epoch + 1}/{num_epochs}, Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}')

print('Training complete')

Train - Epoch 1/10, Loss: 1.3653, Accuracy: 0.3802
Validation - Epoch 1/10, Loss: 1.1713, Accuracy: 0.5312
Train - Epoch 2/10, Loss: 0.9598, Accuracy: 0.6667
Validation - Epoch 2/10, Loss: 1.0653, Accuracy: 0.5000
Train - Epoch 3/10, Loss: 0.8083, Accuracy: 0.7188
Validation - Epoch 3/10, Loss: 0.7345, Accuracy: 0.8229
Train - Epoch 4/10, Loss: 0.6856, Accuracy: 0.7917
Validation - Epoch 4/10, Loss: 0.6668, Accuracy: 0.8333
Train - Epoch 5/10, Loss: 0.5995, Accuracy: 0.8047
Validation - Epoch 5/10, Loss: 0.6560, Accuracy: 0.7812
Train - Epoch 6/10, Loss: 0.4872, Accuracy: 0.8750
Validation - Epoch 6/10, Loss: 0.6042, Accuracy: 0.8125
Train - Epoch 7/10, Loss: 0.4486, Accuracy: 0.8776
Validation - Epoch 7/10, Loss: 0.5625, Accuracy: 0.8229
Train - Epoch 8/10, Loss: 0.4281, Accuracy: 0.8854
Validation - Epoch 8/10, Loss: 0.5266, Accuracy: 0.8229
Train - Epoch 9/10, Loss: 0.4103, Accuracy: 0.8906
Validation - Epoch 9/10, Loss: 0.5125, Accuracy: 0.8125
Train - Epoch 10/10, Loss: 0.3904, Ac

In [None]:
model.eval()
val_loss = 0.0
val_corrects = 0

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        _, preds = torch.max(outputs, 1)
        val_loss += loss.item() * inputs.size(0)
        val_corrects += torch.sum(preds == labels.data)

val_loss /= len(val_dataset)
val_acc = val_corrects.double() / len(val_dataset)

print(f'Final Validation Loss: {val_loss:.4f}, Final Validation Accuracy: {val_acc:.4f}')

Final Validation Loss: 0.4912, Final Validation Accuracy: 0.8333
