## Import necessary packages

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as F
from torchvision.models import resnet18
from torchvision import transforms
from skimage.io import imsave
from skimage.io import imread
from copy import deepcopy
from tqdm import tqdm
from torch import nn
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pydicom
import random
import torch
import os

## Set hyperparameters and constant values

In [None]:
TRAIN_DIR = 'Training/'
TEST_DIR = 'Testing/'
IMG_SIZE = 224
TRAIN_BATCHSIZE = 200
EVAL_BATCHSIZE = 10
EPOCHS = 10
TRAIN_FRACTION = 0.8
TEST_FRACTION = 0.1
VALIDATION_FRACTION = 0.1
DATA_LABELS = {0: 'glioma', 1: 'meningioma', 2: 'no_tumor', 3: 'pituitary'}

## Create class for handling Training and Testing datasets

In [None]:
class BrainTumorDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.labels = None
        
        self.create_labels()

    # Create labels for each image
    def create_labels(self):
        labels = []
        for target, target_label in DATA_LABELS.items():
            case_dir = os.path.join(self.data_dir, target_label)
            for fname in os.listdir(case_dir):
                fpath = os.path.join(case_dir, fname)
                labels.append((fpath, target))
        self.labels = labels

    # Normalize image to 0-255 range         
    def normalize(self, img):
        img = img.astype(np.float_) * 255. / img.max()
        img = img.astype(np.uint8)
        return img

    # Returns data with its label 
    def __getitem__(self, idx):
        fpath, target = self.labels[idx]
        
        img_arr = imread(fpath, as_gray=True)
        img_arr = self.normalize(img_arr)
        
        data = torch.from_numpy(img_arr)
        data = data.type(torch.FloatTensor)
        data = torch.unsqueeze(data, 0)
        
        if self.transform:
            data = self.transform(data)
        
        return data, target

    def __len__(self):
        return len(self.labels)

## Create transforms which will augment images in datasets

In [None]:
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5, interpolation=3),
    transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

## Define function for plotting images in datasets

In [None]:

def plot_sample_images(dataset, num_samples_per_class=4):
    samples = {label: [] for _, label in dataset.labels}
    
    for img_path, label in dataset.labels:
        if all(len(samples[label]) >= num_samples_per_class for label in samples):
            break

        if len(samples[label]) < num_samples_per_class:
            img = imread(img_path, as_gray=True)
            samples[label].append(img)
    
    fig, axes = plt.subplots(len(samples), num_samples_per_class, figsize=(15, 10))
    for i, label in enumerate(samples.keys()):
        for j, img in enumerate(samples[label]):
            axes[i, j].imshow(img, cmap='gray')
            axes[i, j].axis('off')
            if j == 0:
                axes[i, j].set_title(DATA_LABELS[label])
    
    plt.show()

## Instantiate dataset classes

In [None]:
train_dataset = BrainTumorDataset(data_dir=TRAIN_DIR, transform=train_transform)
test_dataset_full = BrainTumorDataset(data_dir=TEST_DIR, transform=test_transform)

plot_sample_images(dataset)

## Create testing and validation datasets

In [None]:
total_test_samples = len(test_dataset_full)
TEST_FRACTION = int(total_test_samples * 0.7)  # 70% for testing
VALIDATION_FRACTION = total_test_samples - TEST_FRACTION  # remaining for validation

test_dataset, validation_dataset = torch.utils.data.random_split(
    test_dataset_full, 
    [TEST_FRACTION, VALIDATION_FRACTION]
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Running on {device}')

## Create data loaders for trainig, testing, and validation

In [None]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=TRAIN_BATCHSIZE,
    shuffle=True
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=EVAL_BATCHSIZE
)

validation_loader = DataLoader(
    validation_dataset, 
    batch_size=EVAL_BATCHSIZE
)

In [None]:
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [None]:
net = resnet18()
net.conv1 = nn.Conv2d(
    1, 
    64, 
    kernel_size=(7, 7), 
    stride=(2, 2), padding=(3, 3), bias=False
)

net = net.to(device)

criterion = nn.CrossEntropyLoss()
error_minimizer = torch.optim.SGD(net.parameters(), lr=0.0001)

net_final = deepcopy(net)

## Training loop

In [None]:
best_validation_accuracy = 0. 
train_accs = []
val_accs = []

for epoch in range(EPOCHS):
    net.train()  # Set the network in training mode

    print(f"# Epoch {epoch + 1}:")

    total_train_examples = 0
    num_correct_train = 0

    for batch_index, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Zero the parameter gradients
        error_minimizer.zero_grad()

        # Forward pass
        predictions = net(inputs)

        # Compute loss
        loss = criterion(predictions, targets)

        # Backward pass and optimization
        loss.backward()
        error_minimizer.step()

        # Compute training accuracy
        _, predicted_class = predictions.max(1)
        total_train_examples += predicted_class.size(0)
        num_correct_train += predicted_class.eq(targets).sum().item()

    train_acc = num_correct_train / total_train_examples
    print(f"Training accuracy: {train_acc}")
    train_accs.append(train_acc)
    
    total_val_examples = 0
    num_correct_val = 0
    
    net.eval()

    with torch.no_grad():
        for batch_index, (inputs, targets) in enumerate(validation_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            predictions = net(inputs)
            _, predicted_class = predictions.max(1)
            
            total_val_examples += predicted_class.size(0)
            num_correct_val += predicted_class.eq(targets).sum().item()

    val_acc = num_correct_val / total_val_examples
    print(f"Validation accuracy: {val_acc}")
    val_accs.append(val_acc)

    if val_acc > best_validation_accuracy:
       best_validation_accuracy = val_acc
       print("Validation accuracy was improved. Saving new model.")
       net_final = deepcopy(net)


## Testing loop

In [None]:
net.eval()

total_test_examples = 0
num_correct_test = 0

with torch.no_grad():
    for batch_index, (inputs, targets) in enumerate(test_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        predictions = net(inputs)
        _, predicted_class = predictions.max(1)
        
        total_test_examples += predicted_class.size(0)
        num_correct_test += predicted_class.eq(targets).sum().item()

test_acc = num_correct_test / total_test_examples
print(f"Test accuracy: {test_acc}")


## Plot results

In [None]:
import matplotlib.pyplot as plt

epochs_list = list(range(EPOCHS))

plt.figure()
plt.plot(epochs_list, train_accs, 'b-', label='training set accuracy')
plt.plot(epochs_list, val_accs, 'r-', label='validation set accuracy')
plt.xlabel('epoch')
plt.ylabel('prediction accuracy')
plt.ylim(0.5, 1)
plt.title('Classifier training evolution:\nprediction accuracy over time')
plt.legend()
plt.show()