# Imports

In [None]:
import os
import nibabel as nib
from scipy import ndimage
import numpy as np

# Basics
import argparse
import os
import glob 
import random
import numpy as np
import matplotlib.pyplot as plt
import time
import copy 

# torch imports
import torch
import torch.nn as nn
import torch.nn.functional as Functional
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
from torchvision import transforms, models
import torchvision.utils as vutils

# Utilities
from PIL import Image 
from scipy import ndimage
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedShuffleSplit

# Transfer Learning
from models import resnet

# SEED for reproducibility
SEED = 25
random.seed(SEED)
torch.manual_seed(SEED)
# Setting the DEVICE & BATCH SIZE
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print(device)
BATCH_SIZE = 10

# Loading Dataset

## Utilities

In [None]:
############################################################################3#
def read_nifti_file(filepath):
    """Read and load volume"""
    return nib.load(filepath).get_fdata()


def normalize(volume):
    """Normalize the volume"""
    min, max = -1000, 400
    volume[volume < min] = min
    volume[volume > max] = max
    volume = (volume - min) / (max - min)
    return volume.astype("float32")


def resize_volume(img, dDepth=64, dWidth=128, dHeight=128):
    """Resize across z-axis"""
    # Get current depth
    current_depth = img.shape[-1]
    current_width = img.shape[0]
    current_height = img.shape[1]
    # Compute depth factor
    depth = current_depth / dDepth
    width = current_width / dWidth
    height = current_height / dHeight
    depth_factor = 1 / depth
    width_factor = 1 / width
    height_factor = 1 / height
    # Rotate
    img = ndimage.rotate(img, 90, reshape=False)
    # Resize across z-axis
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
    return img


def process_scan(path):
    """Read and resize volume"""
    # Read scan
    volume = read_nifti_file(path)
    # Normalize
    volume = normalize(volume)
    # Resize width, height and depth
    volume = resize_volume(volume)
    return volume
#############################################################################################

## Loading the data

In [None]:
# Folder "CT-0" consist of CT scans having normal lung tissue, no CT-signs of viral pneumonia.
normal_scan_paths = [
    os.path.join(os.getcwd(), "MosMedData/CT-0", x) for x in os.listdir("MosMedData/CT-0")]

# Folder "CT-23" consist of CT scans having several ground-glass opacifications, involvement of lung parenchyma.
abnormal_scan_paths = [
    os.path.join(os.getcwd(), "MosMedData/CT-23", x) for x in os.listdir("MosMedData/CT-23")]

print("CT scans with normal lung tissue: " + str(len(normal_scan_paths)))
print("CT scans with abnormal lung tissue: " + str(len(abnormal_scan_paths)))

# 1 for abnormal cases & 0 for the normal ones
abnormal_labels = [1 for _ in range(len(abnormal_scan_paths))]
normal_labels = [0 for _ in range(len(normal_scan_paths))]
# Concatenation
allPaths = normal_scan_paths + abnormal_scan_paths
allLabels = normal_labels + abnormal_labels
allPaths = np.array(allPaths)
allLabels = np.array(allLabels, dtype=np.int64)
print(len(allPaths), len(allLabels))

## Pytorch Dataset

In [None]:
class AddGaussianNoise(object):
    """
        Customized transform for adding noise
    """
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
class CTDataset(Dataset):
    """
        Class for intializing and loading the pytorch dataset --> converted then to dataloader
    """
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
    
    def __nii2tensorarray__(self, data):
        [w, h, d] = data.shape
        new_data = np.reshape(data, [1, w, h, d])
        return new_data
    
    def __rotate3D(self, volume, angle=65):
        volume = np.squeeze(volume)
        rotatedVol = ndimage.rotate(volume, angle, reshape=False)
        return torch.tensor(np.expand_dims(rotatedVol, axis=0))

    def __len__(self):
        return self.paths.shape[0]
    
    def __getitem__(self, idx):
        path = self.paths[idx]
        scan = process_scan(path)
        scan = self.__nii2tensorarray__(scan)    
        scan = self.__rotate3D(scan, random.randint(-25, 25)) # added Random ROTATION       
        if self.transform:
            scan = self.transform(scan)
        
        label = self.labels[idx]
        return scan, label

In [None]:
# Creating the dataset object
tempDataset = CTDataset(allPaths, allLabels)

## Visualization

In [None]:
temploader = DataLoader(tempDataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
# Example batch
images, _ = next(iter(temploader))
# Example 3D volume
image = images[0]
fig = plt.figure(figsize=(10, 10))  
for s in range(image.shape[3]):
    plt.subplot(8, 8, s+1)
    plt.imshow(np.squeeze(image[:,:,:,s]), cmap="gray")
    plt.axis('off')

## Separation of a test set & KFold

In [None]:
# Test Set!
splitting_dev_test = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=SEED)
dev_idx, test_idx = next(splitting_dev_test.split(allPaths, allLabels))

