In [1]:
import torch
import os
import torchvision
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import skorch
from imageio import imread
from tqdm import tqdm

In [2]:

from skorch.callbacks import LRScheduler,Checkpoint,EpochScoring,EarlyStopping
import torch.optim as optim
from skorch.helper import predefined_split

from skorch import NeuralNetClassifier

In [3]:

from sklearn.model_selection import train_test_split

def train_validation_test_split(index, 
                                y, 
                                validation_size=0.20, 
                                test_size=0.20,
                                random_state=None):

    train_index, test_index, y_train, _ = train_test_split( index, 
                                                y, 
                                                test_size=test_size, 
                                                stratify=y, 
                                                random_state=random_state)

    train_index, validation_index, _, _ = train_test_split( train_index, 
                                                y_train, 
                                                test_size=validation_size, 
                                                stratify=y_train,
                                                random_state=random_state)
    return train_index, validation_index, test_index

In [4]:
def over_sampler(metadata, n):
    train_index = metadata.set == "train"
    val_index = metadata.set == "val"
    test_index = metadata.set == "test"
    
    metadata_over_sampled = pd.DataFrame(columns=["file", "label", "dataset", "set"]) 
    
    list_of_classes = metadata.loc[train_index, "label"].unique()
    for cl in list_of_classes:
        specific_class_index = train_index.copy()
        specific_class_index = specific_class_index & metadata.label == cl
        
        metadata_dummy = metadata.loc[specific_class_index,:].sample(n, replace=True)
        metadata_over_sampled = metadata_over_sampled.append(metadata_dummy, ignore_index = True)
    
    ## adding the val and test sets
    metadata_over_sampled = metadata_over_sampled.append(metadata.loc[val_index,:], ignore_index = True)
    metadata_over_sampled = metadata_over_sampled.append(metadata.loc[test_index,:], ignore_index = True)
    
    return metadata_over_sampled

In [5]:
from sklearn.metrics import plot_confusion_matrix, matthews_corrcoef, classification_report,confusion_matrix, accuracy_score, balanced_accuracy_score, cohen_kappa_score, f1_score,  precision_score, recall_score
from statsmodels.stats.contingency_tables import mcnemar
from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay


def classification_complete_report(y_true, y_pred ,labels = None  ): 
    print(classification_report(y_true, y_pred, labels = labels))
    print(15*"----")
    print("matthews correlation coeff: %.4f" % (matthews_corrcoef(y_true, y_pred)) )
    print("Cohen Kappa score: %.4f" % (cohen_kappa_score(y_true, y_pred)) )
    print("Accuracy: %.4f & balanced Accuracy: %.4f" % (accuracy_score(y_true, y_pred), balanced_accuracy_score(y_true, y_pred)) )
    print("macro F1 score: %.4f & micro F1 score: %.4f" % (f1_score(y_true, y_pred, average = "macro"), f1_score(y_true, y_pred, average = "micro")) )
    print("macro Precision score: %.4f & micro Precision score: %.4f" % (precision_score(y_true, y_pred, average = "macro"), precision_score(y_true, y_pred, average = "micro")) )
    print("macro Recall score: %.4f & micro Recall score: %.4f" % (recall_score(y_true, y_pred, average = "macro"), recall_score(y_true, y_pred, average = "micro")) )
    cm = confusion_matrix(y_true, y_pred,labels= labels)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    disp.plot(cmap=plt.cm.Blues)
    plt.show()
    print(15*"----")

In [6]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda

Tesla P100-SXM2-16GB
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [7]:
data_path = {
    "AML": "/pstore/data/DS4/ssl_vs_al/data/matek/train/",
    "MLL": "/pstore/data/DS4/ssl_vs_al/data/matek/test/",
    "PBC": "/pstore/data/DS4/ssl_vs_al/data/matek/test/",
}

