In [13]:
import torch
import torch.nn as nn
import torch.utils.data as data
import matplotlib.pyplot as plt
import anndata as ad
import scanpy as sc
import numpy as np
from tqdm.notebook import tqdm
import seaborn as sns
import pandas as pd
import shap
from sklearn.metrics import confusion_matrix, f1_score, precision_score, log_loss, accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import OneHotEncoder

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
torch.cuda.set_device(2)
print("Device", device)
adata = ad.read_h5ad('/home/brunopsz/Data/GSE155249_COUNTS_NOT_NORMALIZED.h5ad')
adata

Device cuda


AnnData object with n_obs × n_vars = 77146 × 20692
    obs: 'barcode_name', 'Sample', 'Cluster'
    var: 'gene_name'
    layers: 'counts'

In [3]:
#Zapomniałeś tego zrobić...
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.scale(adata)

In [None]:
sc.pp.highly_variable_genes(
    adata,
    flavor="seurat_v3",
    n_top_genes=5000,
    layer="counts",
    batch_key="Sample",
    subset=True,
)

In [4]:
ohe = OneHotEncoder(handle_unknown='ignore', sparse=False).fit(adata.obs['Cluster'].to_numpy().reshape(-1,1))
print(ohe.categories_)

print('\n')

labels_one_hot = ohe.transform(adata.obs['Cluster'].to_numpy().reshape(-1,1))
print(labels_one_hot)

[array(['AT2, AT1 cells', 'B cells', 'CD4 CM T cells',
       'CD4 cytotoxic T cells', 'CD4 prolif. T cells',
       'CD8 cytotoxic T cells', 'CD8 cytotoxic TRM T cells',
       'CD8 prolif. T cells', 'Ciliated cells', 'Club, Basal cells',
       'DC1', 'DC2', 'Infected AT2, AT1 cells', 'Ionocytes', 'Mast cells',
       'Migratory DC', 'Mixed myeloid', 'MoAM1', 'MoAM2', 'MoAM3',
       'MoAM4', 'Plasma cells', 'Prolif. AM', 'TRAM1', 'TRAM2', 'Treg',
       'iNKT cells', 'pDC'], dtype=object)]


[[0. 0. 0. ... 0. 1. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]


In [5]:
class SimpleRNAseqClassifier(nn.Module):

    def __init__(self, num_inputs, num_hidden, num_outputs):
        super().__init__()
        # Initialize the modules we need to build the network
        self.linear1 = nn.Linear(num_inputs, num_hidden)
        self.act_fn = nn.ReLU()
        self.linear2 = nn.Linear(num_hidden, num_outputs)

    def forward(self, x):
        # Perform the calculation of the model to determine the prediction
        x = self.linear1(x)
        x = self.act_fn(x)
        x = self.linear2(x)
        return x
    
class scRNAseqDataset(data.Dataset):

    def __init__(self, data, labels):
        super().__init__()
        self.data = data
        self.label = labels

    def __len__(self):
        # Number of data point we have. Alternatively self.data.shape[0], or self.label.shape[0]
        return self.data.shape[0]

    def __getitem__(self, idx):
        # Return the idx-th data point of the dataset
        # If we have multiple things to return (data point and label), we can return them as tuple
        data_point = self.data[idx]
        data_label = self.label[idx]
        return data_point, data_label

In [6]:
cell_types = adata.obs['Cluster'].unique()
number_of_cell_types = len(cell_types)
print("Number of cell types = " + str(number_of_cell_types))

Number of cell types = 28


In [8]:
def multi_acc(y_pred, y_test):
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)    
    
    _, actual = torch.max(y_test.data, 1)

    correct_pred = (y_pred_tags == actual).float()
    acc = correct_pred.sum() / len(correct_pred)
    
    acc = torch.round(acc * 100)
    
    return acc

def multi_f1(y_pred, y_test):
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)    
    _, actual = torch.max(y_test.data, 1)
    return f1_score(actual.cpu(), y_pred_tags.cpu(), average='weighted', labels=np.unique(y_pred_tags.cpu()))


def multi_precision(y_pred, y_test):
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)    
    _, actual = torch.max(y_test.data, 1)
    return precision_score(actual.cpu(), y_pred_tags.cpu(), average='weighted', labels=np.unique(y_pred_tags.cpu()))

In [9]:
accuracy_stats = {
    'train': [],
    'val': []
}

loss_stats = {
    'train': [],
    'val': []
}

