# **Automatic eye diseases classification of optical coherence tomography (OCT) images using machine learning model** <br>



In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
 
import matplotlib.pyplot as plt
from matplotlib.image import imread
import seaborn as sns
import random
import cv2
import copy

import torch
import torch.nn as nn
from collections import Counter
import torch.nn.functional as F
import torchvision
from torchvision import datasets

import torchvision.transforms as tt
import torchvision.models as models
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
import matplotlib.gridspec as gridspec

from torch.utils.data import random_split, DataLoader

from mlxtend.plotting import plot_confusion_matrix
from sklearn.metrics import confusion_matrix

%matplotlib inline 
import os
print(os.listdir('../input'))



#   Data Loading and Structure



In [None]:
data_dir = '../input/kermany2018/OCT2017 '

# Data Parameters
train = 'train'
val = 'val'
test = 'test'
data_type = [train, val, test]

print(os.listdir(data_dir))


#  Data Augmentation 


In [None]:
# Model parameters
image_size = 128
patch_size = 128 


data_transforms = {
    train: tt.Compose([
        tt.RandomResizedCrop(image_size),
        tt.RandomHorizontalFlip(),
        tt.RandomRotation(5),
        tt.RandomGrayscale(),
        tt.RandomAffine(translate=(0.05,0.05), degrees=0),
        tt.ToTensor(),
    ]),
    val: tt.Compose([
        tt.Resize(image_size),
        tt.CenterCrop(image_size),
        tt.RandomGrayscale(),
        tt.ToTensor(),
    ]),
    test: tt.Compose([
        tt.Resize(image_size),
        tt.CenterCrop(image_size),
        tt.RandomGrayscale(),
        tt.ToTensor(),
    ]),
}

image_datasets = {
    x: datasets.ImageFolder(os.path.join(data_dir, x), 
    transform=data_transforms[x])
    for x in data_type
}

dataset_sizes = {x: len(image_datasets[x]) for x in data_type}

Each image is loaded as a tensor with a dimension of [3,128,128]. 3 indicates that each image is transformed as a color image with RGB channels, whereas 128 refers to its height and width respectively.

In [None]:
img, label = image_datasets[train][0]
print(img.shape, label)
plt.imshow(img[1,:,:],'gray')
plt.show()


#   Data Inspection and Data Imbalance 



In [None]:
counter = []
per = []
for i in range(len(image_datasets[train].classes)):
    classes = image_datasets[train].targets
    counter.append(Counter(classes)[i])

for i in range(len(image_datasets[train].classes)):
    per.append(counter[i]/sum(counter))

train_weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
    per, dataset_sizes[train])
    
%matplotlib inline 
fig = plt.figure(figsize= (16,9),constrained_layout=True)
gs = fig.add_gridspec(1, 2)

ax = fig.add_subplot(gs[0, 0])

colors = ['blue', 'orange','green','red']
ax.bar(image_datasets[train].classes, counter, color = colors)
ax.set_title('Distribution of training set');

ax = fig.add_subplot(gs[0, 1])

ax.pie(per,labels = tuple(image_datasets[train].classes),autopct='%1.1f%%')
ax.set_title('Distribution of training set')

plt.show()

Next, we plot a series of image with different labels. As a semi-trained medical personal, the difference between the normal and abnormal is noticeable. However, the exact difference between those abnormal images cannot be expressed in a coherent manner.  It is required to detect exact yet subtle differences in medical images to make accurate diagnoses. For trained personnel, it takes a considerable amount time to master such skill set after seeing thousands of images. And doctors need a reasonable amount of time to make such diagnosis, and still we cannot guarantee such diagnosis is 100% accurate. Hence, it is desired to train a model so that it can correctly classify those images in a short time frame. 


