In [None]:
import torch
import torchvision
import os
import numpy as np
import skimage.transform as image_transform
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as pyplot
%matplotlib inline

## Create custom dataset Class

In [None]:
class TumorDataset(Dataset):
    def __init__(self,root,size = (128,128)):
        self.root         = root
        self.class_names  = os.listdir(self.root)
        self.size         = size

    def __len__(self):
        dataset_size = 0
        
        for cur_class in self.class_names:
            dataset_size += len(os.listdir(self.root + cur_class + '/'))
            
        return dataset_size
    
    def __getitem__(self,index):
        cur_dataset_size = 0
        class_label      = 0
        image_index      = index
        
        for cur_class in self.class_names:
            image_index = index - cur_dataset_size
            
            cur_path       = self.root + cur_class + '/'
            files_in_class = os.listdir(cur_path)
            
            cur_dataset_size += len(files_in_class)
            
            if(index < cur_dataset_size):
                image = np.asarray(Image.open(cur_path + files_in_class[image_index]),dtype=np.double)
                
                if len(image.shape) == 3:
                    image = image[:,:,0]
                
                image = np.expand_dims(image_transform.resize(image,self.size),axis = 0)
                return (image,class_label)
            
            class_label += 1

### Set up the training dataset and show an example with and without tumors

In [None]:
path_train_set  = 'brain_tumor_dataset/train/'
path_val_set    = 'brain_tumor_dataset/val/'

dataset_train    = TumorDataset(path_train_set)
dataset_validate = TumorDataset(path_val_set)

print('Training Dataset Length:    %d' % (len(dataset_train)))
print('Validation  Dataset Length: %d' % (len(dataset_validate)))

print('Tumor Example')
index = 10
print(' image:')
pyplot.imshow(np.squeeze(dataset_train[index][0]),cmap = 'gray')
pyplot.show()
print(' label: %d' % dataset_train[index][1])

print('Non-Tumor Example')
index = 200
print(' image:')
pyplot.imshow(np.squeeze(dataset_train[index][0]),cmap = 'gray')
pyplot.show()
print(' label: %d' % dataset_train[index][1])

## Creating a Model 

In [None]:
class TumorClassificationModel(torch.nn.Module):
    def __init__(self,kernel_size = 5):
        super().__init__()
        
        self.maxpool    = torch.nn.MaxPool2d(2,2)
            
        self.convlayer1      = torch.nn.Conv2d(in_channels=1,out_channels=6,kernel_size=kernel_size)
        self.convlayer2      = torch.nn.Conv2d(in_channels=6,out_channels=15,kernel_size=kernel_size)
        
        self.fully_connected1 = torch.nn.Linear(15 * 29 * 29, 120)
        self.fully_connected2 = torch.nn.Linear(120,60)
        self.fully_connected3 = torch.nn.Linear(60,1)
        
    def forward(self, x):

        x = self.maxpool(torch.nn.functional.relu(self.convlayer1(x)))
        x = self.maxpool(torch.nn.functional.relu(self.convlayer2(x)))
        
        x   = torch.nn.functional.relu(self.fully_connected1(x.view(-1,15*29*29)))
        x   = torch.nn.functional.relu(self.fully_connected2(x))
        out = torch.sigmoid(self.fully_connected3(x))
        
        return out

## Train the classification model 

In [None]:
num_epochs     = 4
learning_rate  = .00001
batch          = 10

model = TumorClassificationModel().double()

loss_function  = torch.nn.BCELoss()
optimizer      = torch.optim.Adam(model.parameters(),lr = learning_rate)

dataloader_training = DataLoader(dataset_train,batch_size = batch,shuffle=True)
dataloader_validate = DataLoader(dataset_validate,batch_size = batch)

training_loss   = np.zeros(num_epochs)
validation_loss = np.zeros(num_epochs)

print('~~~~~~~~~~~~~~~~~')
print('Starting Training')
print('~~~~~~~~~~~~~~~~~')

for epoch in range(num_epochs):
    print("Epoch %d" % (epoch + 1))
    
    #Loop through the dataset in batches
    for i_batch, sampled_batch in enumerate(dataloader_training):
        cur_images = sampled_batch[0]
        cur_labels = torch.unsqueeze(sampled_batch[1],1).double()
        
        optimizer.zero_grad()
        
        outputs = model(cur_images)
        loss    = loss_function(outputs,cur_labels)
        
        loss.backward()
        optimizer.step()
        
        training_loss[epoch] += loss.item()
        
    #Loop through the validation set to compute validation loss
    for i_batch, sampled_batch in enumerate(dataloader_training):
        cur_images = sampled_batch[0]
        cur_labels = torch.unsqueeze(sampled_batch[1],1).double()
        
        outputs = model(cur_images)
        
        loss = loss_function(outputs,cur_labels)
        
        validation_loss[epoch] =+ loss.item()
    
    print("  training   loss: %.2f" % (training_loss[epoch]))
    print("  validation loss: %.2f" % (validation_loss[epoch]))

## Compute Validation and Training Accuracy 

In [None]:
correct_validation = 0
for i_batch, sampled_batch in enumerate(dataloader_validate):
    cur_images = sampled_batch[0]
    cur_labels = torch.unsqueeze(sampled_batch[1],1).double()
    
    classifications = model(cur_images).detach().numpy()
    labels          = cur_labels.detach().numpy()
    
    classifications[classifications > .5]  = 1
    classifications[classifications <= .5] = 0
    
    correct_validation += np.sum(classifications == labels) / len(dataset_validate)
    
correct_training = 0
for i_batch, sampled_batch in enumerate(dataloader_training):
    cur_images = sampled_batch[0]
    cur_labels = torch.unsqueeze(sampled_batch[1],1).double()
    
    classifications = model(cur_images).detach().numpy()
    labels          = cur_labels.detach().numpy()
    
    classifications[classifications > .5]  = 1
    classifications[classifications <= .5] = 0
    
    correct_training += np.sum(classifications == labels) / len(dataset_train)
    
print('Training   Accuracy:    %.2f' % (correct_training * 100))
print('Validation Accuracy:    %.2f' % (correct_validation * 100))