f1_scores = {
    'train': [],
    'val': []
}

precision = {
    'train': [],
    'val': []
}

In [10]:
LEARNING_RATE = 0.0007

loss_module = nn.CrossEntropyLoss()

loss_module.to(device)

CrossEntropyLoss()

In [11]:
def index_to_cell_type(index):
    arr = np.zeros(number_of_cell_types)
    arr[index] = 1.0
    arr = arr.reshape(1,-1)
    return ohe.inverse_transform(arr)[0][0]

In [17]:
EPOCHS = 50
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=77)
cv_test_accuracy = []
cv_test_precision = []
cv_test_f1 = []
cv_test_roc = []

for i, (train_index, test_index) in enumerate(skf.split(adata.X, adata.obs['Cluster'])):
    
    model = SimpleRNAseqClassifier(num_inputs=adata.shape[1], num_hidden=64, num_outputs=number_of_cell_types)
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
    model.to(device)

    print(f'Begin training split {i}')
    train_dataset = scRNAseqDataset(torch.from_numpy(adata.X[train_index]), torch.from_numpy(labels_one_hot[train_index]))
    test_dataset = scRNAseqDataset(torch.from_numpy(adata.X[test_index]), torch.from_numpy(labels_one_hot[test_index]))

    train_loader = data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = data.DataLoader(test_dataset, batch_size=32, shuffle=False)

    for epoch in tqdm(range(1, EPOCHS+1)):

        #Create writer to use TensorBoard
        writer = SummaryWriter('runs/model_CV_on_base_model')
        model_plotted = False
        
        # TRAINING
        train_epoch_loss = 0
        train_epoch_acc = 0
        train_epoch_f1 = 0
        train_epoch_precision = 0
        model.train()

        for X_train_batch, y_train_batch in train_loader:
            #Push data to GPU
            X_train_batch, y_train_batch = X_train_batch.to(device), y_train_batch.to(device)

            if not model_plotted:
                writer.add_graph(model, X_train_batch)
                model_plotted = True
            
            optimizer.zero_grad()
            
            y_train_pred = model(X_train_batch)

            # #maybe?????
            y_train_pred = torch.log_softmax(y_train_pred, dim = 1)

            #Calculate train metrics
            train_loss = loss_module(y_train_pred, y_train_batch)
            train_acc = multi_acc(y_train_pred, y_train_batch)
            train_f1 = multi_f1(y_train_pred, y_train_batch)
            train_precision = multi_precision(y_train_pred, y_train_batch)

            train_loss.backward()
            optimizer.step()
            
            train_epoch_loss += train_loss.item()
            train_epoch_acc += train_acc.item()
            train_epoch_f1 += train_f1.item()
            train_epoch_precision += train_precision.item()
            
            
        # VALIDATION    
        with torch.no_grad():
            
            loss_stats['train'].append(train_epoch_loss/len(train_loader))
            writer.add_scalar('Loss/training', train_epoch_loss/len(train_loader), global_step = epoch + 1)
            
            f1_scores['train'].append(train_epoch_f1/len(train_loader))
            writer.add_scalar('F1_Score/training', train_epoch_f1/len(train_loader), global_step = epoch + 1)

            precision['train'].append(train_epoch_precision/len(train_loader))
            writer.add_scalar('Precision/training', train_epoch_precision/len(train_loader), global_step = epoch + 1)

            accuracy_stats['train'].append(train_epoch_acc/len(train_loader))
            writer.add_scalar('Accuracy/training', train_epoch_acc/len(train_loader), global_step = epoch + 1)
            
        print(f'({i}) Epoch {epoch+0:03}: | Train Loss: {train_epoch_loss/len(train_loader):.5f} | Train Acc: {train_epoch_acc/len(train_loader):.3f}')
    
    y_true_cells = []
    y_pred_cells = []

    y_true_prob = []
    y_pred_prob = []

    softmax = nn.Softmax(dim=1)
    
    for X_test_batch, y_test_batch in test_loader:
        #Push data to GPU
        X_test_batch, y_test_batch = X_test_batch.to(device), y_test_batch.to(device)	

        y_true_prob.extend(y_test_batch)

        _, actual = torch.max(y_test_batch.data, 1)	

        actural_cell_types = [index_to_cell_type(index) for index in actual.cpu().numpy()]

        y_true_cells.extend(actural_cell_types)	

        outputs = model(X_test_batch)	

        y_pred_prob.extend(softmax(outputs))

        #Czy tu nie powininem użyć softmax przed max???
        _, predicted = torch.max(softmax(outputs), 1)	

        predicted_cell_types = [index_to_cell_type(index) for index in predicted.cpu().numpy()]

        y_pred_cells.extend(predicted_cell_types)

        
    y_true_prob = [y.cpu().numpy() for y in y_true_prob]
    y_pred_prob = [y.cpu().detach().numpy() for y in y_pred_prob]
    acc = accuracy_score(y_true_cells, y_pred_cells)

    print(f"({i} test accuracy = {acc})")
    cv_test_accuracy.append(acc)
    cv_test_precision.append(precision_score(y_true_cells,y_pred_cells, average='weighted'))
    cv_test_f1.append(f1_score(y_true_cells, y_pred_cells, average='weighted'))
    cv_test_roc.append(roc_auc_score(y_true_prob, y_pred_prob, average="weighted", multi_class="ovr"))

