<font color='green' style='font-size:32px'><b>Walking You Through Task 1A - Classifying MRI Brain Tumours</b></font>
***
<font color='green' style='font-size:20px'><b>Reading Images and Storing Data</b></font>

I used torchvision.datasets.ImageFolder to load the MRI images. In doing so, I applied a sequence of transformations to standardise the input images: converting images to grayscale to reduce complexity (as they were originally in RGB format despite appearing to be greyscale), resizing them to 64×64 pixels to reduce computational load while preserving relevant features, and normalising pixel values to the range [-1, 1] to stabilise training. Resizing to other sizes such as 28x28 was tested however accuracy on the validation set was found to be too low (86%) compared to resizing to 64x64 (96.49%). Tensors were used to store the image data and labels because they are the core data structure in PyTorch, optimised for efficient computation, especially on GPUs. By converting the entire dataset into tensors, I could take advantage of fast matrix operations during training and evaluation. Tensors also allow for easy slicing, batching, and reshaping, which simplifies feeding data into the model and applying additional processing if needed.

<font color='green' style='font-size:20px'><b>Machine Learning Model</b></font>

For this task, I chose to implement a Convolutional Neural Network (CNN) because CNNs are particularly effective for image classification problems. Unlike traditional machine learning models such as Support Vector Machines (SVMs) or k-Nearest Neighbors (k-NN), which often require manual preprocessing or feature extraction techniques, CNN's automatically learn to identify relevant visual patterns - such as edges, shapes, and textures - directly from raw pixel data during training. This is a useful ability (especially in the context of classifying MRI tumours) to capture both local and global image structures. By using multiple layers of convolution followed by pooling, the network builds a hierarchical understanding of the image, moving from low-level features to more complex representations relevant for distinguishing between tumor types.

The model begins with two convolutional layers, each with 32 filters and a 3×3 kernel. These initial layers are designed to capture basic visual patterns, such as edges and textures, from the grayscale input images. Padding is used to maintain the original spatial dimensions, which ensures that deeper layers receive feature maps that still contain detailed spatial information. This helps the network progressively learn more complex and meaningful representations from the image as it goes deeper. Max pooling is applied after the second convolution to reduce dimensionality, retain important features, and improve translation invariance. This pattern is repeated with two additional convolutional layers, now using 64 filters to increase the depth of learned features and capture more complex patterns in the data. A second max pooling layer further condenses the representation and reduces computational load. After these layers, the resulting feature maps are flattened into a one-dimensional vector and passed into a linear layer with 256 neurons. This layer combines all the spatial features learned so far into a single compact representation that captures the most important information needed for classification. The final output (linear) layer has 4 neurons, each representing one of the four possible tumour classes in the dataset. ReLU activations are used throughout to introduce non-linearity, and the dynamic computation of the fully connected input size ensures flexibility, especially if changes are made to the input image size or convolutional structure in future experiments.

<font color='green' style='font-size:20px'><b>Training the Model</b></font>

The model is trained using a supervised learning approach, where the CNN learns to map input MRI images to one of four tumour classes. I use the cross-entropy loss function, which is standard for multi-class classification tasks as it measures the difference between the predicted class probabilities and the true labels. For optimisation, I use the Adam optimiser with a learning rate of 0.001, as it provides efficient and adaptive gradient updates, helping the model converge faster without extensive tuning. During training, the model is set to training mode, which activates components like dropout (to avoid overfitting to learn more robust reatures) and batch normalisation (to normalise activations) that are only used during training. The loss is computed for each batch, and gradients are calculated using backpropagation, which adjusts the model’s weights to reduce the error over time. The training loop tracks both loss and accuracy to monitor learning progress. To make training more efficient and avoid overfitting, the model's performance is also evaluated on the validation set after each epoch. The model achieved an high of 99.63% accuracy on the training data and so barring cases of overfitting the model should perform well on the validation data. 

<font color='green' style='font-size:20px'><b>Evaluating the Model</b></font>

I evaluate the model on a separate validation set at the end of each epoch to monitor how well it generalises to unseen data. Before evaluation, the model is set to evaluation mode, which disables training-specific behaviors like dropout and ensures that layers such as batch normalisation use consistent, learned statistics rather than batch-specific values. This provides a more stable and accurate estimate of the model’s true performance. During validation, no gradients are computed, which reduces memory usage and speeds up computation. After each epoch, I calculate the validation accuracy and compare it to the best seen so far. If the current model performs better, its weights are saved. This approach ensures that I retain the most effective version of the model throughout training, helping to prevent overfitting and improve performance on new data. After 10 epochs, the best validaion accuracy obtained was 97,71%.