devPaths, devLabels = allPaths[dev_idx], allLabels[dev_idx]
devSet = CTDataset(devPaths, devLabels)
devLoader = DataLoader(devSet, batch_size=5, shuffle=True, num_workers=0)

testPaths, testLabels = allPaths[test_idx], allLabels[test_idx]
testSet = CTDataset(testPaths, testLabels)
testLoader = DataLoader(testSet, batch_size=5, shuffle=True, num_workers=0)

In [None]:
kfold = KFold(n_splits=5, shuffle=True, random_state=SEED)

# Experiment 1

**In this experiment, we try training model from scratch. We use KFold Cross Validation due to the small size of the dataset.**


        * Experiment #0
        - In this experiment, we try training model from scratch. 
        - We use KFold Cross Validation due to the small size of the dataset.


    * Experiment #1
    - In this experiment, we try training model using pretrained models. 
    - We use KFold Cross Validation due to the small size of the dataset.


## Model

In [None]:
class CTModel(nn.Module):
    def __init__(self, input_shape, transferLearning=False):
        super().__init__()
        self.input_shape = input_shape
        # Import the same architecture from MedicalNet
        self.tf_model = resnet.resnet10(
                sample_input_W=self.input_shape[0],
                sample_input_H=self.input_shape[1],
                sample_input_D=self.input_shape[2],
                shortcut_type='A',
                num_seg_classes=2)
        if transferLearning:
            self.update_transfer_weights()
        """
        # Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=0,
        dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None))
        """
        # Adding extra classifier layers
        self.classifier= nn.Sequential(
            nn.Conv3d(2, 4, 3, 1),
            nn.ReLU(inplace=True),
            nn.Conv3d(4, 4, 3, 1),
            nn.ReLU(inplace=True),
            nn.Conv3d(4, 4, 3, 2),
            nn.ReLU(inplace=True),
            nn.Conv3d(4, 8, 3, 2),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2),
            nn.Flatten()
        )
        

        self.fc1 = nn.Linear(8*3*3, 20)
        self.fc2 = nn.Linear(20,5)
        self.fc3 = nn.Linear(5, 2)
        
    def forward(self, x):
        x = self.tf_model(x)
        x = self.classifier(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc1(x)
        x = Functional.relu(x)
        x = self.fc2(x)
        x = Functional.relu(x)
        x = self.fc3(x)
        x = Functional.softmax(x, dim=1)
        return x
    
    def update_transfer_weights(self, weights_path=None):
        """
            This function updates the model weights by pretrained saved weights from MedicalNet.
            LINK: https://github.com/Tencent/MedicalNet/tree/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b
        """
        self.tf_layers = self.tf_model.state_dict()
        if not weights_path:
            weights_path = "./models/resnet_10_23dataset.pth"
        pretrain_weights = torch.load(weights_path)
        pretrain_weights = {k[7:]:v for k, v in pretrain_weights['state_dict'].items() if k[:7]=="module."}
        pretrain_weights = {k: v for k, v in pretrain_weights.items() if k in self.tf_layers.keys()}
        self.tf_layers.update(pretrain_weights)
        self.tf_model.load_state_dict(self.tf_layers)
        
model_tf = CTModel((128,128,64), transferLearning=True)
model_tf.to(device)

## Training

In [None]:
def train_procedure_withKFolds(net, paths, labels, kfold, stopPatience=3, n_epochs=20, verbose=True, bestModelPath="best0",transfer="no"):
    """
        This function implements the training procedure with KFold cross validation.
    """
    best_loss = 1e+10 # temporary
    last_valLoss = 1e+10 # temporary
    stoppingCounter = 0
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(net.parameters(), lr=0.0001)
    history = {x: {'train': {'loss': [], 'accuracy':[]},
                   'val': {'loss': [], 'accuracy':[]}} for x in range(kfold.get_n_splits())}
    for nfold, (train_idx, valid_idx) in enumerate(kfold.split(paths, labels)):
        trainPaths, trainLabels = paths[train_idx], labels[train_idx]
        validPaths, validLabels = paths[valid_idx], labels[valid_idx]

        trainSet = CTDataset(trainPaths, trainLabels)
        validSet = CTDataset(validPaths, validLabels)

        trainLoader = DataLoader(trainSet, batch_size = BATCH_SIZE,shuffle=True)
        validLoader = DataLoader(validSet, batch_size=BATCH_SIZE, shuffle=True)

        # Start Training!
        for epoch in range(n_epochs):
            # Define loss & acc variables
            epoch_train_loss, epoch_train_acc = 0, 0
            epoch_valid_loss, epoch_valid_acc = 0, 0
            # Set training mode
            net.train()
            for inputsTrain, labelsTrain in trainLoader:
                inputsTrain, labelsTrain = inputsTrain.to(device), labelsTrain.to(device)
                optimizer.zero_grad()
                outputs = net(inputsTrain)
                loss = criterion(outputs, labelsTrain)
                loss.backward()
                optimizer.step() # Update
                epoch_train_loss += loss.item() # Accumelate loss values per epoch
                # Calculate accurcay
                _, indicies = torch.max(outputs, 1)
                epoch_train_acc += accuracy_score(labelsTrain.cpu().numpy(), indicies.cpu().numpy())

            epoch_train_acc = epoch_train_acc/len(trainLoader)    
            epoch_train_loss = epoch_train_loss/len(trainLoader)    
            history[nfold]['train']['accuracy'].append(epoch_train_acc)
            history[nfold]['train']['loss'].append(epoch_train_loss)

            a = np.asarray(history[nfold]['train']['accuracy'])
            name3= ("Fold_{}_train accuracy_{}".format(nfold,transfer))
            with open(name3+".csv","w") as fe:
                    np.savetxt(fe,a, delimiter=",")
                
            name4= ("Fold_{}_train loss_{}".format(nfold,transfer))
            a2 = np.asarray(history[nfold]['train']['loss'])
            with open(name4+".csv","w") as fe:
                    np.savetxt(fe, a2, delimiter=",")

            with torch.no_grad():
                for inputsVal, labelsVal in validLoader: 
                    inputsVal, labelsVal = inputsVal.to(device), labelsVal.to(device)
                    outputsVal = net(inputsVal)
                    lossVal = criterion(outputsVal, labelsVal)
                    epoch_valid_loss += lossVal.item()
                    _, indiciesVal = torch.max(outputsVal, 1)
                    epoch_valid_acc += accuracy_score(labelsVal.cpu().numpy(), indiciesVal.cpu().numpy())
                    
                epoch_valLoss = epoch_valid_loss/len(validLoader)
                epoch_valid_acc = epoch_valid_acc/len(validLoader)
                history[nfold]['val']['accuracy'].append(epoch_valid_acc)
                history[nfold]['val']['loss'].append(epoch_valLoss)
                
                e = np.asarray(history[nfold]['val']['accuracy'])
                name= ("Fold_{}_Val accuracy_{}".format(nfold,transfer))
                with open(name+".csv","w") as fe:
                    np.savetxt(fe, e, delimiter=",")
                
                name2= ("Fold_{}_Val loss_{}".format(nfold,transfer))
                e2 = np.asarray(history[nfold]['val']['loss'])
                with open(name2+".csv","w") as fe:
                    np.savetxt(fe, e2, delimiter=",")


            if verbose:
                print("epoch: {}, train_loss: {:0.3f}, train_acc: {:0.3f}, valid_loss: {:0.3f}, valid_acc: {:0.3f}".format(epoch+1, epoch_train_loss, epoch_train_acc, epoch_valid_loss, epoch_valid_acc))

            if epoch_valid_loss <= best_loss:
                # Save the model with the lowest validation loss.
                torch.save(net.state_dict(), "./{}.pth".format(bestModelPath))
                best_loss = epoch_valid_loss
                if verbose:
                    print("model is saved...")
            if epoch_valLoss <= last_valLoss:
                stoppingCounter = 0
            else:
                stoppingCounter += 1
            last_valLoss = epoch_valLoss

            if stoppingCounter >= stopPatience:
                if verbose:
                    print("Early stopped at Fold {}, Epoch {}".format(nfold, epoch))
                stoppingCounter = 0

                break
    
            
    return net, history, trainLoader

In [None]:
bestModelPath = "./LastResults/Transfer/bestModel"
trainedModel_tf, history_tf, tLoader = train_procedure_withKFolds(model_tf, devPaths, devLabels, kfold, stopPatience=5, n_epochs=20, verbose=True, bestModelPath=bestModelPath, transfer="yes")

In [None]:
trainedModel_tf.load_state_dict(torch.load("{}.pth".format(bestModelPath)))

## Evalutaion

In [None]:
def plotKFoldsHistory(history, type="Scratch", name="img", plot=False):
    # First, get the xAxis
    xLen = max([len(history[i]['train']['loss']) for i in history.keys()])
    xLen = [x+1 for x in range(xLen)]
    legends = []

    # Validation Loss
    fig = plt.figure(figsize=(10, 5))
    for iFold in history.keys():
        y = history[iFold]['val']['loss']
        if len(xLen) != len(y):
            newY = [y[i] if i < len(y) else y[-1] for i in range(len(xLen))]
            y = newY
        plt.plot(xLen, y)
        legends.append("{}".format(iFold+1))
    plt.title("Model Validation Loss")
    plt.ylabel("Loss")
    plt.xlabel("Epoch")
    plt.legend(legends)
    #plt.savefig('/home/omar.mashaal1/mining/Images/{}/{}.jpg'.format(type, name+"_valLoss"))
    if plot:
        plt.show()

    # Validation Accuracy
    fig = plt.figure(figsize=(10, 5))
    for iFold in history.keys():
        y = history[iFold]['val']['accuracy']
        if len(xLen) != len(y):
            newY = [y[i] if i < len(y) else y[-1] for i in range(len(xLen))]
            y = newY
        plt.plot(xLen, y)
        legends.append("{}".format(iFold+1))
    plt.title("Model Validation Accuracy")
    plt.ylabel("Accuracy")
    plt.xlabel("Epoch")
    plt.legend(legends)
    #plt.savefig('/home/omar.mashaal1/mining/Images/{}/{}.jpg'.format(type, name+"_valAccuracy"))

In [None]:
import pandas as pd
import seaborn as sn

def getConfusionMatrix(dataloader, model, type, name, plot=False):
    model.eval()
    preds, truths = [], []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            print("inputs: ", inputs.shape)
            _, output = torch.max(model(inputs), 1)
            print("output: ", output.shape)
            preds.extend(output.detach().cpu().numpy())
            truths.extend(labels.detach().cpu().numpy())

    accuracy = accuracy_score(truths, preds)
    print(accuracy)
    cfMatrix = confusion_matrix(truths, preds)
    cfFrame = pd.DataFrame(cfMatrix / np.sum(cfMatrix, axis=1)[:, None], index = [i for i in [0, 1]],
                     columns = [i for i in [0, 1]])
    plt.figure(figsize = (12,7))
    sn.heatmap(cfFrame, annot=True)
#     #plt.savefig('/home/omar.mashaal1/mining/Images/{}/{}.jpg'.format(type, name))
#     if plot:
#         plt.show()
    return cfMatrix, truths, preds
    #return accuracy

In [None]:
plotKFoldsHistory(history_tf, type="TF", name="exp1", plot=True)

In [None]:
cfMatrix_test1, truths_test1, preds_test1 = getConfusionMatrix(testLoader, trainedModel_tf, type="TL", name="exp1_testing", plot=False)

In [None]:
cfMatrix_train1, truths_train1, preds_train1 = getConfusionMatrix(devLoader, trainedModel_tf, type="TL", name="exp1_training", plot=False)

# Experiment 0

## Model

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=2)
        self.bn = nn.BatchNorm3d(out_channels)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.bn(x)
        return x