writer.close()

Begin training split 0


  0%|          | 0/50 [00:00<?, ?it/s]

(0) Epoch 001: | Train Loss: 1.62181 | Train Acc: 57.725
(0) Epoch 002: | Train Loss: 0.71866 | Train Acc: 80.159
(0) Epoch 003: | Train Loss: 0.48861 | Train Acc: 86.442
(0) Epoch 004: | Train Loss: 0.37577 | Train Acc: 89.935
(0) Epoch 005: | Train Loss: 0.30633 | Train Acc: 92.065
(0) Epoch 006: | Train Loss: 0.25806 | Train Acc: 93.514
(0) Epoch 007: | Train Loss: 0.22174 | Train Acc: 94.676
(0) Epoch 008: | Train Loss: 0.19303 | Train Acc: 95.598
(0) Epoch 009: | Train Loss: 0.16963 | Train Acc: 96.324
(0) Epoch 010: | Train Loss: 0.15008 | Train Acc: 96.971
(0) Epoch 011: | Train Loss: 0.13350 | Train Acc: 97.505
(0) Epoch 012: | Train Loss: 0.11937 | Train Acc: 97.965
(0) Epoch 013: | Train Loss: 0.10721 | Train Acc: 98.284
(0) Epoch 014: | Train Loss: 0.09672 | Train Acc: 98.587
(0) Epoch 015: | Train Loss: 0.08765 | Train Acc: 98.844
(0) Epoch 016: | Train Loss: 0.07974 | Train Acc: 99.034
(0) Epoch 017: | Train Loss: 0.07281 | Train Acc: 99.206
(0) Epoch 018: | Train Loss: 0.

  0%|          | 0/50 [00:00<?, ?it/s]

(1) Epoch 001: | Train Loss: 1.62282 | Train Acc: 60.571
(1) Epoch 002: | Train Loss: 0.70221 | Train Acc: 79.847
(1) Epoch 003: | Train Loss: 0.48417 | Train Acc: 86.513
(1) Epoch 004: | Train Loss: 0.37430 | Train Acc: 90.014
(1) Epoch 005: | Train Loss: 0.30505 | Train Acc: 92.063
(1) Epoch 006: | Train Loss: 0.25648 | Train Acc: 93.581
(1) Epoch 007: | Train Loss: 0.21995 | Train Acc: 94.725
(1) Epoch 008: | Train Loss: 0.19098 | Train Acc: 95.652
(1) Epoch 009: | Train Loss: 0.16741 | Train Acc: 96.443
(1) Epoch 010: | Train Loss: 0.14779 | Train Acc: 97.100
(1) Epoch 011: | Train Loss: 0.13131 | Train Acc: 97.594
(1) Epoch 012: | Train Loss: 0.11725 | Train Acc: 98.018
(1) Epoch 013: | Train Loss: 0.10521 | Train Acc: 98.384
(1) Epoch 014: | Train Loss: 0.09483 | Train Acc: 98.685
(1) Epoch 015: | Train Loss: 0.08592 | Train Acc: 98.895
(1) Epoch 016: | Train Loss: 0.07811 | Train Acc: 99.091
(1) Epoch 017: | Train Loss: 0.07131 | Train Acc: 99.241
(1) Epoch 018: | Train Loss: 0.

  0%|          | 0/50 [00:00<?, ?it/s]