equivalent_classes = {
    'BAS': 'basophil',
    'EBO': 'erythroblast',
    'EOS': 'eosinophil',
    'KSC': "unknown",
    'LYA': 'lymphocyte',
    'LYT': 'lymphocyte',
    'MMZ': 'ig',
    'MOB': 'monocyte',
    'MON': 'monocyte',
    'MYB': 'ig',
    'MYO': 'ig',
    'NGB': 'neutrophil',
    'NGS': 'neutrophil',
    'PMB': 'ig',
    'PMO': 'ig',
}



def finding_classes(data_dir):
    """
    this function finds the folders in the root path and considers them
    as classes
    """
    classes = sorted(os.listdir(data_dir))
    return classes


In [8]:
def metadata_generator(data_path):
    
    metadata = pd.DataFrame(columns=["file", "label", "dataset", "set"])
    for ds in data_path:
        list_of_classes = finding_classes(data_path[ds])
        for cl in list_of_classes:
            
            metadata_dummy = pd.DataFrame(columns=["file", "label", "dataset", "set"])
            metadata_dummy["file"] = glob(os.path.join(data_path[ds], cl, "*"))
            metadata_dummy["label"] = cl
            metadata_dummy["dataset"] = ds
            metadata_dummy["set"] = "train"
            metadata = metadata.append(metadata_dummy, ignore_index=True)
            metadata_dummy = None
    return metadata
    



In [9]:
metadata = metadata_generator(data_path)

In [10]:
metadata.label = metadata.label.replace(equivalent_classes)

In [11]:
## leave unknown class out
known_index = metadata.label != "unknown"
metadata = metadata.loc[known_index,:].reset_index(drop = True)

In [12]:

train_index, val_index, test_index = train_validation_test_split(metadata.index, 
                                                                 y = metadata.label,
                                                                random_state = 314)

metadata.loc[train_index, "set"] = "train"

metadata.loc[val_index, "set"] = "val"

metadata.loc[test_index, "set"] = "test"

In [13]:
metadata

Unnamed: 0,file,label,dataset,set
0,/pstore/data/DS4/ssl_vs_al/data/matek/train/BA...,basophil,AML,train
1,/pstore/data/DS4/ssl_vs_al/data/matek/train/BA...,basophil,AML,train
2,/pstore/data/DS4/ssl_vs_al/data/matek/train/BA...,basophil,AML,train
3,/pstore/data/DS4/ssl_vs_al/data/matek/train/BA...,basophil,AML,train
4,/pstore/data/DS4/ssl_vs_al/data/matek/train/BA...,basophil,AML,test
...,...,...,...,...
22015,/pstore/data/DS4/ssl_vs_al/data/matek/test/PMO...,ig,PBC,train
22016,/pstore/data/DS4/ssl_vs_al/data/matek/test/PMO...,ig,PBC,test
22017,/pstore/data/DS4/ssl_vs_al/data/matek/test/PMO...,ig,PBC,train
22018,/pstore/data/DS4/ssl_vs_al/data/matek/test/PMO...,ig,PBC,train


In [14]:
label_map = dict()

for i,cl in enumerate(metadata.label.unique()):
    label_map[cl] = i
    
label_map

{'basophil': 0,
 'erythroblast': 1,
 'eosinophil': 2,
 'lymphocyte': 3,
 'ig': 4,
 'monocyte': 5,
 'neutrophil': 6}

In [15]:
def over_sampler(metadata, n):
    train_index = metadata.set == "train"
    
    metadata_over_sampled = pd.DataFrame(columns=["file", 
                                                  "label", 
                                                  "dataset", 
                                                  "set"]) 
    
    list_of_classes = metadata.loc[train_index, "label"].unique()
    for cl in list_of_classes:
        specific_class_index = train_index.copy()
        specific_class_index = specific_class_index & (metadata.label == cl)
        
        metadata_dummy = metadata.loc[specific_class_index,:].sample(n, replace=True)
        metadata_over_sampled = metadata_over_sampled.append(metadata_dummy, ignore_index = True)
    
    return metadata_over_sampled

In [16]:
metadata_over_sampled = over_sampler(metadata, 100)

In [17]:
metadata_over_sampled

