In [30]:
from IPython.core.debugger import Tracer
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models
import argparse
from datetime import datetime
import os
import time
import multiprocessing
import psutil
import json
import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd  
import h5py
from sklearn.model_selection import train_test_split
from skimage.util import crop,  random_noise
from skimage.transform import   rescale, resize, rotate, AffineTransform, warp
import torch.optim as optim
from tqdm import tqdm
from resnet18 import resnet18
from collections import Counter
from util import get_statistics
from dataset import Dataset_Generator, train_validation_test_split, get_classes_map, number_of_classes, number_of_channels, get_all_object_numbers_labels
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score
import ast

In [85]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 300
n_epochs = 100
num_workers = 0
lr = 0.001
n_splits = 10

In [38]:
only_channels = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
only_classes = None

In [39]:
h5_file = "data/WBC/Lyse fix sample_1_Focused & Singlets & CD45 pos.h5"

In [40]:
label_map = get_classes_map(h5_file)

In [41]:
class_names = list(label_map.values())

In [42]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1., p=0.5):
        self.std = std
        self.mean = mean
        self.p = p
        
    def __call__(self, tensor):
        if torch.rand(1) < self.p:
            return tensor + torch.randn(tensor.size()) * self.std + self.mean
        return tensor
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [43]:
transform = transforms.Compose(
        [transforms.RandomVerticalFlip(),
         transforms.RandomHorizontalFlip(),
         transforms.RandomRotation(45),
        AddGaussianNoise(0., 1., 0.3)])

In [44]:
num_classes = number_of_classes(h5_file, only_classes=only_classes)
num_channels = number_of_channels(h5_file, only_channels=only_channels)

In [46]:
from imblearn.over_sampling import RandomOverSampler
ros = RandomOverSampler(random_state=42, sampling_strategy='all')

In [78]:
class oversampled_Kfold():
    def __init__(self, n_splits, n_repeats=1):
        self.n_splits = n_splits
        self.n_repeats = n_repeats

    def get_n_splits(self, X, y, groups=None):
        return self.n_splits*self.n_repeats

    def split(self, X, y, groups=None):
        splits = np.array_split(np.random.choice(len(X), len(X),replace=False), self.n_splits)
        train, test = [], []
        for repeat in range(self.n_repeats):
            for idx in range(len(splits)):
                trainingIdx = np.delete(splits, idx)
                Xidx_r, y_r = ros.fit_resample(np.hstack(trainingIdx).reshape((-1,1)), np.asarray(y[np.hstack(trainingIdx)]))
                train.append(Xidx_r.flatten())
                test.append(splits[idx])
        return list(zip(train, test))

In [79]:
rkf_search = oversampled_Kfold(n_splits=n_splits, n_repeats=1)

In [80]:
X, y = get_all_object_numbers_labels(h5_file, only_classes) 