(2) Epoch 001: | Train Loss: 1.66664 | Train Acc: 61.141
(2) Epoch 002: | Train Loss: 0.67644 | Train Acc: 80.994
(2) Epoch 003: | Train Loss: 0.47277 | Train Acc: 86.759
(2) Epoch 004: | Train Loss: 0.36840 | Train Acc: 90.012
(2) Epoch 005: | Train Loss: 0.30189 | Train Acc: 92.153
(2) Epoch 006: | Train Loss: 0.25480 | Train Acc: 93.629
(2) Epoch 007: | Train Loss: 0.21914 | Train Acc: 94.739
(2) Epoch 008: | Train Loss: 0.19088 | Train Acc: 95.669
(2) Epoch 009: | Train Loss: 0.16764 | Train Acc: 96.439
(2) Epoch 010: | Train Loss: 0.14830 | Train Acc: 97.094
(2) Epoch 011: | Train Loss: 0.13187 | Train Acc: 97.625
(2) Epoch 012: | Train Loss: 0.11782 | Train Acc: 98.052
(2) Epoch 013: | Train Loss: 0.10572 | Train Acc: 98.390
(2) Epoch 014: | Train Loss: 0.09527 | Train Acc: 98.702
(2) Epoch 015: | Train Loss: 0.08623 | Train Acc: 98.927
(2) Epoch 016: | Train Loss: 0.07837 | Train Acc: 99.096
(2) Epoch 017: | Train Loss: 0.07149 | Train Acc: 99.255
(2) Epoch 018: | Train Loss: 0.

  0%|          | 0/50 [00:00<?, ?it/s]

(3) Epoch 001: | Train Loss: 1.61131 | Train Acc: 60.212
(3) Epoch 002: | Train Loss: 0.68407 | Train Acc: 81.022
(3) Epoch 003: | Train Loss: 0.47410 | Train Acc: 86.812
(3) Epoch 004: | Train Loss: 0.36526 | Train Acc: 90.124
(3) Epoch 005: | Train Loss: 0.29795 | Train Acc: 92.253
(3) Epoch 006: | Train Loss: 0.25101 | Train Acc: 93.759
(3) Epoch 007: | Train Loss: 0.21573 | Train Acc: 94.859
(3) Epoch 008: | Train Loss: 0.18772 | Train Acc: 95.781
(3) Epoch 009: | Train Loss: 0.16490 | Train Acc: 96.498
(3) Epoch 010: | Train Loss: 0.14577 | Train Acc: 97.068
(3) Epoch 011: | Train Loss: 0.12967 | Train Acc: 97.637
(3) Epoch 012: | Train Loss: 0.11586 | Train Acc: 98.016
(3) Epoch 013: | Train Loss: 0.10408 | Train Acc: 98.356
(3) Epoch 014: | Train Loss: 0.09389 | Train Acc: 98.664
(3) Epoch 015: | Train Loss: 0.08505 | Train Acc: 98.854
(3) Epoch 016: | Train Loss: 0.07736 | Train Acc: 99.028
(3) Epoch 017: | Train Loss: 0.07063 | Train Acc: 99.209
(3) Epoch 018: | Train Loss: 0.

  0%|          | 0/50 [00:00<?, ?it/s]

(4) Epoch 001: | Train Loss: 1.61725 | Train Acc: 59.700
(4) Epoch 002: | Train Loss: 0.70548 | Train Acc: 79.426
(4) Epoch 003: | Train Loss: 0.49001 | Train Acc: 86.126
(4) Epoch 004: | Train Loss: 0.37850 | Train Acc: 89.853
(4) Epoch 005: | Train Loss: 0.30846 | Train Acc: 92.087
(4) Epoch 006: | Train Loss: 0.25962 | Train Acc: 93.553
(4) Epoch 007: | Train Loss: 0.22289 | Train Acc: 94.747
(4) Epoch 008: | Train Loss: 0.19387 | Train Acc: 95.653
(4) Epoch 009: | Train Loss: 0.17010 | Train Acc: 96.418
(4) Epoch 010: | Train Loss: 0.15038 | Train Acc: 97.070
(4) Epoch 011: | Train Loss: 0.13358 | Train Acc: 97.609
(4) Epoch 012: | Train Loss: 0.11924 | Train Acc: 98.030
(4) Epoch 013: | Train Loss: 0.10694 | Train Acc: 98.389
(4) Epoch 014: | Train Loss: 0.09633 | Train Acc: 98.669
(4) Epoch 015: | Train Loss: 0.08715 | Train Acc: 98.913
(4) Epoch 016: | Train Loss: 0.07916 | Train Acc: 99.094
(4) Epoch 017: | Train Loss: 0.07215 | Train Acc: 99.251
(4) Epoch 018: | Train Loss: 0.

  0%|          | 0/50 [00:00<?, ?it/s]