In [None]:
# Function for plotting samples
def plot_samples(samples):  
    fig, ax = plt.subplots(nrows=4, ncols=3, figsize=(15,15))
    for i in range(len(samples)):
        image = cv2.cvtColor(imread(samples[i]), cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (256, 256)) 
        ax[i//3][i%3].imshow(image)
        if i<3:
            ax[i//3][i%3].set_title("Normal", fontsize=20)
        elif 3<= i < 6:
            ax[i//3][i%3].set_title("DRUSEN", fontsize=20)
        elif 6<= i < 9:
            ax[i//3][i%3].set_title("DME", fontsize=20)
        else:
            ax[i//3][i%3].set_title("CNV", fontsize=20)
            
        ax[i//3][i%3].axis('off')
        
## Plot training samples
rand_samples = random.sample([os.path.join(data_dir+'/train/NORMAL', filename) 
                              for filename in os.listdir(data_dir+'/train/NORMAL')], 3) + \
    random.sample([os.path.join(data_dir+'/train/DRUSEN', filename) 
                   for filename in os.listdir(data_dir+'/train/DRUSEN')], 3) + \
    random.sample([os.path.join(data_dir+'/train/DME', filename) 
                   for filename in os.listdir(data_dir+'/train/DME')], 3) + \
    random.sample([os.path.join(data_dir+'/train/CNV', filename) 
                   for filename in os.listdir(data_dir+'/train/CNV')], 3)

plot_samples(rand_samples)
plt.suptitle('Training Set Samples', fontsize=20)
plt.show()


#   Preparing Train, Validation & Test Data

 

In [None]:
train_size = round(len(image_datasets[train])*0.5) # 
temp = len(image_datasets[train]) - train_size # 99%

train_ds,temps = random_split(image_datasets[train], [train_size,temp])


val_size = round(len(temps)*0.02) # 4%
temp = len(temps) - val_size # 99%

val_ds,_ = random_split(temps, [val_size,temp])

len(train_ds),len(val_ds)

Here we have a total of 41,742 images in 128 batches, which means that we will divide 41,742 into 128 batches and each batch contains about 326 samples to train the model.

In [None]:
train_dl = DataLoader(train_ds, patch_size, num_workers=8, pin_memory=True)
val_dl = DataLoader(val_ds, patch_size, num_workers=8, pin_memory=True)


#   Setting Up GPU


In [None]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
device = get_default_device()
device

In [None]:
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

In [None]:
class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device) # yield will stop here, perform other steps, and the resumes to the next loop/batch

    def __len__(self):
        """Number of batches"""
        return len(self.dl)


#    Choosing Model Performance Metrics



In [None]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1) 
    return torch.tensor(torch.sum(preds == labels).item() / len(preds)), preds

def F1_score(outputs, labels):
    _, preds = torch.max(outputs, dim=1) 
    
    # precision, recall, and F1
    cm  = confusion_matrix(labels, preds)
    tn, fp, fn, tp = cm.ravel()
    precision = tp/(tp+fp)
    recall = tp/(tp+fn)
    f1 = 2*((precision*recall)/(precision+recall))
    
    return precision,recall,f1,preds



#    Selectting Cost Function for the Task





In [None]:
class OCTresbase(nn.Module):
        # this is for loading the batch of train image and outputting its loss, accuracy 
    # & predictions
    def training_step(self, batch,weight):
        images,labels = batch
        out = self(images)                                      # generate predictions
        loss = F.cross_entropy(out, labels,weight=weight)      # weighted compute loss

        acc,preds = accuracy(out, labels)                       # calculate accuracy
        
        return {'train_loss': loss, 'train_acc':acc}
    
        # this is for computing the train average loss and acc for each epoch
    def train_epoch_end(self, outputs):
        batch_losses = [x['train_loss'] for x in outputs]       # get all the batches loss
        epoch_loss = torch.stack(batch_losses).mean()           # combine losses
        batch_accs = [x['train_acc'] for x in outputs]          # get all the batches acc
        epoch_acc = torch.stack(batch_accs).mean()              # combine accuracies
        
        return {'train_loss': epoch_loss.item(), 'train_acc': epoch_acc.item()}
    
    # this is for loading the batch of val/test image and outputting its loss, accuracy, 
    # predictions & labels
    def validation_step(self, batch):
        images,labels = batch
        out = self(images)                                      # generate predictions
        loss = F.cross_entropy(out, labels)                     # compute loss
        acc,preds = accuracy(out, labels)                       # calculate acc & get preds
        
        return {'val_loss': loss.detach(), 'val_acc':acc.detach(), 
                'preds':preds.detach(), 'labels':labels.detach()}
    # detach extracts only the needed number, or other numbers will crowd memory
    
    # this is for computing the validation average loss and acc for each epoch
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]         # get all the batches loss
        epoch_loss = torch.stack(batch_losses).mean()           # combine losses
        batch_accs = [x['val_acc'] for x in outputs]            # get all the batches acc
        
        epoch_acc = torch.stack(batch_accs).mean()              # combine accuracies
        
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

    # this is for printing out the results after each epoch
    def epoch_end(self, epoch, train_result, val_result):
        print('Epoch [{}], train_loss: {:.4f}, train_acc: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}'.
              format(epoch+1, train_result['train_loss'], train_result['train_acc'],
                     val_result['val_loss'], val_result['val_acc']))
    
    # this is for using on the test set, it outputs the average loss and acc, 
    # and outputs the predictions
    def test_prediction(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()           # combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()              # combine accuracies
        # combine predictions
        batch_preds = [pred for x in outputs for pred in x['preds'].tolist()] 
        # combine labels
        batch_labels = [lab for x in outputs for lab in x['labels'].tolist()]  
        
        return {'test_loss': epoch_loss.item(), 'test_acc': epoch_acc.item(),
                'test_preds': batch_preds, 'test_labels': batch_labels}      



#  8.  Model Selection: ResNET


In [None]:
resnet18 = models.resnet18(pretrained=True)



#   Transfer Learining


In [None]:
class OCTres(OCTresbase):
    def __init__(self):
        super().__init__()
        # Use a pretrained model
        self.network = models.resnet18(pretrained=True)
        # Freeze training for all layers before classifier
        for param in self.network.fc.parameters():
            param.require_grad = False  
        num_features = self.network.fc.in_features # get number of in features of last layer
        self.network.fc = nn.Linear(num_features, 4) # replace model classifier
    
    def forward(self, xb):
        return self.network(xb)

To properly evaluate the model, we develop some functions to achieve such goal.

*   **evaluate**: calls the validation functions defined in the base model class above and return the output. Validation functions compute the validation average loss and accuracy for each epoch and determind the proper hyperparameters.  
*   **get_lr**: Instead of using a fixed learning rate, we also construct a learning rate scheduler, which calculates the learning rate at batch index. This will change the learning rate after every batch of training. The adaptive apporach of learning rate ensures a faster training. 
*   **fit**: the fit function determines the best hyperparameters for the model and then save such model as the optimal model for our task on hand. 
*  **model.eval()**: informs the network that nothing new is to be learned and the model is used for testing only. 
* **model.train()**  sets the modules in the network in training mode which means that the model knows it has to learn the layers as opposed to the **model.eval()**
*   **torch.no_grad()**: impacts and deactivates autograd engine. It will reduce memory usage and speed up computations. 

In [None]:
@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit(epochs, lr, model, train_loader, val_loader,weight, weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
# def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):

    torch.cuda.empty_cache() # release all the GPU memory cache
    history = {}
    
    
    optimizer = opt_func(model.parameters(), lr)

    best_loss = 1 # initialize best loss, which will be replaced with lower better loss
    for epoch in range(epochs):
        
        # Training Phase 
        model.train() 
        train_outputs = []      
        lrs = []
        
        for batch in train_loader:
            outputs = model.training_step(batch,weight)
#             outputs = model.training_step(batch)

            loss = outputs['train_loss']                          # get the loss
            train_outputs.append(outputs)
            # get the train average loss and acc for each epoch
            train_results = model.train_epoch_end(train_outputs)                        
            loss.backward()                                       # compute gradients
            
            # Gradient clipping
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            
            optimizer.step()                                      # update weights
            optimizer.zero_grad()                                 # reset gradients
        
        
        # Validation phase
        val_results = evaluate(model, val_loader)
        
        # Save best loss
        if val_results['val_loss'] < best_loss and epoch + 1 > 4:
            best_loss = min(best_loss, val_results['val_loss'])
            best_model_wts = copy.deepcopy(model.state_dict())
            #torch.save(model.state_dict(), 'best_model.pt')
        
        # print results
        model.epoch_end(epoch, train_results, val_results)
        
        # save results to dictionary
        to_add = {'train_loss': train_results['train_loss'],
                  'train_acc': train_results['train_acc'],
                 'val_loss': val_results['val_loss'],
                  'val_acc': val_results['val_acc'], 'lrs':lrs}
        
        # update performance dictionary
        for key,val in to_add.items():
            if key in history:
                history[key].append(val)
            else:
                history[key] = [val]
    
    model.load_state_dict(best_model_wts)                         # load best model
    
    return history, optimizer, best_loss
            

In [None]:
train_dl = DeviceDataLoader(train_dl, device)
val_dl = DeviceDataLoader(val_dl, device)

model = to_device(OCTres(), device)


#   Model Training and Evaluating




In [None]:
epochs = 20
lr = 0.0005
grad_clip = None
weight_decay = 1e-4
opt_func = torch.optim.Adam
# weighted loss for data class imbalance

weight = np.array(per)
weight = torch.FloatTensor(1/weight).to(device)



history, optimizer, best_loss = fit(epochs, lr, model, train_dl, val_dl,weight,
                                    grad_clip=grad_clip, 
                                    weight_decay=weight_decay, 
                                    opt_func=opt_func)

In [None]:
print('Best loss is:', best_loss)
# Save Model
bestmodel = {'model': OCTres(),
              'state_dict': model.state_dict(),
              'optimizer' : optimizer.state_dict()}

torch.save(bestmodel, './OCTResnet.pth')

In [None]:
# this is for loading the model from a previously saved one

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False

    model.eval()
    return model

model = load_checkpoint('./OCTResnet.pth')
# model = to_device(OCTres(), device)


#   Accuracy and Loss Plots


In [None]:
# Plot Accuracy and Loss 
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
t = f.suptitle('Performance', fontsize=12)
f.subplots_adjust(top=0.85, wspace=0.3)

epoch_list = list(range(1,epochs+1))
ax1.plot(epoch_list, history['train_acc'], label='Train Accuracy')
ax1.plot(epoch_list, history['val_acc'], label='Validation Accuracy')
ax1.set_xticks(np.arange(0, epochs+1, 5))
ax1.set_ylabel('Accuracy Value')
ax1.set_xlabel('Epoch')
ax1.set_title('Accuracy')
l1 = ax1.legend(loc="best")

ax2.plot(epoch_list, history['train_loss'], label='Train Loss')
ax2.plot(epoch_list, history['val_loss'], label='Validation Loss')
ax2.set_xticks(np.arange(0, epochs+1, 5))
ax2.set_ylabel('Loss Value')
ax2.set_xlabel('Epoch')
ax2.set_title('Loss')
l2 = ax2.legend(loc="best")



#  Predicting on Test Set


In [None]:
counter = []
per = []
# image_datasets[train].classes
for i in range(len(image_datasets[test].classes)):
    classes = image_datasets[test].targets
    counter.append(Counter(classes)[i])

for i in range(len(image_datasets[test].classes)):
    per.append(counter[i]/sum(counter))

    
train_weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
    per, dataset_sizes[train])
    
%matplotlib inline 
fig = plt.figure(figsize= (16,9),constrained_layout=True)
gs = fig.add_gridspec(1, 2)

ax = fig.add_subplot(gs[0, 0])

colors = ['blue', 'orange','green','red']
ax.bar(image_datasets[val].classes, counter, color = colors)
ax.set_title('Distribution of test set');

ax = fig.add_subplot(gs[0, 1])

ax.pie(per,labels = tuple(image_datasets[test].classes),autopct='%1.1f%%')
ax.set_title('Distribution of test set')

plt.show()

In [None]:
@torch.no_grad()
def test_predict(model, test_loader):
    model.eval()
    # perform testing for each batch
    outputs = [model.validation_step(batch) for batch in test_loader] 
    results = model.test_prediction(outputs)                          
    print('test_loss: {:.4f}, test_acc: {:.4f}'
          .format(results['test_loss'], results['test_acc']))
    
    return results['test_preds'], results['test_labels']

In [None]:
# len(image_datasets[test])
test_dl = DataLoader(image_datasets[test], patch_size, num_workers=8, pin_memory=True)

test_dl = DeviceDataLoader(test_dl, device)
preds,labels = test_predict(model.to(device), test_dl)


#   Confusion Matrix
 

In [None]:
# # Plot confusion matrix
cm  = confusion_matrix(labels, preds)
plt.figure()
plot_confusion_matrix(cm,figsize=(12,8),cmap=plt.cm.Blues)
plt.xticks(range(4), ['CNV', 'DME','DRUSEN','Normal'], fontsize=16)
plt.yticks(range(4), ['CNV', 'DME','DRUSEN','Normal'], fontsize=16)
plt.xlabel('Predicted Label',fontsize=18)
plt.ylabel('True Label',fontsize=18)
plt.show()

In [None]:
# Compute Performance Metrics
fp = cm.sum(axis=0) - np.diag(cm)  
fn = cm.sum(axis=1) - np.diag(cm)
tp = np.diag(cm)
tn = cm.sum() - (fp + fn + tp)

accuracy = (np.array(preds) == np.array(labels)).sum() / len(preds)
precision = tp/(tp+fp)
recall = tp/(tp+fn)
f1 = 2*((precision*recall)/(precision+recall))

recall = recall.astype(float)
precision = precision.astype(float)
f1 = f1.astype(float)

print("Accuracy of the model is %.2f"% accuracy)
print('Recall of the model is {}'.format(recall))
print('precision of the model is {}'.format(precision))
print('F1 score of the model is {}'.format(f1))


#   Predictions Presentation 


In [None]:
idxs = torch.tensor(np.append(np.arange(start=0, stop=4, step=1), 
                              np.arange(start=300, stop=304, step=1)))
idxs = torch.tensor(np.append(idxs, 
                              np.arange(start=600, stop=604, step=1)))
idxs = torch.tensor(np.append(idxs, 
                              np.arange(start=900, stop=904, step=1)))

fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(12,12),constrained_layout=True)

for c,i in enumerate(idxs):
    img_tensor, label = image_datasets[test][i]
    ax[c//4][c%4].imshow(img_tensor[0,:,:], cmap='gray')
    ax[c//4][c%4].set_title('Label: {}\nPrediction: {}'
                            .format(image_datasets[test].classes[label], 
                                    image_datasets[test].classes[preds[i]]),
                            fontsize=12)
    ax[c//4][c%4].axis('off')