Unnamed: 0,file,label,dataset,set
0,/pstore/data/DS4/ssl_vs_al/data/matek/test/BAS...,basophil,PBC,train
1,/pstore/data/DS4/ssl_vs_al/data/matek/train/BA...,basophil,AML,train
2,/pstore/data/DS4/ssl_vs_al/data/matek/train/BA...,basophil,AML,train
3,/pstore/data/DS4/ssl_vs_al/data/matek/train/BA...,basophil,AML,train
4,/pstore/data/DS4/ssl_vs_al/data/matek/train/BA...,basophil,AML,train
...,...,...,...,...
695,/pstore/data/DS4/ssl_vs_al/data/matek/test/NGS...,neutrophil,MLL,train
696,/pstore/data/DS4/ssl_vs_al/data/matek/train/NG...,neutrophil,AML,train
697,/pstore/data/DS4/ssl_vs_al/data/matek/test/NGS...,neutrophil,MLL,train
698,/pstore/data/DS4/ssl_vs_al/data/matek/train/NG...,neutrophil,AML,train


In [21]:
from torch.utils.data import Dataset, DataLoader
import copy

class DatasetGenerator(Dataset):

    def __init__(self, 
                metadata, 
                reshape_size=64, 
                label_map=[],
                dataset = [],
                transform=None,
                selected_channels = [0]):
        
        dataset_index = metadata.dataset.isin(dataset)
        self.metadata = metadata.loc[dataset_index,:].copy()
        self.metadata = self.metadata.copy().reset_index(drop = True)
        
        self.reshape_size = reshape_size
        self.label_map = label_map
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()

        ## get image and label
        h5_file_path = self.metadata.loc[idx,"file"]
        image= imread(h5_file_path)[:,:,selected_channels]
        image = np.transpose(image, (2, 0, 1))
        label = self.metadata.loc[idx,"label"]
 

        # map numpy array to tensor
        image = torch.from_numpy(copy.deepcopy(image)) 
        image = image.float()
        
        if self.transform:
            image = self.transform(image) 
        
        label = self.label_map[label]
        label = torch.tensor(label).long()
        return image.float(),  label
        
        

def get_statistics(dataloader, nmb_channels ):

    statistics = dict()
    
    statistics["mean"] = torch.zeros(nmb_channels)
    statistics["std"] = torch.zeros(nmb_channels)
    for _, data_l in enumerate(tqdm(dataloader), 0):
        image, _ = data_l
        for n in range(nmb_channels):

            statistics["mean"][n] += image[:, n, :, :].mean()
            statistics["std"][n] += image[:, n, :, :].std()

    # averaging
    for k in statistics:
        statistics[k] = statistics[k].div_(len(dataloader))

    print('statistics used: %s' % (str(statistics)))

    return statistics

In [22]:
train_index = metadata.set == "train"
val_index = metadata.set == "val"
test_index = metadata.set == "test"

In [20]:
selected_channels = [0,1,2]

In [24]:
reshape_size = 64

train_dataset = DatasetGenerator(metadata.loc[train_index ,:], # metadata_over_sampled
                                 reshape_size=reshape_size, 
                                label_map=label_map,
                                 dataset = ["PBC","MLL"],
                                transform=None,
                                selected_channels=selected_channels)

Unnamed: 0,file,label,dataset,set
0,/pstore/data/DS4/ssl_vs_al/data/matek/test/BAS...,basophil,MLL,train
1,/pstore/data/DS4/ssl_vs_al/data/matek/test/BAS...,basophil,MLL,train
2,/pstore/data/DS4/ssl_vs_al/data/matek/test/BAS...,basophil,MLL,train
3,/pstore/data/DS4/ssl_vs_al/data/matek/test/BAS...,basophil,MLL,train
4,/pstore/data/DS4/ssl_vs_al/data/matek/test/BAS...,basophil,MLL,train
...,...,...,...,...
4703,/pstore/data/DS4/ssl_vs_al/data/matek/test/PMO...,ig,PBC,train
4704,/pstore/data/DS4/ssl_vs_al/data/matek/test/PMO...,ig,PBC,train
4705,/pstore/data/DS4/ssl_vs_al/data/matek/test/PMO...,ig,PBC,train
4706,/pstore/data/DS4/ssl_vs_al/data/matek/test/PMO...,ig,PBC,train