(5) Epoch 001: | Train Loss: 1.62370 | Train Acc: 60.496
(5) Epoch 002: | Train Loss: 0.68716 | Train Acc: 81.075
(5) Epoch 003: | Train Loss: 0.48071 | Train Acc: 86.319
(5) Epoch 004: | Train Loss: 0.37376 | Train Acc: 89.739
(5) Epoch 005: | Train Loss: 0.30618 | Train Acc: 91.925
(5) Epoch 006: | Train Loss: 0.25861 | Train Acc: 93.509
(5) Epoch 007: | Train Loss: 0.22256 | Train Acc: 94.644
(5) Epoch 008: | Train Loss: 0.19406 | Train Acc: 95.623
(5) Epoch 009: | Train Loss: 0.17072 | Train Acc: 96.330
(5) Epoch 010: | Train Loss: 0.15123 | Train Acc: 96.943
(5) Epoch 011: | Train Loss: 0.13476 | Train Acc: 97.504
(5) Epoch 012: | Train Loss: 0.12069 | Train Acc: 97.931
(5) Epoch 013: | Train Loss: 0.10854 | Train Acc: 98.257
(5) Epoch 014: | Train Loss: 0.09804 | Train Acc: 98.591
(5) Epoch 015: | Train Loss: 0.08893 | Train Acc: 98.805
(5) Epoch 016: | Train Loss: 0.08100 | Train Acc: 99.008
(5) Epoch 017: | Train Loss: 0.07399 | Train Acc: 99.179
(5) Epoch 018: | Train Loss: 0.

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.


  0%|          | 0/50 [00:00<?, ?it/s]

(6) Epoch 001: | Train Loss: 1.60170 | Train Acc: 60.921
(6) Epoch 002: | Train Loss: 0.67826 | Train Acc: 81.066
(6) Epoch 003: | Train Loss: 0.47220 | Train Acc: 86.834
(6) Epoch 004: | Train Loss: 0.36781 | Train Acc: 90.093
(6) Epoch 005: | Train Loss: 0.30196 | Train Acc: 92.197
(6) Epoch 006: | Train Loss: 0.25540 | Train Acc: 93.733
(6) Epoch 007: | Train Loss: 0.21996 | Train Acc: 94.789
(6) Epoch 008: | Train Loss: 0.19171 | Train Acc: 95.775
(6) Epoch 009: | Train Loss: 0.16859 | Train Acc: 96.475
(6) Epoch 010: | Train Loss: 0.14921 | Train Acc: 97.117
(6) Epoch 011: | Train Loss: 0.13279 | Train Acc: 97.613
(6) Epoch 012: | Train Loss: 0.11877 | Train Acc: 98.041
(6) Epoch 013: | Train Loss: 0.10670 | Train Acc: 98.347
(6) Epoch 014: | Train Loss: 0.09628 | Train Acc: 98.629
(6) Epoch 015: | Train Loss: 0.08723 | Train Acc: 98.854
(6) Epoch 016: | Train Loss: 0.07937 | Train Acc: 99.059
(6) Epoch 017: | Train Loss: 0.07251 | Train Acc: 99.241
(6) Epoch 018: | Train Loss: 0.

  0%|          | 0/50 [00:00<?, ?it/s]

(7) Epoch 001: | Train Loss: 1.64859 | Train Acc: 57.459
(7) Epoch 002: | Train Loss: 0.71121 | Train Acc: 79.875
(7) Epoch 003: | Train Loss: 0.49212 | Train Acc: 86.489
(7) Epoch 004: | Train Loss: 0.37949 | Train Acc: 89.905
(7) Epoch 005: | Train Loss: 0.30853 | Train Acc: 92.089
(7) Epoch 006: | Train Loss: 0.25831 | Train Acc: 93.616
(7) Epoch 007: | Train Loss: 0.22084 | Train Acc: 94.816
(7) Epoch 008: | Train Loss: 0.19151 | Train Acc: 95.785
(7) Epoch 009: | Train Loss: 0.16767 | Train Acc: 96.484
(7) Epoch 010: | Train Loss: 0.14799 | Train Acc: 97.141
(7) Epoch 011: | Train Loss: 0.13140 | Train Acc: 97.628
(7) Epoch 012: | Train Loss: 0.11727 | Train Acc: 98.045
(7) Epoch 013: | Train Loss: 0.10523 | Train Acc: 98.388
(7) Epoch 014: | Train Loss: 0.09485 | Train Acc: 98.694
(7) Epoch 015: | Train Loss: 0.08588 | Train Acc: 98.929
(7) Epoch 016: | Train Loss: 0.07807 | Train Acc: 99.119
(7) Epoch 017: | Train Loss: 0.07126 | Train Acc: 99.265
(7) Epoch 018: | Train Loss: 0.

  0%|          | 0/50 [00:00<?, ?it/s]