In [84]:
print("Start validation")
for train_indx, test_indx in rkf_search.split(X, y):
    train_dataset = Dataset_Generator(h5_file, train_indx, reshape_size=64, transform=transform,
                                      only_channels=only_channels, only_classes=only_classes)
    trainloader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=num_workers)
    statistics = get_statistics(trainloader, only_channels)
    train_dataset = Dataset_Generator(h5_file, train_indx, reshape_size=64, transform=transform,
                                      means=statistics["mean"].div_(len(trainloader)),
                                      stds=statistics["std"].div_(len(trainloader)), only_channels=only_channels,
                                      only_classes=only_classes)
    test_dataset = Dataset_Generator(h5_file, test_indx, reshape_size=64,
                                     means=statistics["mean"].div_(len(trainloader)),
                                     stds=statistics["std"].div_(len(trainloader)), only_channels=only_channels,
                                     only_classes=only_classes)
    trainloader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=num_workers)
    testloader = DataLoader(test_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)
    model = resnet18(pretrained=True)

    # loading the imagenet weights in case it is possible
    if num_channels != 3:
        model.conv1 = nn.Conv2d(num_channels, 64, kernel_size=(7, 7),
                                    stride=(2, 2), padding=(3, 3), bias=False)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)

    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    for epoch in range(n_epochs):
        running_loss = 0.0
        print('epoch%d' % epoch)
        for i, data in enumerate(trainloader, 0):
            indx = (data["object_number"] != -1).reshape(-1)
            if indx.sum() > 0:
                inputs, labels = data["image"][indx], data["label"][indx]

                inputs, labels = inputs.to(device), labels.to(device)
                inputs = inputs.float()
                labels = labels.reshape(-1)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)

                loss = criterion(outputs, F.one_hot(labels.long(), num_classes).type_as(outputs))
                loss.backward()
                optimizer.step()

                # print statistics
                running_loss += loss.item()
            if i % 500 == 499:  # print every 2000 mini-batches
                print('[%d, %5d] training loss: %.8f' % (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
    correct = 0.
    total = 0.
    y_true = list()
    y_pred = list()

    with torch.no_grad():
        for data in testloader:
            indx = (data["object_number"] != -1).reshape(-1)
            if indx.sum() > 0:
                inputs, labels = data["image"][indx], data["label"][indx]

                inputs, labels = inputs.to(device), labels.to(device)     
                inputs = inputs.float()
                labels = labels.reshape(-1)

                outputs = model(inputs)
                pred = outputs.argmax(dim=1)
                _ , predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (labels.reshape(-1) == predicted).sum().item()
                for i in range(len(pred)):
                    y_true.append(labels[i].item())
                    y_pred.append(pred[i].item())

    print('Accuracy of the network on the %d test images: %d %%' % (len(test_dataset),
        100 * correct / total))
    print(classification_report(y_true, y_pred, target_names=class_names, digits=4))
    f1_score_original = f1_score(y_true, y_pred, average=None, labels=np.arange(num_classes))
    df = pd.DataFrame(np.atleast_2d(f1_score_original), columns=class_names)
    print(df.to_string())

Start validation


  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)


epoch0
epoch1
epoch2
epoch3
epoch4
Accuracy of the network on the 10427 test images: 5 %
                   precision    recall  f1-score   support

          unknown     0.0000    0.0000    0.0000       412
           CD4+ T     1.0000    0.0035    0.0070      1428
           CD8+ T     0.0000    0.0000    0.0000       615
 CD15+ neutrophil     0.0000    0.0000    0.0000      6103
   CD14+ monocyte     0.0000    0.0000    0.0000       445
          CD19+ B     0.0000    0.0000    0.0000       318
         CD56+ NK     0.0000    0.0000    0.0000       230
              NKT     0.0000    0.0000    0.0000       331
       eosinophil     0.0523    1.0000    0.0994       545

         accuracy                         0.0527     10427
        macro avg     0.1169    0.1115    0.0118     10427
     weighted avg     0.1397    0.0527    0.0062     10427

    unknown    CD4+ T   CD8+ T   CD15+ neutrophil   CD14+ monocyte   CD19+ B   CD56+ NK   NKT   eosinophil
0       0.0  0.006978      0.0    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch0
epoch1
epoch2
epoch3
epoch4
Accuracy of the network on the 10427 test images: 6 %
                   precision    recall  f1-score   support

          unknown     0.0000    0.0000    0.0000       414
           CD4+ T     0.9839    0.0404    0.0777      1509
           CD8+ T     0.8421    0.0530    0.0997       604
 CD15+ neutrophil     1.0000    0.0003    0.0006      6163
   CD14+ monocyte     0.0000    0.0000    0.0000       423
          CD19+ B     0.0000    0.0000    0.0000       287
         CD56+ NK     0.0000    0.0000    0.0000       203
              NKT     0.2079    0.1775    0.1915       355
       eosinophil     0.0468    1.0000    0.0894       469

         accuracy                         0.0601     10427
        macro avg     0.3423    0.1412    0.0510     10427
     weighted avg     0.7914    0.0601    0.0279     10427

    unknown    CD4+ T    CD8+ T   CD15+ neutrophil   CD14+ monocyte   CD19+ B   CD56+ NK       NKT   eosinophil
0       0.0  0.077658  0.0996

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch0
epoch1
epoch2
epoch3
epoch4
Accuracy of the network on the 10426 test images: 5 %
                   precision    recall  f1-score   support

          unknown     0.0000    0.0000    0.0000       460
           CD4+ T     0.0000    0.0000    0.0000      1539
           CD8+ T     0.0000    0.0000    0.0000       569
 CD15+ neutrophil     0.0000    0.0000    0.0000      6059
   CD14+ monocyte     0.0000    0.0000    0.0000       414
          CD19+ B     0.0000    0.0000    0.0000       284
         CD56+ NK     0.0000    0.0000    0.0000       236
              NKT     0.2400    0.0175    0.0327       342
       eosinophil     0.0503    1.0000    0.0958       523

         accuracy                         0.0507     10426
        macro avg     0.0323    0.1131    0.0143     10426
     weighted avg     0.0104    0.0507    0.0059     10426

    unknown   CD4+ T   CD8+ T   CD15+ neutrophil   CD14+ monocyte   CD19+ B   CD56+ NK       NKT   eosinophil
0       0.0      0.0      0.0  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