In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
%matplotlib inline
from tqdm import tqdm, trange
import matplotlib.pyplot as plt

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

#---------------------- Loading Data ----------------------#

# Converting to grescale, resizing, converting to tensor, normalising
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Scales pixels from [0,1] to [-1,1]
])


# Load Images
train_dataset = datasets.ImageFolder(root='./brain-tumor/train', transform=transform)
valid_dataset = datasets.ImageFolder(root='./brain-tumor/valid', transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=False)

In [10]:
#---------------------- Creating Tensors ----------------------#

train_size = len(train_dataset)  # Number of training images (5712)
valid_size = len(valid_dataset)  # Number of validation/test images (1311)

height = width = 64

# Initialize X and y for storing images and labels as tensors
X_train = torch.zeros((train_size, 1, height, width))  # 1 channel for grayscale images
y_train = torch.zeros((train_size,), dtype=torch.long)
X_valid = torch.zeros((valid_size, 1, height, width))  # 1 channel for grayscale images
y_valid = torch.zeros((valid_size,), dtype=torch.long)

# Convert train data to tensors
for i, (img, label) in enumerate(train_loader.dataset):
    X_train[i] = img  # Directly store the tensor (size: [1, 28, 28] for grayscale)
    y_train[i] = label

# Convert validation data to tensors
for i, (img, label) in enumerate(valid_loader.dataset):
    X_valid[i] = img  # Directly store the tensor (size: [1, 28, 28] for grayscale)
    y_valid[i] = label

data_train_iter = iter(train_loader)
images, labels = next(data_train_iter)

In [11]:
#---------------------- Creating CNN Model ----------------------#

class MRI_CNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        # Define the output size after convolutions and pooling dynamically
        self.fc1 = None  # To be defined after determining the output size
        self.fc2 = nn.Linear(256, 4)  # 4 output classes

    def forward(self, x):
        # convolution layer 1
        x = self.conv1(x)
        x = F.relu(x)

        # convolution layer 2
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2)  # output: 32x32

        # convolution layer 3
        x = self.conv3(x)
        x = F.relu(x)

        # convolution layer 4
        x = self.conv4(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2)  # output: 16x16

        # Dynamically define the input size for the fully connected layer
        if self.fc1 is None:
            # Determine the size of the input to the fully connected layer dynamically
            self.fc1 = nn.Linear(x.size(1) * x.size(2) * x.size(3), 256)

        # Flatten the output from convolution layers before feeding to fully connected layers
        x = x.view(-1, self.fc1.in_features)  # Flatten dynamically

        # fully connected layer 1
        x = self.fc1(x)
        x = F.relu(x)

        # fully connected layer 2 (output layer)
        x = self.fc2(x)  # 4 output classes

        return x


In [12]:
# Test the forward pass
model = MRI_CNN()
y = model(images)
print(f"output.shape: {y.shape}")

output.shape: torch.Size([32, 4])


In [13]:
# --------------------- Training Function ---------------------

from pathlib import Path
import time
# Path to save/load model weights
path_to_state = "mri_cnn_best.pt"