(8) Epoch 001: | Train Loss: 1.61838 | Train Acc: 58.876
(8) Epoch 002: | Train Loss: 0.69385 | Train Acc: 80.758
(8) Epoch 003: | Train Loss: 0.47784 | Train Acc: 86.666
(8) Epoch 004: | Train Loss: 0.36888 | Train Acc: 89.968
(8) Epoch 005: | Train Loss: 0.30164 | Train Acc: 92.115
(8) Epoch 006: | Train Loss: 0.25457 | Train Acc: 93.577
(8) Epoch 007: | Train Loss: 0.21894 | Train Acc: 94.688
(8) Epoch 008: | Train Loss: 0.19067 | Train Acc: 95.667
(8) Epoch 009: | Train Loss: 0.16753 | Train Acc: 96.382
(8) Epoch 010: | Train Loss: 0.14816 | Train Acc: 97.063
(8) Epoch 011: | Train Loss: 0.13184 | Train Acc: 97.608
(8) Epoch 012: | Train Loss: 0.11790 | Train Acc: 98.009
(8) Epoch 013: | Train Loss: 0.10595 | Train Acc: 98.356
(8) Epoch 014: | Train Loss: 0.09560 | Train Acc: 98.631
(8) Epoch 015: | Train Loss: 0.08669 | Train Acc: 98.887
(8) Epoch 016: | Train Loss: 0.07891 | Train Acc: 99.085
(8) Epoch 017: | Train Loss: 0.07210 | Train Acc: 99.209
(8) Epoch 018: | Train Loss: 0.

  0%|          | 0/50 [00:00<?, ?it/s]

(9) Epoch 001: | Train Loss: 1.63113 | Train Acc: 59.811
(9) Epoch 002: | Train Loss: 0.70406 | Train Acc: 80.969
(9) Epoch 003: | Train Loss: 0.48579 | Train Acc: 86.459
(9) Epoch 004: | Train Loss: 0.37368 | Train Acc: 89.870
(9) Epoch 005: | Train Loss: 0.30411 | Train Acc: 91.941
(9) Epoch 006: | Train Loss: 0.25549 | Train Acc: 93.578
(9) Epoch 007: | Train Loss: 0.21913 | Train Acc: 94.774
(9) Epoch 008: | Train Loss: 0.19059 | Train Acc: 95.709
(9) Epoch 009: | Train Loss: 0.16730 | Train Acc: 96.439
(9) Epoch 010: | Train Loss: 0.14796 | Train Acc: 97.107
(9) Epoch 011: | Train Loss: 0.13159 | Train Acc: 97.628
(9) Epoch 012: | Train Loss: 0.11770 | Train Acc: 98.003
(9) Epoch 013: | Train Loss: 0.10575 | Train Acc: 98.341
(9) Epoch 014: | Train Loss: 0.09545 | Train Acc: 98.630
(9) Epoch 015: | Train Loss: 0.08651 | Train Acc: 98.851
(9) Epoch 016: | Train Loss: 0.07870 | Train Acc: 99.069
(9) Epoch 017: | Train Loss: 0.07189 | Train Acc: 99.229
(9) Epoch 018: | Train Loss: 0.

In [18]:
cv_test_roc

[0.9944082318810038,
 0.9945067036905018,
 0.9940895178350934,
 0.9943080241115064,
 0.993986281305609,
 0.9939387059432859,
 0.994417169876028,
 0.9940096865711178,
 0.9942670833312208,
 0.9942455829537971]

In [19]:
cv_test_f1

[0.8833304984852983,
 0.8862925357536315,
 0.8857507995152686,
 0.8840204642926202,
 0.8837404533736724,
 0.8832387608811492,
 0.8858592420851722,
 0.8796493319598641,
 0.8858261161543948,
 0.8862438887676163]