In [1]:
import psutil
import joblib
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import torch.nn as nn
from sklearn.model_selection import KFold
from torch.utils.data import Dataset, DataLoader
import gc
import os
import warnings
warnings.filterwarnings("ignore")

In [25]:
MAX_SEQ = 400
n_part = data['sub_chapter_id'].nunique() + 1
D_MODEL = 128
N_LAYER = 2
DROPOUT = 0.2

In [26]:
class FFN(nn.Module):
    def __init__(self, state_size=200):
        super(FFN, self).__init__()
        self.state_size = state_size

        self.lr1 = nn.Linear(state_size, state_size)
        self.relu = nn.ReLU()
        self.lr2 = nn.Linear(state_size, state_size)
        self.dropout = nn.Dropout(DROPOUT)
    
    def forward(self, x):
        x = self.lr1(x)
        x = self.relu(x)
        x = self.lr2(x)
        return self.dropout(x)

def future_mask(seq_length):
    future_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')
    return torch.from_numpy(future_mask)


class SAINTModel(nn.Module):
    def __init__(self, n_skill, n_part, max_seq=MAX_SEQ, embed_dim= D_MODEL, elapsed_time_cat_flag = False):
        super(SAINTModel, self).__init__()

        self.n_skill = n_skill
        self.embed_dim = embed_dim
        self.n_chapter= 39
        self.n_sub_chapter = n_part
        self.elapsed_time_cat_flag = elapsed_time_cat_flag

        self.q_embedding = nn.Embedding(self.n_skill+1, embed_dim) ## exercise
        self.c_embedding = nn.Embedding(self.n_chapter+1, embed_dim) ## category
        self.sc_embedding = nn.Embedding(self.n_sub_chapter, embed_dim) ## category
        self.pos_embedding = nn.Embedding(max_seq+1, embed_dim) ## position
        self.res_embedding = nn.Embedding(2+1, embed_dim) ## response
        self.feat_embedding = nn.Linear(2, embed_dim) ## temporal
    



        self.transformer = nn.Transformer(nhead=8, d_model = embed_dim, num_encoder_layers= N_LAYER, num_decoder_layers= N_LAYER, dropout = DROPOUT)

        self.dropout = nn.Dropout(DROPOUT)
        self.layer_normal = nn.LayerNorm(embed_dim) 
        self.ffn = FFN(embed_dim)
        self.pred = nn.Linear(embed_dim, 1)
    
    def forward(self, question, chapter, schapter, response, user_features):

        device = question.device  
        ## embedding layer
        question = self.q_embedding(question)
        chapter = self.c_embedding(chapter)
        schapter = self.sc_embedding(schapter)
        pos_id = torch.arange(question.size(1)).unsqueeze(0).to(device)
        pos_id = self.pos_embedding(pos_id)
        res = self.res_embedding(response)
        user_features = self.feat_embedding(user_features)
        

        enc = pos_id + question + chapter + schapter 
        dec = pos_id + res + enc + user_features
        enc = enc.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]
        dec = dec.permute(1, 0, 2)
        mask = future_mask(enc.size(0)).to(device)
        att_output = self.transformer(enc, dec, src_mask=mask, tgt_mask=mask, memory_mask = mask)
        att_output = self.layer_normal(att_output)
        att_output = att_output.permute(1, 0, 2) # att_output: [s_len, bs, embed] => [bs, s_len, embed]
        
        
        x = self.ffn(att_output)
        x = self.layer_normal(x + att_output)
        x = self.pred(x)

        return x.squeeze(-1)

In [32]:
patience = 5