In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 64, shuffle = True, num_workers = 4)

In [None]:
stats = get_statistics(train_loader,nmb_channels=len(selected_channels) )

In [None]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
import torch.nn as nn
from torchvision.models import resnet18
import torch.nn.functional as F

class ResNet18Pretrained(nn.Module):
    def __init__(self,  
                 num_channels=3, 
                 num_classes=3, 
                 pretrained=True ,**kwargs):
        
        super().__init__()
        model = resnet18(pretrained=True) 
        if num_channels != 3:
            model.conv1 = nn.Conv2d(num_channels, 64, kernel_size=(7, 7),
                                        stride=(2, 2), padding=(3, 3), bias=False)
        
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
        self.model = model

    def forward(self, x):                
        return self.model(x)

In [None]:
resnet_pretrained = ResNet18Pretrained(num_channels=3 , num_classes = len(metadata.label.unique()) )

In [None]:

from torchvision import transforms
eps = 1e-16
stats["std"] = stats["std"] + eps
train_transform = transforms.Compose([ 
        torchvision.transforms.Normalize(
                mean=stats["mean"],
                std=stats["std"],
            ),
        transforms.RandomResizedCrop(reshape_size, scale=(0.8, 1.0), ratio=(0.8, 1.2)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        AddGaussianNoise(mean=0., std=0.02),
])



train_dataset = DatasetGenerator(metadata.loc[train_index ,:], 
                                 reshape_size=reshape_size, 
                                 dataset = ["PBC","MLL"],
                                label_map=label_map, 
                                 transform = train_transform)

val_transform = transforms.Compose([ 
        torchvision.transforms.Normalize(
                mean=stats["mean"],
                std=stats["std"],
            ),
        transforms.Resize(reshape_size)])


val_dataset = DatasetGenerator(metadata.loc[val_index,:], 
                                 reshape_size=reshape_size, 
                                 dataset = ["PBC","MLL"],
                                label_map=label_map, 
                                 transform = val_transform)



test_dataset = DatasetGenerator(metadata.loc[test_index,:], 
                                 reshape_size=reshape_size, 
                                 dataset = ["AML"],
                                label_map=label_map, 
                                 transform = val_transform)

In [None]:
test_dataset[0][0].shape

In [None]:
plt.imshow(test_dataset[0][0][0,:,:])

In [None]:


lr_scheduler = LRScheduler(policy='StepLR', step_size=5, gamma=0.6)
checkpoint = Checkpoint(f_params='resnet_18_aml.pth', monitor='valid_acc_best')


epoch_scoring = EpochScoring("f1_macro", name = 
                       "valid_f1_macro",
                       on_train = False,
                       lower_is_better = False)

early_stopping = EarlyStopping(monitor='valid_f1_macro', 
                               patience=50, 
                               threshold=0.0001, 
                               threshold_mode='rel', 
                               lower_is_better=False)

model = NeuralNetClassifier(    
    resnet_pretrained, 
    criterion=nn.CrossEntropyLoss,
    lr=0.001,
    batch_size=64,
    max_epochs=1000,
    optimizer=optim.Adam,
    iterator_train__shuffle=True,
    iterator_train__num_workers=3,
    iterator_valid__shuffle=False,
    iterator_valid__num_workers=1,
    callbacks=[lr_scheduler,checkpoint, epoch_scoring, early_stopping],
    train_split=predefined_split(val_dataset),
    device="cuda",
    warm_start=True)

In [None]:
model = model.fit(train_dataset, y = None)

In [None]:
model.module.load_state_dict(torch.load('resnet_18_aml.pth')) 


In [None]:
preds = model.predict(val_dataset)

classification_complete_report([label_map[t] for t in val_dataset.metadata.label], 
                               preds  )

In [None]:
preds = model.predict(test_dataset)
#preds =  [inv_map[int(t)] for t in preds]

classification_complete_report([label_map[t] for t in test_dataset.metadata.label],  preds   )