def train_model(model, path_to_state, criterion, optimizer, scheduler=None, epochs=10):
    my_file = Path(path_to_state)
    if my_file.is_file():
        print("✅ Loading pretrained weights...")
        model.load_state_dict(torch.load(path_to_state))

    best_accuracy = 0.0
    since = time.time()

    for epoch in range(epochs):
        print(f"\n🔁 Epoch [{epoch+1}/{epochs}]")

        # ---------- Training ----------
        model.train()
        running_loss, running_corrects = 0.0, 0

        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        train_loss = running_loss / train_size
        train_acc = running_corrects.double() / train_size
        print(f"Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")

        if scheduler:
            scheduler.step()

        # ---------- Validation ----------
        model.eval()
        running_loss, running_corrects = 0.0, 0

        with torch.no_grad():
            for inputs, labels in valid_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                preds = torch.argmax(outputs, dim=1)
                loss = criterion(outputs, labels)

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

        val_loss = running_loss / valid_size
        val_acc = running_corrects.double() / valid_size
        print(f"Valid Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")

        # Save best model
        if val_acc > best_accuracy:
            best_accuracy = val_acc
            torch.save(model.state_dict(), path_to_state)
            print("Model saved!")

    time_elapsed = time.time() - since
    print(f"\n⏱ Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    print(f"Best Validation Accuracy: {best_accuracy:.4f}")

    return model


# --------------------- Train the Model ---------------------

trained_model = train_model(
    model=model,
    path_to_state='best_model.pt',
    criterion=nn.CrossEntropyLoss(),
    optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
    scheduler=None,  # Or add a scheduler like StepLR
    epochs=10
)

✅ Loading pretrained weights...

🔁 Epoch [1/10]


100%|██████████| 179/179 [01:38<00:00,  1.81it/s]


Train Loss: 0.0227, Accuracy: 0.9923
Valid Loss: 0.2143, Accuracy: 0.9542
Model saved!

🔁 Epoch [2/10]


100%|██████████| 179/179 [01:23<00:00,  2.14it/s]


Train Loss: 0.0177, Accuracy: 0.9951
Valid Loss: 0.1668, Accuracy: 0.9687
Model saved!

🔁 Epoch [3/10]


100%|██████████| 179/179 [01:33<00:00,  1.91it/s]


Train Loss: 0.0096, Accuracy: 0.9968
Valid Loss: 0.1901, Accuracy: 0.9535

🔁 Epoch [4/10]


100%|██████████| 179/179 [01:22<00:00,  2.18it/s]


Train Loss: 0.0420, Accuracy: 0.9855
Valid Loss: 0.1386, Accuracy: 0.9634

🔁 Epoch [5/10]


100%|██████████| 179/179 [01:24<00:00,  2.12it/s]


Train Loss: 0.0157, Accuracy: 0.9944
Valid Loss: 0.1491, Accuracy: 0.9710
Model saved!

🔁 Epoch [6/10]


100%|██████████| 179/179 [01:33<00:00,  1.92it/s]


Train Loss: 0.0032, Accuracy: 0.9989
Valid Loss: 0.3262, Accuracy: 0.9588

🔁 Epoch [7/10]


100%|██████████| 179/179 [01:30<00:00,  1.98it/s]


Train Loss: 0.0317, Accuracy: 0.9942
Valid Loss: 0.1517, Accuracy: 0.9695

🔁 Epoch [8/10]


100%|██████████| 179/179 [01:29<00:00,  1.99it/s]


Train Loss: 0.0008, Accuracy: 1.0000
Valid Loss: 0.1363, Accuracy: 0.9764
Model saved!

🔁 Epoch [9/10]


100%|██████████| 179/179 [01:44<00:00,  1.71it/s]


Train Loss: 0.0001, Accuracy: 1.0000
Valid Loss: 0.1425, Accuracy: 0.9771
Model saved!

🔁 Epoch [10/10]


100%|██████████| 179/179 [02:40<00:00,  1.11it/s]


Train Loss: 0.0001, Accuracy: 1.0000
Valid Loss: 0.1476, Accuracy: 0.9764

⏱ Training complete in 18m 39s
Best Validation Accuracy: 0.9771


<font color='green' style='font-size:32px'><b>Walking You Through Task 1B - Part 1</b></font>
***
This mini dataset contained a much smaller number of training and validation images, which significantly impacted the model’s performance. After running the training for 10 epochs, the model achieved a validation accuracy of 70.00%. In comparison, the same model trained on the full dataset had reached a much higher validation accuracy of 97.71%. The results show that the model performs worse when trained on the smaller dataset and although it is still able to learn some features and patterns, the limited number of examples in the mini dataset restricts the model’s exposure to the full variability present in brain tumor MRI images. This likely causes the model to overfit to the small training set, where it learns patterns that do not generalise well to new, unseen data. As a result, the validation accuracy remains relatively low.

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
from pathlib import Path

#---------------------- Device ----------------------#
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

#---------------------- Transform ----------------------#
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

#---------------------- Datasets & Loaders ----------------------#
train_dataset = datasets.ImageFolder(root='./brain-tumor-mini/train-mini', transform=transform)
valid_dataset = datasets.ImageFolder(root='./brain-tumor-mini/valid-mini', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)

train_dataset_size = len(train_dataset)
valid_dataset_size = len(valid_dataset)

#---------------------- Model ----------------------#
model = MRI_CNN()
model.to(device)

#---------------------- Loss, Optimizer, Scheduler ----------------------#
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)  # Optional

