In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
from torchvision.transforms import functional as F
from torchvision import transforms
from skimage.io import imsave
from skimage.io import imread
from copy import deepcopy
from tqdm import tqdm
from torch import nn
import pandas as pd
import numpy as np
import torch
import random
import pydicom
import os

In [None]:
DATA_DIR = 'Training/'
IMG_SIZE = 128
TRAIN_BATCHSIZE = 200
EVAL_BATCHSIZE = 10
EPOCHS = 10
TRAIN_FRACTION = 0.8
TEST_FRACTION = 0.1
VALIDATION_FRACTION = 0.1

In [None]:
class BrainTumorDataset(Dataset):
    def __init__(self, transform=None):
        self.DATA_DIR = DATA_DIR
        self.IMG_SIZE = IMG_SIZE
        self.transform = transform
        
        self.labels = None
        self.create_labels()

    # Create labels for each image
    def create_labels(self):
        labels = []
        for target, target_label in enumerate(['glioma', 'meningioma', 'no_tumor', 'pituitary']):
            case_dir = os.path.join(self.DATA_DIR, target_label)
            for fname in os.listdir(case_dir):
                fpath = os.path.join(case_dir, fname)
                labels.append((fpath, target))
        self.labels = labels

    # Normalize image to 0-255 range         
    def normalize(self, img):
        img = img.astype(np.float_) * 255. / img.max()
        img = img.astype(np.uint8)
        return img

    # Returns data with its label 
    def __getitem__(self, idx):
        fpath, target = self.labels[idx]
        
        # Step 1: Check the Input Types
        img_arr = imread(fpath, as_gray=True)
        
        img_arr = self.normalize(img_arr)
        
        # Convert ndarray to tensor
        data = torch.from_numpy(img_arr)
        
        data = data.type(torch.FloatTensor)
        data = torch.unsqueeze(data, 0)  # add channel dimension
        
        if self.transform:
            data = self.transform(data)
        
        return data, target

    def __len__(self):
        return len(self.labels)

In [None]:
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5, interpolation=3),
    transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
])

In [None]:
dataset = BrainTumorDataset(transform=train_transform)
print(f'Amount of data in dataset: {len(dataset)}')

In [None]:
train_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(
    dataset, 
    [TRAIN_FRACTION, TEST_FRACTION, VALIDATION_FRACTION]
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Running on {device}')

In [None]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=TRAIN_BATCHSIZE,
    shuffle=True
)

validation_loader = DataLoader(
    validation_dataset, 
    batch_size=EVAL_BATCHSIZE
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=EVAL_BATCHSIZE
)

In [None]:
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [None]:

net = resnet18()
net.conv1 = nn.Conv2d(
    1, 
    64, 
    kernel_size=(7, 7), 
    stride=(2, 2), padding=(3, 3), bias=False
)

net = net.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
error_minimizer = torch.optim.SGD(net.parameters(), lr=0.0001)

In [None]:
net_final = deepcopy(net)

In [None]:
best_validation_accuracy = 0. 
train_accs = []
val_accs = []

for epoch in range(EPOCHS):
	net.train()

	print(f"# Epoch {epoch + 1}:")

	total_train_examples = 0
	num_correct_train = 0

	for batch_index, (inputs, targets) in tqdm(enumerate(train_loader), total=len(train_dataset)//TRAIN_BATCHSIZE):
		inputs = inputs.to(device)
		targets = targets.to(device)

		error_minimizer.zero_grad()

		predictions = net(inputs)

		loss = criterion(predictions, targets)
		loss.backward()

		error_minimizer.step()

		_, predicted_class = predictions.max(1)
		total_train_examples += predicted_class.size(0)
		num_correct_train += predicted_class.eq(targets).sum().item()

	train_acc = num_correct_train / total_train_examples
	print(f"Training accuracy: {train_acc}")
	train_accs.append(train_acc)

	total_val_examples = 0
	num_correct_val = 0

	net.eval()

	with torch.no_grad():
		for batch_index, (inputs, targets) in tqdm(enumerate(validation_loader), total=len(validation_dataset)//eval_batchsize):
			inputs = inputs.to(device)
			targets = targets.to(device)
			predictions = net(inputs)

			_, predicted_class = predictions.max(1)
			total_val_examples += predicted_class.size(0)
			num_correct_val += predicted_class.eq(targets).sum().item()

	val_acc = num_correct_val / total_val_examples
	print(f"Validation accuracy: {val_acc}")
	val_accs.append(val_acc)

	if val_acc > best_validation_accuracy:
		best_validation_accuracy = val_acc
		print("Validation accuracy was improved. Saving new model.")
		net_final = deepcopy(net)

In [None]:
import matplotlib.pyplot as plt

epochs_list = list(range(EPOCHS))

plt.figure()
plt.plot(epochs_list, train_accs, 'b-', label='training set accuracy')
plt.plot(epochs_list, val_accs, 'r-', label='validation set accuracy')
plt.xlabel('epoch')
plt.ylabel('prediction accuracy')
plt.ylim(0.5, 1)
plt.title('Classifier training evolution:\nprediction accuracy over time')
plt.legend()
plt.show()