In [1]:
import numpy as np
import torch.nn as nn
from skorch import NeuralNetClassifier

ModuleNotFoundError: No module named 'skorch'

In [3]:
import umap

In [None]:
import warnings
warnings.filterwarnings("ignore")
import os
import torch
from torchvision import transforms
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset

In [None]:
from sklearn.metrics import classification_report

In [None]:
from imblearn.over_sampling import RandomOverSampler
from iflai.dl.util import read_data, get_statistics_h5, calculate_weights
from iflai.dl.dataset import train_validation_test_split_wth_augmentation, Dataset_Generator_Preprocessed_h5
from iflai.ml.feature_extractor import AmnisData
from iflai.dl.models import PretrainedModel

In [None]:
from skorch.callbacks import LRScheduler, Checkpoint
import torch.optim as optim
from skorch.helper import predefined_split

In [None]:
seed_value = 42

os.environ['PYTHONHASHSEED']=str(seed_value)
import random
random.seed(seed_value)

np.random.seed(seed_value)
torch.manual_seed(seed_value)

In [None]:
dataset_name = "wbc"
only_channels = [0,1,2,3,4,5,6,7,8,9,10,11]
path_to_data ="../../data/WBC"
scaling_factor = 255.
reshape_size = 64
num_channels = len(only_channels)
train_transform = transforms.Compose(
        [transforms.RandomVerticalFlip(),
         transforms.RandomHorizontalFlip(),
         transforms.RandomRotation(45)])
test_transform = transforms.Compose([])
batch_size = 256
num_workers = 2
dev="cuda"

In [None]:
# amnis_data = AmnisData(path_to_data, None)

In [None]:
X, y, CLASS_NAMES, data_map = read_data(path_to_data)

In [None]:
num_classes = len(data_map.keys())
train_indx, validation_indx, test_indx = train_validation_test_split_wth_augmentation(X, y, only_classes=None)

In [None]:
train_dataset = Dataset_Generator_Preprocessed_h5(path_to_data=path_to_data,
                                                      set_indx=train_indx,
                                                      scaling_factor=scaling_factor,
                                                      reshape_size=reshape_size,
                                                      transform=train_transform,
                                                      data_map=data_map,
                                                      only_channels=only_channels,
                                                      num_channels=num_channels)

trainloader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=num_workers)

In [None]:
statistics = get_statistics_h5(trainloader, only_channels, None, num_channels)

In [None]:
y_train = [data_map.get(y[i]) for i in train_indx]
weights = calculate_weights(y_train)
class_weights = torch.FloatTensor(weights).to(dev)
oversample = RandomOverSampler(random_state=seed_value, sampling_strategy='all')
train_indx, y_train = oversample.fit_resample(np.asarray(train_indx).reshape(-1, 1), np.asarray(y_train))
train_indx = train_indx.T[0]
y_train = [data_map.get(y[i]) for i in train_indx]

In [None]:
train_dataset = Dataset_Generator_Preprocessed_h5(path_to_data=path_to_data,
                                                      set_indx=train_indx,
                                                      scaling_factor=scaling_factor,
                                                      reshape_size=reshape_size,
                                                      transform=train_transform,
                                                      data_map=data_map,
                                                      only_channels=only_channels,
                                                      num_channels=num_channels,
                                                      means=statistics["mean"],
                                                      stds=statistics["std"],
                                                  return_only_image=True,
                                                      )

validation_dataset = Dataset_Generator_Preprocessed_h5(path_to_data=path_to_data,
                                                           set_indx=validation_indx,
                                                           scaling_factor=scaling_factor,
                                                           reshape_size=reshape_size,
                                                           transform=test_transform,
                                                           data_map=data_map,
                                                           only_channels=only_channels,
                                                           num_channels=num_channels,
                                                           means=statistics["mean"],
                                                           stds=statistics["std"],
                                                       return_only_image=True,
                                                           )

test_dataset = Dataset_Generator_Preprocessed_h5(path_to_data=path_to_data,
                                                     set_indx=test_indx,
                                                     scaling_factor=scaling_factor,
                                                     reshape_size=reshape_size,
                                                     transform=test_transform,
                                                     data_map=data_map,
                                                     only_channels=only_channels,
                                                     num_channels=num_channels,
                                                     means=statistics["mean"],
                                                     stds=statistics["std"],
                                                 return_only_image=True,
                                                     )

In [None]:
lrscheduler = LRScheduler(policy='StepLR', step_size=7, gamma=0.5)
checkpoint = Checkpoint(f_params='wbs_net_all.pth', monitor='valid_loss_best')

In [None]:
net = NeuralNetClassifier(
    PretrainedModel, 
    criterion=nn.CrossEntropyLoss,
    criterion__weight=class_weights,
    lr=1e-5,
    batch_size=256,
    max_epochs=10,
    module__output_features=num_classes,
    module__num_channels=num_channels, 
    optimizer=optim.Adam,
    iterator_train__shuffle=False,
    iterator_train__num_workers=2,
    iterator_valid__shuffle=False,
    iterator_valid__num_workers=2,
    callbacks=[lrscheduler, checkpoint],
    train_split=predefined_split(validation_dataset),
    device=dev
)

In [None]:
net.fit(train_dataset, y=None)

In [None]:
y_pred_net = net.predict(test_dataset)

In [None]:
print(classification_report(y_true, y_pred_net, target_names=class_names, digits=4))