class get_model(nn.Module):
    def __init__(self, width=128, height=128, depth=64):
        super(get_model, self).__init__()
        self.conv1 = ConvBlock(1, 64)
        self.conv2 = ConvBlock(64, 64)
        self.conv3 = ConvBlock(64, 128)
        self.conv4 = ConvBlock(128, 256)
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc1 = nn.Linear(256, 512)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(512, 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = Functional.softmax(x, dim=1)      
        return x
Model_0 = get_model().to(device)
#count_parameters(model)

## Training

In [None]:
bestModelPath = "./LastResults/Base/bestModel"
trainedModel_0, history_0, tLoader = train_procedure_withKFolds(Model_0, devPaths, devLabels, kfold, stopPatience=5, n_epochs=20, verbose=True, bestModelPath=bestModelPath, transfer="no")

In [None]:
trainedModel_0.load_state_dict(torch.load("{}.pth".format(bestModelPath)))

## Evaluation

In [None]:
plotKFoldsHistory(history_0, type="Base", name="exp0", plot=True)

In [None]:
C_test0, T_test0, L_test0 = getConfusionMatrix(testLoader, trainedModel_0, type="Base", name="exp0_final", plot=False)

In [None]:
a1 = np.asarray(T_test)
name1 = ("Scratch Best Model Testing_{}".format("Truth"))
with open(name1+".csv", "w") as fe1:
    np.savetxt(fe1, a1, delimiter=",")

a2 = np.asarray(L_test)
name2 = ("Scratch Best Model Testing_{}".format("Prediction"))
with open(name2+".csv", "w") as fe2:
    np.savetxt(fe2, a2, delimiter=",")
    
a3 = np.asarray(C_test)
name3 = ("Scratch Best Model Testing_{}".format("ConvM"))
with open(name3+".csv", "w") as fe3:
    np.savetxt(fe3, a3, delimiter=",")