In [34]:
X = np.array(group.keys())
kfold = KFold(n_splits=5, shuffle=True)
train_losses = list()
train_aucs = list()
train_accs = list()
val_losses = list()
val_aucs = list()
val_accs = list()
test_losses = list()
test_aucs = list()
test_accs = list()
for train, test in kfold.split(X):
    users_train, users_test =  X[train], X[test]
    n = len(users_test)//2
    users_test, users_val = users_test[:n], users_test[n: ]
    train = PRACTICE_DATASET(group[users_train])
    valid = PRACTICE_DATASET(group[users_val])
    test = PRACTICE_DATASET(group[users_test])
    train_dataloader = DataLoader(train, batch_size=32, shuffle=True, num_workers=8)
    val_dataloader = DataLoader(valid, batch_size=32, shuffle=True, num_workers=8)
    test_dataloader = DataLoader(test, batch_size=32, shuffle=True, num_workers=8)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    saint = SAINTModel(n_skill, n_part)
    epochs = 100
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(saint.parameters(), betas=(0.9, 0.999), lr = 0.0005, eps=1e-8)
    saint.to(device)
    criterion.to(device)
    
    def train_epoch(model=saint, train_iterator=train_dataloader, optim=optimizer, criterion=criterion, device=device):
        model.train()

        train_loss = []
        num_corrects = 0
        num_total = 0
        labels = []
        outs = []
        tbar = tqdm(train_iterator)
        for item in tbar:
            question_id = item[0].to(device).long()
            chapter = item[1].to(device).long()
            schapter = item[2].to(device).long()
            responses = item[3].to(device).long()
            user_feats = item[4].to(device).float()
            label = item[5].to(device).float()            
            target_mask = (question_id!=0)
            optim.zero_grad()
            output = model(question_id, chapter, schapter, responses, user_feats)
            output = torch.reshape(output, label.shape)

            output = torch.masked_select(output, target_mask)
            label = torch.masked_select(label, target_mask)

            loss = criterion(output, label)
            loss.backward()
            optim.step()
            train_loss.append(loss.item())
            pred = (torch.sigmoid(output) >= 0.5).long()

            num_corrects += (pred == label).sum().item()
            num_total += len(label)

            labels.extend(label.view(-1).data.cpu().numpy())
            outs.extend(output.view(-1).data.cpu().numpy())

            tbar.set_description('loss - {:.4f}'.format(loss))
        acc = num_corrects / num_total
        auc = roc_auc_score(labels, outs)
        loss = np.mean(train_loss)

        return loss, acc, auc
   

    def val_epoch(model=saint, val_iterator=test_dataloader, 
              criterion=criterion, device=device):
        model.eval()

        train_loss = []
        num_corrects = 0
        num_total = 0
        labels = []
        outs = []
        tbar = tqdm(val_iterator)
        for item in tbar:
            question_id = item[0].to(device).long()
            chapter = item[1].to(device).long()
            schapter = item[2].to(device).long()
            responses = item[3].to(device).long()
            user_feats = item[4].to(device).float()
            label = item[5].to(device).float()            
            target_mask = (question_id!=0)
            with torch.no_grad():
                output = model(question_id, chapter, schapter, responses, user_feats)

            output = torch.reshape(output, label.shape)
            output = torch.masked_select(output, target_mask)
            label = torch.masked_select(label, target_mask)

            loss = criterion(output, label)
            train_loss.append(loss.item())

            pred = (torch.sigmoid(output) >= 0.5).long()
            num_corrects += (pred == label).sum().item()
            num_total += len(label)

            labels.extend(label.view(-1).data.cpu().numpy())
            outs.extend(output.view(-1).data.cpu().numpy())

            tbar.set_description('valid loss - {:.4f}'.format(loss))

        acc = num_corrects / num_total
        auc = roc_auc_score(labels, outs)
        loss = np.average(train_loss)

        return loss, acc, auc
    
    MIN_VAL = 1000000000
    count = 0
    print('----------------------------------------------------------------------------')
    for epoch in range(epochs):
        train_loss, train_acc, train_auc = train_epoch(model=saint, device=device)
        print("epoch - {} train_loss - {:.2f} acc - {:.3f} auc - {:.3f}".format(epoch, train_loss, train_acc, train_auc))
        val_loss, val_acc, val_auc = val_epoch(model=saint, val_iterator= val_dataloader, device=device)
        print("epoch - {} val_loss - {:.2f} val acc - {:.3f} val auc - {:.3f}".format(epoch, val_loss, val_acc, val_auc))
        if val_loss < MIN_VAL:
            count = 0
            MIN_VAL = val_loss
        else:
            count += 1

        if count == patience:
            print('Val Loss does not improve for {} consecutive epochs'.format(patience))
            break
    test_loss, test_acc, test_auc = val_epoch(model=saint, device=device)
    print("epoch - {} test_loss - {:.2f} acc - {:.3f} auc - {:.3f}".format(epoch, test_loss, test_acc, test_auc))
    test_losses.append(test_loss)
    test_aucs.append(test_auc)
    test_accs.append(test_acc)
    train_aucs.append(train_auc)
    train_losses.append(train_loss)
    train_accs.append(train_acc)

  0%|          | 0/28 [00:00<?, ?it/s]

----------------------------------------------------------------------------