#---------------------- Training Function ----------------------#
def train_model(model, path_to_state, criterion, optimizer, scheduler, epochs=10):
    my_file = Path(path_to_state)
    if my_file.is_file():
        print("✅ Loading pretrained weights...")
        model.load_state_dict(torch.load(path_to_state), strict=False)

    best_accuracy = 0.0
    since = time.time()

    for epoch in range(epochs):
        print(f"\n🔁 Epoch [{epoch+1}/{epochs}]")

        # ---------- Training ----------
        model.train()
        running_loss, running_corrects = 0.0, 0

        for inputs, labels in tqdm(train_loader, desc='Training'):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        train_loss = running_loss / train_dataset_size
        train_acc = running_corrects.double() / train_dataset_size
        print(f"Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")

        scheduler.step()

        # ---------- Validation ----------
        model.eval()
        running_loss, running_corrects = 0.0, 0

        with torch.no_grad():
            for inputs, labels in tqdm(valid_loader, desc='Validation'):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                preds = torch.argmax(outputs, dim=1)
                loss = criterion(outputs, labels)

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

        val_loss = running_loss / valid_dataset_size
        val_acc = running_corrects.double() / valid_dataset_size
        print(f"Valid Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")

        # Save best model
        if val_acc > best_accuracy:
            best_accuracy = val_acc
            torch.save(model.state_dict(), path_to_state)
            print("Model saved!")

    time_elapsed = time.time() - since
    print(f"\n⏱ Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    print(f"Best Validation Accuracy: {best_accuracy:.4f}")

    return model

#---------------------- Run Training ----------------------#
trained_model = train_model(
    model=model,
    path_to_state='mri_cnn_mini.pt',
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    epochs=10
)


Using device: cpu
✅ Loading pretrained weights...

🔁 Epoch [1/10]


Training: 100%|██████████| 15/15 [00:14<00:00,  1.03it/s]


Train Loss: 1.7770, Accuracy: 0.3494


Validation: 100%|██████████| 4/4 [00:02<00:00,  1.61it/s]


Valid Loss: 1.2440, Accuracy: 0.5182
Model saved!

🔁 Epoch [2/10]


Training: 100%|██████████| 15/15 [00:07<00:00,  2.01it/s]


Train Loss: 1.1511, Accuracy: 0.5146


Validation: 100%|██████████| 4/4 [00:01<00:00,  3.26it/s]


Valid Loss: 1.1315, Accuracy: 0.5545
Model saved!

🔁 Epoch [3/10]


Training: 100%|██████████| 15/15 [00:07<00:00,  2.13it/s]


Train Loss: 0.9786, Accuracy: 0.6444


Validation: 100%|██████████| 4/4 [00:01<00:00,  3.85it/s]


Valid Loss: 1.0197, Accuracy: 0.5273

🔁 Epoch [4/10]


Training: 100%|██████████| 15/15 [00:06<00:00,  2.23it/s]


Train Loss: 0.8654, Accuracy: 0.6695


Validation: 100%|██████████| 4/4 [00:01<00:00,  3.85it/s]


Valid Loss: 1.0179, Accuracy: 0.6000
Model saved!

🔁 Epoch [5/10]


Training: 100%|██████████| 15/15 [00:06<00:00,  2.22it/s]


Train Loss: 0.7825, Accuracy: 0.6946


Validation: 100%|██████████| 4/4 [00:00<00:00,  4.01it/s]


Valid Loss: 0.9894, Accuracy: 0.6636
Model saved!

🔁 Epoch [6/10]


Training: 100%|██████████| 15/15 [00:07<00:00,  2.10it/s]


Train Loss: 0.7028, Accuracy: 0.7448


Validation: 100%|██████████| 4/4 [00:01<00:00,  3.72it/s]


Valid Loss: 0.9260, Accuracy: 0.6909
Model saved!

🔁 Epoch [7/10]


Training: 100%|██████████| 15/15 [00:07<00:00,  2.10it/s]


Train Loss: 0.6531, Accuracy: 0.7824


Validation: 100%|██████████| 4/4 [00:01<00:00,  3.81it/s]


Valid Loss: 1.0130, Accuracy: 0.6455

🔁 Epoch [8/10]


Training: 100%|██████████| 15/15 [00:06<00:00,  2.24it/s]


Train Loss: 0.5919, Accuracy: 0.7908


Validation: 100%|██████████| 4/4 [00:01<00:00,  3.68it/s]


Valid Loss: 0.9299, Accuracy: 0.7000
Model saved!

🔁 Epoch [9/10]


Training: 100%|██████████| 15/15 [00:06<00:00,  2.19it/s]


Train Loss: 0.5752, Accuracy: 0.7929


Validation: 100%|██████████| 4/4 [00:01<00:00,  2.87it/s]


Valid Loss: 0.9613, Accuracy: 0.6727

🔁 Epoch [10/10]


Training: 100%|██████████| 15/15 [00:07<00:00,  2.04it/s]


Train Loss: 0.5695, Accuracy: 0.7887


Validation: 100%|██████████| 4/4 [00:01<00:00,  3.94it/s]

Valid Loss: 0.9448, Accuracy: 0.6818

⏱ Training complete in 1m 31s
Best Validation Accuracy: 0.7000





<font color='green' style='font-size:32px'><b>Walking You Through Task 1B - Part 2</b></font>
***
I chose to use a pretrained ResNet18 model, a well-known convolutional neural network architecture originally trained on the large-scale ImageNet dataset. ResNet18 is a good choice in this context due to its balanced depth and efficiency. It is deep enough to capture meaningful hierarchical features but still lightweight enough to train quickly and reliably on limited hardware and data.

Resnet18 was trained on slightly different data so I had to modify the model slightly to allow it to function with my data. For instance, Resnet18 was trained with RGB images instead of the greyscale images we are working with here and so the first layer was adjusted by changing input channels from 3 to 1. I then modified the final fully connected layer of ResNet18 to output predictions for 4 brain tumor classes instead of the original 1000 ImageNet classes. To further adapt the model, I froze all layers except for layer4 and the final layer. This means that the earlier layers remained fixed during training, retaining their pretrained weights. These early layers are responsible for detecting low-level features such as edges and textures, which are common across many image types and do not need retraining. By only fine-tuning layer4 and the fc layer, I allowed the model to learn higher-level patterns that are more specific to brain tumor images while reducing the risk of overfitting.

After training the modified ResNet18 model for 10 epochs on the mini dataset, I achieved a validation accuracy of 84.55%, which is a substantial improvement over the 62.73% obtained when training a custom CNN from scratch on the same data. This demonstrates that transfer learning provides a strong foundation by using pretrained knowledge. This is especially true when working with small datasets where training from scratch is prone to overfitting and poor generalisation.

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
from pathlib import Path

# --------------------- Setup ---------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --------------------- Transforms ---------------------
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale (1 channel)
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])   # Normalize for 1 channel
])

