In [None]:
from datetime import datetime
import itertools
import logging
import numpy as np
import os
from sklearn.datasets import make_classification
from sklearn.metrics import classification_report
from skorch import NeuralNetClassifier
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split
from skorch.callbacks import Checkpoint, TrainEndCheckpoint
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

In [None]:
import iflai
from iflai.dl.util import calculate_weights, train_validation_test_split, get_statistics
from iflai.dl.dataset import DatasetGenerator
from iflai.dl.models import PretrainedModel, resnet18

In [None]:
seed_value = 42

os.environ['PYTHONHASHSEED']=str(seed_value)
import random
random.seed(seed_value)

np.random.seed(seed_value)
torch.manual_seed(seed_value)

#### Define all necessary parameters

In [None]:
dataset_name = "wbc"
num_of_all_channels = 12
path_to_data ="..\..\data/WBC"
model_dir = "models"
log_dir = "logs"
scaling_factor = 4095.
reshape_size = 64
train_transform_init = [
         transforms.RandomVerticalFlip(),
         transforms.RandomHorizontalFlip(),
         transforms.RandomRotation(45)
        ]
test_transform_init = [ ]
all_channels = np.arange(num_of_all_channels)
channels =np.asarray([ "Ch" + str(i) for i in all_channels])

# how many times the model has to be retrained for the same set of channels
n_retrain = 5
# how many channels should be removed from the dataset
#number_removed_channels = 3

In [None]:
batch_size = 64
num_workers = 2
device="cpu"

In [None]:
# hyperparameters for the model
lrscheduler = LRScheduler(
    policy='StepLR', step_size=7, gamma=0.5)
number_epochs = 10
lr = 0.001
momentum=0.9
optimizer = optim.SGD

In [None]:
os.makedirs(model_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)

In [None]:
# initialize logging
now = datetime.now()
timestamp = datetime.timestamp(now)
logging.basicConfig(filename=os.path.join(log_dir, 'remove_and_retrain_{}_{}.txt'.format(dataset_name, timestamp)), level=logging.DEBUG)

#### Load data

In [None]:
%time

metadata = iflai.metadata_generator(path_to_data)

In [None]:
indx = metadata["label"] != "unknown"
metadata = metadata.loc[indx,:]
metadata = metadata.reset_index(drop = True)

In [None]:
label_map = dict(zip(sorted(set(metadata["label"])), np.arange(len(set(metadata["label"])))))
num_classes = len(label_map.keys())
class_names_targets = [c.decode("utf-8") for c in label_map.keys()]

In [None]:
def split_load_normalize_data(random_state=seed_value, selected_channels=[]):
    train_index, validation_index, test_index = train_validation_test_split(metadata.index, metadata["label"], random_state=seed_value)
    
    # caclculate statistics
    train_transform = train_transform_init.copy()
    test_transform = test_transform_init.copy()
    train_dataset = DatasetGenerator(metadata=metadata.loc[train_index,:],
                                 label_map=label_map,
                                 selected_channels=selected_channels,
                                 scaling_factor=scaling_factor,
                                 reshape_size=reshape_size,
                                 transform=transforms.Compose(train_transform))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    statistics = get_statistics(train_loader, selected_channels=selected_channels)
    
    # normalize data
    train_transform.append(transforms.Normalize(mean=statistics["mean"][selected_channels],
                         std=statistics["mean"][selected_channels]))
    test_transform.append(transforms.Normalize(mean=statistics["mean"][selected_channels],
                         std=statistics["mean"][selected_channels]))
  
    
    train_dataset = DatasetGenerator(metadata=metadata.loc[train_index,:],
                                 label_map=label_map,
                                 selected_channels=selected_channels,
                                 scaling_factor=scaling_factor, 
                                 reshape_size=reshape_size,
                                 transform= transforms.Compose(train_transform))
    validation_dataset = DatasetGenerator(metadata=metadata.loc[validation_index,:],
                                      label_map=label_map,
                                      selected_channels=selected_channels,
                                      scaling_factor=scaling_factor,
                                      reshape_size=reshape_size,
                                      transform=transforms.Compose(test_transform))
    test_dataset = DatasetGenerator(metadata=metadata.loc[test_index,:],
                                    label_map=label_map,
                                    selected_channels=selected_channels,
                                    scaling_factor=scaling_factor,
                                    reshape_size=reshape_size,
                                    transform=transforms.Compose(test_transform))
    return train_dataset, validation_dataset, test_dataset

In [None]:
def train_model(train_dataset, validation_dataset, num_channels, selected_channels, seed):
    model_saved_name = '{}_net_{}_seed_{}.pth'.format(dataset_name, '_'.join(map(str,selected_channels)), seed)
    checkpoint = Checkpoint(f_params=model_saved_name, monitor='valid_loss_best', dirname='models')
    net = NeuralNetClassifier(
        PretrainedModel, 
        criterion=nn.CrossEntropyLoss,
        lr=lr,
        batch_size=batch_size,
        max_epochs=number_epochs,
        module__output_features=num_classes,
        module__num_classes=num_classes,
        module__num_channels=num_channels, 
        optimizer=optimizer,
        optimizer__momentum=momentum,
        iterator_train__shuffle=False,
        iterator_train__num_workers=num_workers,
        iterator_valid__shuffle=False,
        iterator_valid__num_workers=num_workers,
        callbacks=[lrscheduler, checkpoint],
        train_split=predefined_split(validation_dataset),
        #device='cuda' # comment to train on cpu
    )
    net.fit(train_dataset, y=None)
    
    return model_saved_name

In [None]:
def load_and_eval_model(num_channels, test_dataset, path_to_the_cp=""):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    # load model
    model = PretrainedModel(num_classes, num_channels)
    checkpoint = torch.load(os.path.join(model_dir, path_to_the_cp))
    model.load_state_dict(checkpoint)
    model = model.to(device)
    
    # evaluate
    correct = 0.
    total = 0.
    y_true = list()
    y_pred = list()
    y_true_proba = list()
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data[0].to(device).float(), data[1].to(device).long()
            outputs = model(inputs)
            pred = outputs.argmax(dim=1)
            true_proba = np.array([j[i] for (i,j) in zip(pred, outputs)])
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (labels.reshape(-1) == predicted).sum().item()
            for i in range(len(pred)):
                y_true.append(labels[i].item())
                y_pred.append(pred[i].item())
                y_true_proba.append(true_proba[i].item())
    
    # save result
    logging.info(classification_report(y_true, y_pred, target_names=class_names_targets, digits=4))

In [None]:
def findsubsets(s, n_elements):
    return list(itertools.combinations(s, n_elements))

In [None]:
s = set(all_channels)
for number_removed_channels in np.array([1,2,4,5,6,7,8,9,10,11]):
    all_combinations = findsubsets(s, num_of_all_channels - number_removed_channels)
    for channel_comb in all_combinations:
        for n in range(n_retrain):
            channel_comb = np.asarray(channel_comb)
            logging.info("Train new model: iteration {}, channels: {}".format(str(n), '_'.join(map(str, channel_comb))))
            num_channels = len(channel_comb)
            train_dataset, val_dataset, test_dataset = split_load_normalize_data(random_state=seed_value, selected_channels=channel_comb)
            model_path = train_model(train_dataset, val_dataset, num_channels, channel_comb, n)
            load_and_eval_model(num_channels, test_dataset, model_path)
            os.remove(os.path.join(model_dir, model_path))