loss - 0.6070: 100%|██████████| 28/28 [00:05<00:00,  4.69it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 0 train_loss - 0.63 acc - 0.696 auc - 0.545


valid loss - 0.6843: 100%|██████████| 4/4 [00:00<00:00,  6.75it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 0 val_loss - 0.65 val acc - 0.652 val auc - 0.640


loss - 0.5555: 100%|██████████| 28/28 [00:05<00:00,  5.39it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 1 train_loss - 0.55 acc - 0.721 auc - 0.709


valid loss - 0.5890: 100%|██████████| 4/4 [00:00<00:00,  6.60it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 1 val_loss - 0.60 val acc - 0.677 val auc - 0.690


loss - 0.5007: 100%|██████████| 28/28 [00:05<00:00,  5.37it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 2 train_loss - 0.52 acc - 0.735 auc - 0.748


valid loss - 0.5749: 100%|██████████| 4/4 [00:00<00:00,  7.13it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 2 val_loss - 0.59 val acc - 0.686 val auc - 0.715


loss - 0.5144: 100%|██████████| 28/28 [00:05<00:00,  5.48it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 3 train_loss - 0.50 acc - 0.744 auc - 0.771


valid loss - 0.5673: 100%|██████████| 4/4 [00:00<00:00,  7.60it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 3 val_loss - 0.57 val acc - 0.693 val auc - 0.730


loss - 0.5131: 100%|██████████| 28/28 [00:05<00:00,  5.23it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 4 train_loss - 0.49 acc - 0.754 auc - 0.790


valid loss - 0.5359: 100%|██████████| 4/4 [00:00<00:00,  7.70it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 4 val_loss - 0.56 val acc - 0.700 val auc - 0.746


loss - 0.4458: 100%|██████████| 28/28 [00:05<00:00,  5.39it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 5 train_loss - 0.47 acc - 0.763 auc - 0.806


valid loss - 0.5912: 100%|██████████| 4/4 [00:00<00:00,  7.59it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 5 val_loss - 0.57 val acc - 0.695 val auc - 0.755


loss - 0.3918: 100%|██████████| 28/28 [00:05<00:00,  5.54it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 6 train_loss - 0.46 acc - 0.770 auc - 0.817


valid loss - 0.5174: 100%|██████████| 4/4 [00:00<00:00,  7.23it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 6 val_loss - 0.55 val acc - 0.713 val auc - 0.762


loss - 0.4548: 100%|██████████| 28/28 [00:05<00:00,  5.18it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 7 train_loss - 0.45 acc - 0.776 auc - 0.827


valid loss - 0.5702: 100%|██████████| 4/4 [00:00<00:00,  7.08it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 7 val_loss - 0.55 val acc - 0.713 val auc - 0.764


loss - 0.4466: 100%|██████████| 28/28 [00:05<00:00,  5.53it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 8 train_loss - 0.44 acc - 0.782 auc - 0.836


valid loss - 0.5508: 100%|██████████| 4/4 [00:00<00:00,  6.33it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 8 val_loss - 0.55 val acc - 0.714 val auc - 0.768


loss - 0.4110: 100%|██████████| 28/28 [00:05<00:00,  5.30it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 9 train_loss - 0.43 acc - 0.787 auc - 0.844


valid loss - 0.5687: 100%|██████████| 4/4 [00:00<00:00,  7.08it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 9 val_loss - 0.55 val acc - 0.718 val auc - 0.770


loss - 0.4283: 100%|██████████| 28/28 [00:05<00:00,  5.31it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 10 train_loss - 0.43 acc - 0.791 auc - 0.849


valid loss - 0.5367: 100%|██████████| 4/4 [00:00<00:00,  7.11it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 10 val_loss - 0.55 val acc - 0.721 val auc - 0.772


loss - 0.4256: 100%|██████████| 28/28 [00:05<00:00,  5.39it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 11 train_loss - 0.42 acc - 0.794 auc - 0.855


valid loss - 0.5701: 100%|██████████| 4/4 [00:00<00:00,  6.78it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 11 val_loss - 0.56 val acc - 0.718 val auc - 0.772


loss - 0.3865: 100%|██████████| 28/28 [00:05<00:00,  5.36it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 12 train_loss - 0.41 acc - 0.799 auc - 0.860


valid loss - 0.5896: 100%|██████████| 4/4 [00:00<00:00,  7.21it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 12 val_loss - 0.56 val acc - 0.718 val auc - 0.773


loss - 0.4013: 100%|██████████| 28/28 [00:05<00:00,  5.42it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 13 train_loss - 0.41 acc - 0.802 auc - 0.865


valid loss - 0.5562: 100%|██████████| 4/4 [00:00<00:00,  7.25it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 13 val_loss - 0.54 val acc - 0.723 val auc - 0.774


loss - 0.4086: 100%|██████████| 28/28 [00:05<00:00,  5.02it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 14 train_loss - 0.40 acc - 0.805 auc - 0.869


valid loss - 0.5526: 100%|██████████| 4/4 [00:00<00:00,  7.27it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 14 val_loss - 0.55 val acc - 0.722 val auc - 0.774


loss - 0.3061: 100%|██████████| 28/28 [00:05<00:00,  5.39it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 15 train_loss - 0.39 acc - 0.809 auc - 0.873


valid loss - 0.5563: 100%|██████████| 4/4 [00:00<00:00,  7.50it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 15 val_loss - 0.55 val acc - 0.714 val auc - 0.773


loss - 0.3862: 100%|██████████| 28/28 [00:05<00:00,  5.52it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 16 train_loss - 0.39 acc - 0.813 auc - 0.879


valid loss - 0.5758: 100%|██████████| 4/4 [00:00<00:00,  7.39it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 16 val_loss - 0.56 val acc - 0.718 val auc - 0.770


loss - 0.3941: 100%|██████████| 28/28 [00:05<00:00,  5.54it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 17 train_loss - 0.38 acc - 0.816 auc - 0.882


valid loss - 0.5167: 100%|██████████| 4/4 [00:00<00:00,  7.08it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 17 val_loss - 0.55 val acc - 0.720 val auc - 0.772


loss - 0.4052: 100%|██████████| 28/28 [00:05<00:00,  5.44it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 18 train_loss - 0.38 acc - 0.818 auc - 0.885


valid loss - 0.5547: 100%|██████████| 4/4 [00:00<00:00,  7.42it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 18 val_loss - 0.55 val acc - 0.717 val auc - 0.771
Val Loss does not improve for 5 consecutive epochs


valid loss - 0.3941: 100%|██████████| 4/4 [00:00<00:00,  7.11it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 18 test_loss - 0.39 acc - 0.817 auc - 0.854
----------------------------------------------------------------------------


loss - 0.5828: 100%|██████████| 28/28 [00:05<00:00,  5.13it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 0 train_loss - 0.65 acc - 0.692 auc - 0.521


valid loss - 0.6981: 100%|██████████| 4/4 [00:00<00:00,  6.83it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 0 val_loss - 0.68 val acc - 0.632 val auc - 0.652


loss - 0.5072: 100%|██████████| 28/28 [00:05<00:00,  5.29it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 1 train_loss - 0.56 acc - 0.718 auc - 0.683


valid loss - 0.6122: 100%|██████████| 4/4 [00:00<00:00,  5.37it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 1 val_loss - 0.61 val acc - 0.669 val auc - 0.684


loss - 0.5060: 100%|██████████| 28/28 [00:06<00:00,  4.50it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 2 train_loss - 0.53 acc - 0.730 auc - 0.736


valid loss - 0.5971: 100%|██████████| 4/4 [00:00<00:00,  5.82it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 2 val_loss - 0.60 val acc - 0.677 val auc - 0.705


loss - 0.5144: 100%|██████████| 28/28 [00:06<00:00,  4.53it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 3 train_loss - 0.51 acc - 0.741 auc - 0.759


valid loss - 0.6114: 100%|██████████| 4/4 [00:00<00:00,  7.24it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 3 val_loss - 0.59 val acc - 0.685 val auc - 0.720


loss - 0.4935: 100%|██████████| 28/28 [00:05<00:00,  5.53it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 4 train_loss - 0.49 acc - 0.752 auc - 0.781


valid loss - 0.5678: 100%|██████████| 4/4 [00:00<00:00,  5.73it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 4 val_loss - 0.58 val acc - 0.690 val auc - 0.734


loss - 0.4233: 100%|██████████| 28/28 [00:06<00:00,  4.64it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 5 train_loss - 0.48 acc - 0.761 auc - 0.800


valid loss - 0.5751: 100%|██████████| 4/4 [00:00<00:00,  7.17it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 5 val_loss - 0.57 val acc - 0.699 val auc - 0.748


loss - 0.4607: 100%|██████████| 28/28 [00:05<00:00,  5.37it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 6 train_loss - 0.46 acc - 0.769 auc - 0.814


valid loss - 0.5811: 100%|██████████| 4/4 [00:00<00:00,  6.81it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 6 val_loss - 0.57 val acc - 0.693 val auc - 0.753


loss - 0.4336: 100%|██████████| 28/28 [00:05<00:00,  5.54it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 7 train_loss - 0.45 acc - 0.776 auc - 0.825


valid loss - 0.5863: 100%|██████████| 4/4 [00:00<00:00,  6.91it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 7 val_loss - 0.57 val acc - 0.704 val auc - 0.758


loss - 0.4142: 100%|██████████| 28/28 [00:05<00:00,  5.38it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 8 train_loss - 0.44 acc - 0.781 auc - 0.834


valid loss - 0.5342: 100%|██████████| 4/4 [00:00<00:00,  6.98it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 8 val_loss - 0.56 val acc - 0.707 val auc - 0.764


loss - 0.3720: 100%|██████████| 28/28 [00:05<00:00,  5.47it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 9 train_loss - 0.44 acc - 0.785 auc - 0.840


valid loss - 0.5913: 100%|██████████| 4/4 [00:00<00:00,  6.79it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 9 val_loss - 0.57 val acc - 0.708 val auc - 0.764


loss - 0.3760: 100%|██████████| 28/28 [00:05<00:00,  4.81it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 10 train_loss - 0.43 acc - 0.790 auc - 0.847


valid loss - 0.5209: 100%|██████████| 4/4 [00:00<00:00,  6.72it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 10 val_loss - 0.55 val acc - 0.710 val auc - 0.767


loss - 0.4102: 100%|██████████| 28/28 [00:05<00:00,  5.41it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 11 train_loss - 0.42 acc - 0.793 auc - 0.852


valid loss - 0.5663: 100%|██████████| 4/4 [00:00<00:00,  6.37it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 11 val_loss - 0.56 val acc - 0.706 val auc - 0.768


loss - 0.4186: 100%|██████████| 28/28 [00:05<00:00,  5.26it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 12 train_loss - 0.42 acc - 0.796 auc - 0.857


valid loss - 0.6059: 100%|██████████| 4/4 [00:00<00:00,  6.50it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 12 val_loss - 0.56 val acc - 0.712 val auc - 0.769


loss - 0.3914: 100%|██████████| 28/28 [00:05<00:00,  5.36it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 13 train_loss - 0.41 acc - 0.801 auc - 0.862


valid loss - 0.5148: 100%|██████████| 4/4 [00:00<00:00,  6.66it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 13 val_loss - 0.56 val acc - 0.708 val auc - 0.768


loss - 0.3755: 100%|██████████| 28/28 [00:05<00:00,  5.46it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 14 train_loss - 0.40 acc - 0.804 auc - 0.867


valid loss - 0.6396: 100%|██████████| 4/4 [00:00<00:00,  7.16it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 14 val_loss - 0.58 val acc - 0.709 val auc - 0.770


loss - 0.3826: 100%|██████████| 28/28 [00:05<00:00,  5.06it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 15 train_loss - 0.40 acc - 0.807 auc - 0.871


valid loss - 0.5682: 100%|██████████| 4/4 [00:00<00:00,  7.13it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 15 val_loss - 0.57 val acc - 0.710 val auc - 0.769
Val Loss does not improve for 5 consecutive epochs


valid loss - 0.3234: 100%|██████████| 4/4 [00:00<00:00,  7.29it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 15 test_loss - 0.38 acc - 0.826 auc - 0.859
----------------------------------------------------------------------------


loss - 0.6115: 100%|██████████| 28/28 [00:05<00:00,  5.49it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 0 train_loss - 0.64 acc - 0.687 auc - 0.519


valid loss - 0.6493: 100%|██████████| 4/4 [00:00<00:00,  7.28it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 0 val_loss - 0.65 val acc - 0.663 val auc - 0.655


loss - 0.4999: 100%|██████████| 28/28 [00:05<00:00,  5.37it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 1 train_loss - 0.57 acc - 0.713 auc - 0.655


valid loss - 0.5921: 100%|██████████| 4/4 [00:00<00:00,  7.28it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 1 val_loss - 0.59 val acc - 0.683 val auc - 0.684


loss - 0.5014: 100%|██████████| 28/28 [00:05<00:00,  5.50it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 2 train_loss - 0.53 acc - 0.724 auc - 0.727


valid loss - 0.5471: 100%|██████████| 4/4 [00:00<00:00,  6.39it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 2 val_loss - 0.57 val acc - 0.696 val auc - 0.716


loss - 0.5132: 100%|██████████| 28/28 [00:05<00:00,  5.48it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 3 train_loss - 0.52 acc - 0.735 auc - 0.751


valid loss - 0.5866: 100%|██████████| 4/4 [00:00<00:00,  6.95it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 3 val_loss - 0.56 val acc - 0.705 val auc - 0.735


loss - 0.5117: 100%|██████████| 28/28 [00:05<00:00,  5.26it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 4 train_loss - 0.50 acc - 0.746 auc - 0.774


valid loss - 0.5729: 100%|██████████| 4/4 [00:00<00:00,  6.58it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 4 val_loss - 0.55 val acc - 0.711 val auc - 0.751


loss - 0.5121: 100%|██████████| 28/28 [00:05<00:00,  5.22it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 5 train_loss - 0.49 acc - 0.755 auc - 0.791


valid loss - 0.5118: 100%|██████████| 4/4 [00:00<00:00,  6.87it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 5 val_loss - 0.54 val acc - 0.716 val auc - 0.765


loss - 0.4821: 100%|██████████| 28/28 [00:05<00:00,  5.34it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 6 train_loss - 0.47 acc - 0.763 auc - 0.807


valid loss - 0.5250: 100%|██████████| 4/4 [00:00<00:00,  6.89it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 6 val_loss - 0.53 val acc - 0.720 val auc - 0.772


loss - 0.4773: 100%|██████████| 28/28 [00:05<00:00,  5.44it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 7 train_loss - 0.46 acc - 0.768 auc - 0.818


valid loss - 0.5929: 100%|██████████| 4/4 [00:00<00:00,  6.93it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 7 val_loss - 0.54 val acc - 0.731 val auc - 0.778


loss - 0.4003: 100%|██████████| 28/28 [00:05<00:00,  5.26it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 8 train_loss - 0.45 acc - 0.775 auc - 0.828


valid loss - 0.4701: 100%|██████████| 4/4 [00:00<00:00,  6.90it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 8 val_loss - 0.52 val acc - 0.731 val auc - 0.781


loss - 0.4528: 100%|██████████| 28/28 [00:05<00:00,  5.30it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 9 train_loss - 0.44 acc - 0.779 auc - 0.833


valid loss - 0.5238: 100%|██████████| 4/4 [00:00<00:00,  7.28it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 9 val_loss - 0.53 val acc - 0.733 val auc - 0.783


loss - 0.4643: 100%|██████████| 28/28 [00:05<00:00,  5.54it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 10 train_loss - 0.44 acc - 0.784 auc - 0.840


valid loss - 0.5178: 100%|██████████| 4/4 [00:00<00:00,  7.13it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 10 val_loss - 0.52 val acc - 0.735 val auc - 0.786


loss - 0.4614: 100%|██████████| 28/28 [00:05<00:00,  5.25it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 11 train_loss - 0.43 acc - 0.788 auc - 0.847


valid loss - 0.4928: 100%|██████████| 4/4 [00:00<00:00,  6.81it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 11 val_loss - 0.52 val acc - 0.736 val auc - 0.788


loss - 0.4479: 100%|██████████| 28/28 [00:05<00:00,  5.53it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 12 train_loss - 0.42 acc - 0.792 auc - 0.852


valid loss - 0.5709: 100%|██████████| 4/4 [00:00<00:00,  6.93it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 12 val_loss - 0.52 val acc - 0.738 val auc - 0.790


loss - 0.4323: 100%|██████████| 28/28 [00:05<00:00,  5.46it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 13 train_loss - 0.42 acc - 0.796 auc - 0.857


valid loss - 0.5273: 100%|██████████| 4/4 [00:00<00:00,  7.27it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 13 val_loss - 0.53 val acc - 0.737 val auc - 0.791


loss - 0.3979: 100%|██████████| 28/28 [00:05<00:00,  5.43it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 14 train_loss - 0.41 acc - 0.799 auc - 0.862


valid loss - 0.4930: 100%|██████████| 4/4 [00:00<00:00,  5.83it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 14 val_loss - 0.52 val acc - 0.737 val auc - 0.788


loss - 0.4261: 100%|██████████| 28/28 [00:05<00:00,  5.25it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 15 train_loss - 0.41 acc - 0.801 auc - 0.866


valid loss - 0.5429: 100%|██████████| 4/4 [00:00<00:00,  6.41it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 15 val_loss - 0.53 val acc - 0.734 val auc - 0.790
Val Loss does not improve for 5 consecutive epochs


valid loss - 0.3515: 100%|██████████| 4/4 [00:00<00:00,  6.63it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 15 test_loss - 0.37 acc - 0.829 auc - 0.871
----------------------------------------------------------------------------


loss - 0.5930: 100%|██████████| 28/28 [00:05<00:00,  5.36it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 0 train_loss - 0.63 acc - 0.692 auc - 0.548


valid loss - 0.5853: 100%|██████████| 4/4 [00:00<00:00,  6.77it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 0 val_loss - 0.62 val acc - 0.659 val auc - 0.646


loss - 0.5422: 100%|██████████| 28/28 [00:05<00:00,  5.35it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 1 train_loss - 0.55 acc - 0.720 auc - 0.690


valid loss - 0.5510: 100%|██████████| 4/4 [00:00<00:00,  6.80it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 1 val_loss - 0.59 val acc - 0.680 val auc - 0.684


loss - 0.5228: 100%|██████████| 28/28 [00:05<00:00,  5.33it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 2 train_loss - 0.53 acc - 0.732 auc - 0.735


valid loss - 0.5817: 100%|██████████| 4/4 [00:00<00:00,  6.74it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 2 val_loss - 0.58 val acc - 0.692 val auc - 0.711


loss - 0.4989: 100%|██████████| 28/28 [00:05<00:00,  5.47it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 3 train_loss - 0.51 acc - 0.740 auc - 0.759


valid loss - 0.5685: 100%|██████████| 4/4 [00:00<00:00,  5.79it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 3 val_loss - 0.57 val acc - 0.692 val auc - 0.731


loss - 0.4986: 100%|██████████| 28/28 [00:05<00:00,  5.33it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 4 train_loss - 0.49 acc - 0.751 auc - 0.780


valid loss - 0.5897: 100%|██████████| 4/4 [00:00<00:00,  6.77it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 4 val_loss - 0.56 val acc - 0.709 val auc - 0.749


loss - 0.4871: 100%|██████████| 28/28 [00:05<00:00,  5.36it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 5 train_loss - 0.48 acc - 0.760 auc - 0.799


valid loss - 0.5517: 100%|██████████| 4/4 [00:00<00:00,  5.32it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 5 val_loss - 0.54 val acc - 0.717 val auc - 0.765


loss - 0.4446: 100%|██████████| 28/28 [00:05<00:00,  5.32it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 6 train_loss - 0.46 acc - 0.768 auc - 0.813


valid loss - 0.5326: 100%|██████████| 4/4 [00:00<00:00,  6.86it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 6 val_loss - 0.53 val acc - 0.724 val auc - 0.774


loss - 0.4588: 100%|██████████| 28/28 [00:05<00:00,  5.38it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 7 train_loss - 0.45 acc - 0.774 auc - 0.823


valid loss - 0.5026: 100%|██████████| 4/4 [00:00<00:00,  6.72it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 7 val_loss - 0.53 val acc - 0.724 val auc - 0.780


loss - 0.4625: 100%|██████████| 28/28 [00:05<00:00,  5.39it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 8 train_loss - 0.45 acc - 0.779 auc - 0.831


valid loss - 0.5605: 100%|██████████| 4/4 [00:00<00:00,  5.52it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 8 val_loss - 0.53 val acc - 0.730 val auc - 0.784


loss - 0.4748: 100%|██████████| 28/28 [00:05<00:00,  5.26it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 9 train_loss - 0.44 acc - 0.784 auc - 0.839


valid loss - 0.5152: 100%|██████████| 4/4 [00:00<00:00,  6.90it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 9 val_loss - 0.53 val acc - 0.726 val auc - 0.786


loss - 0.4232: 100%|██████████| 28/28 [00:05<00:00,  5.51it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 10 train_loss - 0.43 acc - 0.789 auc - 0.846


valid loss - 0.5352: 100%|██████████| 4/4 [00:00<00:00,  6.96it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 10 val_loss - 0.52 val acc - 0.731 val auc - 0.788


loss - 0.4366: 100%|██████████| 28/28 [00:05<00:00,  5.32it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 11 train_loss - 0.42 acc - 0.793 auc - 0.851


valid loss - 0.5619: 100%|██████████| 4/4 [00:00<00:00,  7.17it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 11 val_loss - 0.53 val acc - 0.733 val auc - 0.789


loss - 0.4025: 100%|██████████| 28/28 [00:05<00:00,  5.47it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 12 train_loss - 0.42 acc - 0.796 auc - 0.856


valid loss - 0.5096: 100%|██████████| 4/4 [00:00<00:00,  7.35it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 12 val_loss - 0.52 val acc - 0.732 val auc - 0.789


loss - 0.4123: 100%|██████████| 28/28 [00:05<00:00,  5.43it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 13 train_loss - 0.41 acc - 0.799 auc - 0.861


valid loss - 0.5197: 100%|██████████| 4/4 [00:00<00:00,  7.05it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 13 val_loss - 0.52 val acc - 0.732 val auc - 0.789


loss - 0.3843: 100%|██████████| 28/28 [00:05<00:00,  5.11it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 14 train_loss - 0.41 acc - 0.803 auc - 0.865


valid loss - 0.5661: 100%|██████████| 4/4 [00:00<00:00,  6.52it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 14 val_loss - 0.54 val acc - 0.735 val auc - 0.790


loss - 0.3933: 100%|██████████| 28/28 [00:05<00:00,  5.35it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 15 train_loss - 0.40 acc - 0.806 auc - 0.870


valid loss - 0.5053: 100%|██████████| 4/4 [00:00<00:00,  6.94it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 15 val_loss - 0.52 val acc - 0.735 val auc - 0.790


loss - 0.4604: 100%|██████████| 28/28 [00:05<00:00,  5.30it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 16 train_loss - 0.39 acc - 0.810 auc - 0.875


valid loss - 0.5113: 100%|██████████| 4/4 [00:00<00:00,  6.60it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 16 val_loss - 0.52 val acc - 0.734 val auc - 0.790


loss - 0.4034: 100%|██████████| 28/28 [00:05<00:00,  5.44it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 17 train_loss - 0.39 acc - 0.813 auc - 0.878


valid loss - 0.5030: 100%|██████████| 4/4 [00:00<00:00,  6.90it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 17 val_loss - 0.52 val acc - 0.734 val auc - 0.791


loss - 0.3814: 100%|██████████| 28/28 [00:05<00:00,  5.21it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 18 train_loss - 0.38 acc - 0.816 auc - 0.882


valid loss - 0.5638: 100%|██████████| 4/4 [00:00<00:00,  4.95it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 18 val_loss - 0.54 val acc - 0.731 val auc - 0.790


loss - 0.4262: 100%|██████████| 28/28 [00:05<00:00,  5.08it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 19 train_loss - 0.38 acc - 0.819 auc - 0.885


valid loss - 0.5548: 100%|██████████| 4/4 [00:00<00:00,  6.86it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 19 val_loss - 0.54 val acc - 0.729 val auc - 0.788


loss - 0.4263: 100%|██████████| 28/28 [00:05<00:00,  5.51it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 20 train_loss - 0.37 acc - 0.821 auc - 0.888


valid loss - 0.5211: 100%|██████████| 4/4 [00:00<00:00,  6.90it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 20 val_loss - 0.53 val acc - 0.732 val auc - 0.787


loss - 0.3412: 100%|██████████| 28/28 [00:05<00:00,  5.48it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 21 train_loss - 0.37 acc - 0.825 auc - 0.892


valid loss - 0.5789: 100%|██████████| 4/4 [00:00<00:00,  6.80it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 21 val_loss - 0.55 val acc - 0.732 val auc - 0.786
Val Loss does not improve for 5 consecutive epochs


valid loss - 0.3309: 100%|██████████| 4/4 [00:00<00:00,  6.93it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 21 test_loss - 0.41 acc - 0.807 auc - 0.853
----------------------------------------------------------------------------


loss - 0.5829: 100%|██████████| 28/28 [00:05<00:00,  5.47it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 0 train_loss - 0.63 acc - 0.696 auc - 0.541


valid loss - 0.5888: 100%|██████████| 4/4 [00:00<00:00,  6.91it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 0 val_loss - 0.64 val acc - 0.640 val auc - 0.648


loss - 0.5568: 100%|██████████| 28/28 [00:05<00:00,  5.36it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 1 train_loss - 0.56 acc - 0.717 auc - 0.680


valid loss - 0.6094: 100%|██████████| 4/4 [00:00<00:00,  6.87it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 1 val_loss - 0.60 val acc - 0.674 val auc - 0.692


loss - 0.5213: 100%|██████████| 28/28 [00:05<00:00,  5.39it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 2 train_loss - 0.53 acc - 0.732 auc - 0.730


valid loss - 0.5605: 100%|██████████| 4/4 [00:00<00:00,  5.05it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 2 val_loss - 0.58 val acc - 0.688 val auc - 0.716


loss - 0.5185: 100%|██████████| 28/28 [00:05<00:00,  5.21it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 3 train_loss - 0.51 acc - 0.740 auc - 0.755


valid loss - 0.5752: 100%|██████████| 4/4 [00:00<00:00,  6.77it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 3 val_loss - 0.57 val acc - 0.695 val auc - 0.735


loss - 0.4931: 100%|██████████| 28/28 [00:05<00:00,  5.38it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 4 train_loss - 0.49 acc - 0.750 auc - 0.779


valid loss - 0.5472: 100%|██████████| 4/4 [00:00<00:00,  6.81it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 4 val_loss - 0.56 val acc - 0.703 val auc - 0.752


loss - 0.5115: 100%|██████████| 28/28 [00:05<00:00,  5.42it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 5 train_loss - 0.48 acc - 0.760 auc - 0.798


valid loss - 0.5560: 100%|██████████| 4/4 [00:00<00:00,  6.63it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 5 val_loss - 0.55 val acc - 0.710 val auc - 0.764


loss - 0.5011: 100%|██████████| 28/28 [00:05<00:00,  5.36it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 6 train_loss - 0.46 acc - 0.768 auc - 0.813


valid loss - 0.5189: 100%|██████████| 4/4 [00:00<00:00,  6.60it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 6 val_loss - 0.54 val acc - 0.713 val auc - 0.771


loss - 0.4369: 100%|██████████| 28/28 [00:05<00:00,  5.46it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 7 train_loss - 0.45 acc - 0.774 auc - 0.823


valid loss - 0.5443: 100%|██████████| 4/4 [00:00<00:00,  5.24it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 7 val_loss - 0.54 val acc - 0.716 val auc - 0.776


loss - 0.4062: 100%|██████████| 28/28 [00:05<00:00,  5.06it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 8 train_loss - 0.45 acc - 0.779 auc - 0.831


valid loss - 0.5446: 100%|██████████| 4/4 [00:00<00:00,  6.55it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 8 val_loss - 0.54 val acc - 0.717 val auc - 0.777


loss - 0.4116: 100%|██████████| 28/28 [00:05<00:00,  5.42it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 9 train_loss - 0.44 acc - 0.784 auc - 0.839


valid loss - 0.5274: 100%|██████████| 4/4 [00:00<00:00,  7.09it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 9 val_loss - 0.54 val acc - 0.717 val auc - 0.779


loss - 0.4586: 100%|██████████| 28/28 [00:05<00:00,  5.53it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 10 train_loss - 0.43 acc - 0.788 auc - 0.845


valid loss - 0.5575: 100%|██████████| 4/4 [00:00<00:00,  6.88it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 10 val_loss - 0.54 val acc - 0.723 val auc - 0.781


loss - 0.4188: 100%|██████████| 28/28 [00:05<00:00,  5.49it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 11 train_loss - 0.42 acc - 0.791 auc - 0.850


valid loss - 0.5321: 100%|██████████| 4/4 [00:00<00:00,  7.08it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 11 val_loss - 0.54 val acc - 0.722 val auc - 0.783


loss - 0.4592: 100%|██████████| 28/28 [00:05<00:00,  5.47it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 12 train_loss - 0.42 acc - 0.795 auc - 0.855


valid loss - 0.5038: 100%|██████████| 4/4 [00:00<00:00,  6.88it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 12 val_loss - 0.54 val acc - 0.721 val auc - 0.783


loss - 0.4409: 100%|██████████| 28/28 [00:05<00:00,  5.01it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 13 train_loss - 0.41 acc - 0.798 auc - 0.860


valid loss - 0.5860: 100%|██████████| 4/4 [00:00<00:00,  6.03it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 13 val_loss - 0.54 val acc - 0.725 val auc - 0.784


loss - 0.4196: 100%|██████████| 28/28 [00:05<00:00,  5.39it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 14 train_loss - 0.41 acc - 0.801 auc - 0.864


valid loss - 0.5327: 100%|██████████| 4/4 [00:00<00:00,  6.38it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 14 val_loss - 0.55 val acc - 0.714 val auc - 0.783


loss - 0.3643: 100%|██████████| 28/28 [00:05<00:00,  5.39it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 15 train_loss - 0.40 acc - 0.804 auc - 0.868


valid loss - 0.5277: 100%|██████████| 4/4 [00:00<00:00,  6.52it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 15 val_loss - 0.53 val acc - 0.724 val auc - 0.786


loss - 0.4414: 100%|██████████| 28/28 [00:05<00:00,  5.30it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 16 train_loss - 0.40 acc - 0.808 auc - 0.873


valid loss - 0.5410: 100%|██████████| 4/4 [00:00<00:00,  6.46it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 16 val_loss - 0.55 val acc - 0.718 val auc - 0.783


loss - 0.3856: 100%|██████████| 28/28 [00:05<00:00,  5.35it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 17 train_loss - 0.39 acc - 0.810 auc - 0.876


valid loss - 0.5460: 100%|██████████| 4/4 [00:00<00:00,  6.75it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 17 val_loss - 0.56 val acc - 0.718 val auc - 0.784


loss - 0.4190: 100%|██████████| 28/28 [00:05<00:00,  5.35it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 18 train_loss - 0.39 acc - 0.813 auc - 0.879


valid loss - 0.5500: 100%|██████████| 4/4 [00:00<00:00,  5.80it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 18 val_loss - 0.55 val acc - 0.716 val auc - 0.783


loss - 0.3995: 100%|██████████| 28/28 [00:05<00:00,  5.27it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 19 train_loss - 0.38 acc - 0.816 auc - 0.883


valid loss - 0.5481: 100%|██████████| 4/4 [00:00<00:00,  6.42it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

epoch - 19 val_loss - 0.55 val acc - 0.723 val auc - 0.782


loss - 0.3469: 100%|██████████| 28/28 [00:05<00:00,  5.01it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 20 train_loss - 0.38 acc - 0.820 auc - 0.887


valid loss - 0.5380: 100%|██████████| 4/4 [00:00<00:00,  6.96it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

epoch - 20 val_loss - 0.55 val acc - 0.714 val auc - 0.780
Val Loss does not improve for 5 consecutive epochs


valid loss - 0.3299: 100%|██████████| 4/4 [00:00<00:00,  6.79it/s]


epoch - 20 test_loss - 0.38 acc - 0.819 auc - 0.860


In [35]:
print("test avg loss: ", np.mean(test_losses), np.std(test_losses))
print("test avg acc: ", np.mean(test_accs), np.std(test_accs))
print("test avg auc: ", np.mean(test_aucs), np.std(test_aucs))

test avg loss:  0.38554222881793976 0.014184064903042477
test avg acc:  0.8194767648406174 0.007848861411297061
test avg auc:  0.8592138838147333 0.006289587944095433


In [36]:
print("train avg loss: ", np.mean(train_losses), np.std(train_losses))
print("train avg acc: ", np.mean(train_accs), np.std(train_accs))
print("train avg auc: ", np.mean(train_aucs), np.std(train_aucs))

train avg loss:  0.38508885779551105 0.014740751426031942
train avg acc:  0.8141624949222173 0.008583210668538632
train avg auc:  0.8801929897471057 0.010153732797921613