# --------------------- Load Data ---------------------
train_dataset = datasets.ImageFolder('./brain-tumor-mini/train-mini', transform=transform)
valid_dataset = datasets.ImageFolder('./brain-tumor-mini/valid-mini', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)

train_dataset_size = len(train_dataset)
valid_dataset_size = len(valid_dataset)

# --------------------- Load Pretrained ResNet18 ---------------------
model = models.resnet18(pretrained=True)

# Modify first conv layer to accept 1-channel input
original_weights = model.conv1.weight.data
model.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=3, bias=False)
model.conv1.weight.data = original_weights.mean(dim=1, keepdim=True)

# Freeze early layers
for name, param in model.named_parameters():
    if "layer4" not in name and "fc" not in name:
        param.requires_grad = False

# Modify final layer for 4 tumor classes
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 4)

model = model.to(device)

# --------------------- Loss, Optimizer, Scheduler ---------------------#
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# --------------------- Training Function ---------------------#
def train_model(model, path_to_state, criterion, optimizer, scheduler, epochs=10):
    my_file = Path(path_to_state)
    if my_file.is_file():
        try:
            print("✅ Loading pretrained weights...")
            model.load_state_dict(torch.load(path_to_state), strict=False)
        except RuntimeError as e:
            print("Error loading weights:", e)
            print("Continuing with randomly initialized weights.")

    best_accuracy = 0.0
    since = time.time()

    for epoch in range(epochs):
        print(f"\nEpoch [{epoch+1}/{epochs}]")

        # ---------- Training ----------#
        model.train()
        running_loss, running_corrects = 0.0, 0

        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        train_loss = running_loss / train_dataset_size
        train_acc = running_corrects.double() / train_dataset_size
        print(f"Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")

        scheduler.step()

        # ---------- Validation ----------#
        model.eval()
        running_loss, running_corrects = 0.0, 0

        with torch.no_grad():
            for inputs, labels in valid_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                preds = torch.argmax(outputs, dim=1)
                loss = criterion(outputs, labels)

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

        val_loss = running_loss / valid_dataset_size
        val_acc = running_corrects.double() / valid_dataset_size
        print(f"Valid Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")

        # Save best model
        if val_acc > best_accuracy:
            best_accuracy = val_acc
            torch.save(model.state_dict(), path_to_state)
            print("✅ Model saved!")

    time_elapsed = time.time() - since
    print(f"\n⏱ Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    print(f"🏆 Best Validation Accuracy: {best_accuracy:.4f}")

    # 🔄 Reload best model weights
    model.load_state_dict(torch.load(path_to_state))

    # 🔍 Print best model's parameters
    print("\nBest Model's state_dict:")
    for param_tensor in model.state_dict():
        print(f"{param_tensor}\t{model.state_dict()[param_tensor].size()}")

    print("\nOptimizer's state_dict:")
    for var_name in optimizer.state_dict():
        print(f"{var_name}\t{optimizer.state_dict()[var_name]}")

    return model


# --------------------- Train the Model ---------------------
trained_model = train_model(
    model=model,
    path_to_state='resnet18_brain_tumor.pt',
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    epochs=10
)


Using device: cpu
✅ Loading pretrained weights...

Epoch [1/10]


100%|██████████| 15/15 [00:14<00:00,  1.03it/s]


Train Loss: 0.0462, Accuracy: 0.9874
Valid Loss: 1.1176, Accuracy: 0.8364
✅ Model saved!

Epoch [2/10]


100%|██████████| 15/15 [00:14<00:00,  1.03it/s]


Train Loss: 0.0600, Accuracy: 0.9812
Valid Loss: 1.1102, Accuracy: 0.8455
✅ Model saved!

Epoch [3/10]


100%|██████████| 15/15 [00:13<00:00,  1.11it/s]


Train Loss: 0.0203, Accuracy: 0.9958
Valid Loss: 1.4297, Accuracy: 0.8000

Epoch [4/10]


100%|██████████| 15/15 [00:12<00:00,  1.19it/s]


Train Loss: 0.0317, Accuracy: 0.9916
Valid Loss: 1.4378, Accuracy: 0.8000

Epoch [5/10]


100%|██████████| 15/15 [00:13<00:00,  1.15it/s]


Train Loss: 0.0275, Accuracy: 0.9854
Valid Loss: 1.3195, Accuracy: 0.8364

Epoch [6/10]


100%|██████████| 15/15 [00:13<00:00,  1.11it/s]


Train Loss: 0.0248, Accuracy: 0.9874
Valid Loss: 1.3762, Accuracy: 0.8182

Epoch [7/10]


100%|██████████| 15/15 [00:12<00:00,  1.16it/s]


Train Loss: 0.0357, Accuracy: 0.9895
Valid Loss: 1.0751, Accuracy: 0.8182

Epoch [8/10]


100%|██████████| 15/15 [00:13<00:00,  1.11it/s]


Train Loss: 0.0188, Accuracy: 0.9937
Valid Loss: 1.0433, Accuracy: 0.8273

Epoch [9/10]


100%|██████████| 15/15 [00:13<00:00,  1.14it/s]


Train Loss: 0.0027, Accuracy: 1.0000
Valid Loss: 1.0643, Accuracy: 0.8182

Epoch [10/10]


100%|██████████| 15/15 [00:12<00:00,  1.17it/s]


Train Loss: 0.0127, Accuracy: 0.9958
Valid Loss: 1.0576, Accuracy: 0.8182

⏱ Training complete in 2m 37s
🏆 Best Validation Accuracy: 0.8455

Best Model's state_dict:
conv1.weight	torch.Size([64, 1, 7, 7])
bn1.weight	torch.Size([64])
bn1.bias	torch.Size([64])
bn1.running_mean	torch.Size([64])
bn1.running_var	torch.Size([64])
bn1.num_batches_tracked	torch.Size([])
layer1.0.conv1.weight	torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight	torch.Size([64])
layer1.0.bn1.bias	torch.Size([64])
layer1.0.bn1.running_mean	torch.Size([64])
layer1.0.bn1.running_var	torch.Size([64])
layer1.0.bn1.num_batches_tracked	torch.Size([])
layer1.0.conv2.weight	torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight	torch.Size([64])
layer1.0.bn2.bias	torch.Size([64])
layer1.0.bn2.running_mean	torch.Size([64])
layer1.0.bn2.running_var	torch.Size([64])
layer1.0.bn2.num_batches_tracked	torch.Size([])
layer1.1.conv1.weight	torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight	torch.Size([64])
layer1.1.bn1.bias	torch.Size([64])
layer1