In [2]:
import numpy as np
import argparse
import os
import imp
import re
import pickle
import datetime
import random
import math
import copy


import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils import data
from torch.autograd import Variable
import torch.nn.functional as F


from utils import utils
from utils.readers import InHospitalMortalityReader
from utils.preprocessing import Discretizer, Normalizer
from utils import metrics
from utils import common_utils

In [3]:
data_path = './data/'
file_name = './model/concare0'
small_part = True
arg_timestep = 1.0
batch_size = 256
epochs = 100

In [4]:
# Build readers, discretizers, normalizers
train_reader = InHospitalMortalityReader(dataset_dir=os.path.join(data_path, 'train'),
                                         listfile=os.path.join(data_path, 'train_listfile.csv'),
                                         period_length=48.0)

val_reader = InHospitalMortalityReader(dataset_dir=os.path.join(data_path, 'train'),
                                       listfile=os.path.join(data_path, 'val_listfile.csv'),
                                       period_length=48.0)

discretizer = Discretizer(timestep=arg_timestep,
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

In [5]:
discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
normalizer_state = 'ihm_normalizer'
normalizer_state = os.path.join(os.path.dirname(data_path), normalizer_state)
normalizer.load_params(normalizer_state)

In [6]:
n_trained_chunks = 0
train_raw = utils.load_data(train_reader, discretizer, normalizer, small_part, return_names=True)
val_raw = utils.load_data(val_reader, discretizer, normalizer, small_part, return_names=True)

In [7]:
demographic_data = []
diagnosis_data = []
idx_list = []

demo_path = data_path + 'demographic/'
for cur_name in os.listdir(demo_path):
    cur_id, cur_episode = cur_name.split('_', 1)
    cur_episode = cur_episode[:-4]
    cur_file = demo_path + cur_name

    with open(cur_file, "r") as tsfile:
        header = tsfile.readline().strip().split(',')
        if header[0] != "Icustay":
            continue
        cur_data = tsfile.readline().strip().split(',')

    if len(cur_data) == 1:
        cur_demo = np.zeros(12)
        cur_diag = np.zeros(128)
    else:
        if cur_data[3] == '':
            cur_data[3] = 60.0
        if cur_data[4] == '':
            cur_data[4] = 160
        if cur_data[5] == '':
            cur_data[5] = 60

        cur_demo = np.zeros(12)
        cur_demo[int(cur_data[1])] = 1
        cur_demo[5 + int(cur_data[2])] = 1
        cur_demo[9:] = cur_data[3:6]
        cur_diag = np.array(cur_data[8:], dtype=np.int)

    demographic_data.append(cur_demo)
    diagnosis_data.append(cur_diag)
    idx_list.append(cur_id+'_'+cur_episode)

for each_idx in range(9,12):
    cur_val = []
    for i in range(len(demographic_data)):
        cur_val.append(demographic_data[i][each_idx])
    cur_val = np.array(cur_val)
    _mean = np.mean(cur_val)
    _std = np.std(cur_val)
    _std = _std if _std > 1e-7 else 1e-7
    for i in range(len(demographic_data)):
        demographic_data[i][each_idx] = (demographic_data[i][each_idx] - _mean) / _std

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations


In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu')
#device = torch.device('cpu')
print("available device: {}".format(device))

available device: cpu


In [10]:
class SingleAttention(nn.Module):
    def __init__(self, attention_input_dim, attention_hidden_dim, attention_type='add', demographic_dim=12, time_aware=False, use_demographic=False):
        super(SingleAttention, self).__init__()

        self.attention_type = attention_type
        self.attention_hidden_dim = attention_hidden_dim
        self.attention_input_dim = attention_input_dim
        self.use_demographic = use_demographic
        self.demographic_dim = demographic_dim
        self.time_aware = time_aware

        # batch_time = torch.arange(0, batch_mask.size()[1], dtype=torch.float32).reshape(1, batch_mask.size()[1], 1)
        # batch_time = batch_time.repeat(batch_mask.size()[0], 1, 1)

        if attention_type == 'add':
            if self.time_aware == True:
                # self.Wx = nn.Parameter(torch.randn(attention_input_dim+1, attention_hidden_dim))
                self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
                self.Wtime_aware = nn.Parameter(torch.randn(1, attention_hidden_dim))
                nn.init.kaiming_uniform_(self.Wtime_aware, a=math.sqrt(5))
            else:
                self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
            self.Wt = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
            self.Wd = nn.Parameter(torch.randn(demographic_dim, attention_hidden_dim))
            self.bh = nn.Parameter(torch.zeros(attention_hidden_dim,))
            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
            self.ba = nn.Parameter(torch.zeros(1,))

            nn.init.kaiming_uniform_(self.Wd, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        elif attention_type == 'mul':
            self.Wa = nn.Parameter(torch.randn(attention_input_dim, attention_input_dim))
            self.ba = nn.Parameter(torch.zeros(1,))

            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        elif attention_type == 'concat':
            if self.time_aware == True:
                self.Wh = nn.Parameter(torch.randn(2*attention_input_dim+1, attention_hidden_dim))
            else:
                self.Wh = nn.Parameter(torch.randn(2*attention_input_dim, attention_hidden_dim))

            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
            self.ba = nn.Parameter(torch.zeros(1,))

            nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))

        elif attention_type == 'new':
            self.Wt = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
            self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))

            self.rate = nn.Parameter(torch.zeros(1)+0.8)
            nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5))

        else:
            raise RuntimeError('Wrong attention type.')

        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()

    def forward(self, input, demo=None):

        batch_size, time_step, input_dim = input.size() # batch_size * time_step * hidden_dim(i)
        #assert(input_dim == self.input_dim)

        # time_decays = torch.zeros((time_step,time_step)).to(device)# t*t
        # for this_time in range(time_step):
        #     for pre_time in range(time_step):
        #         if pre_time > this_time:
        #             break
        #         time_decays[this_time][pre_time] = torch.tensor(this_time - pre_time, dtype=torch.float32).to(device)
        # b_time_decays = tile(time_decays, 0, batch_size).view(batch_size,time_step,time_step).unsqueeze(-1).to(device)# b t t 1

        time_decays = torch.tensor(range(47,-1,-1), dtype=torch.float32).unsqueeze(-1).unsqueeze(0).to(device)# 1*t*1
        b_time_decays = time_decays.repeat(batch_size,1,1)+1# b t 1

        if self.attention_type == 'add': #B*T*I  @ H*I
            q = torch.matmul(input[:,-1,:], self.Wt)# b h
            q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            if self.time_aware == True:
                # k_input = torch.cat((input, time), dim=-1)
                k = torch.matmul(input, self.Wx)#b t h
                # k = torch.reshape(k, (batch_size, 1, time_step, self.attention_hidden_dim)) #B*1*T*H
                time_hidden = torch.matmul(b_time_decays, self.Wtime_aware)#  b t h
            else:
                k = torch.matmul(input, self.Wx)# b t h
                # k = torch.reshape(k, (batch_size, 1, time_step, self.attention_hidden_dim)) #B*1*T*H
            if self.use_demographic == True:
                d = torch.matmul(demo, self.Wd) #B*H
                d = torch.reshape(d, (batch_size, 1, self.attention_hidden_dim)) # b 1 h
            h = q + k + self.bh # b t h
            if self.time_aware == True:
                h += time_hidden
            h = self.tanh(h) #B*T*H
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step))# b t
        elif self.attention_type == 'mul':
            e = torch.matmul(input[:,-1,:], self.Wa)#b i
            e = torch.matmul(e.unsqueeze(1), input.permute(0,2,1)).squeeze() + self.ba #b t
        elif self.attention_type == 'concat':
            q = input[:,-1,:].unsqueeze(1).repeat(1,time_step,1)# b t i
            k = input
            c = torch.cat((q, k), dim=-1) #B*T*2I
            if self.time_aware == True:
                c = torch.cat((c, b_time_decays), dim=-1) #B*T*2I+1
            h = torch.matmul(c, self.Wh)
            h = self.tanh(h)
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step)) # b t

        elif self.attention_type == 'new':

            q = torch.matmul(input[:,-1,:], self.Wt)# b h
            q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            k = torch.matmul(input, self.Wx)#b t h
            dot_product = torch.matmul(q, k.transpose(1, 2)).squeeze() # b t
            denominator =  self.sigmoid(self.rate) * (torch.log(2.72 +  (1-self.sigmoid(dot_product)))* (b_time_decays.squeeze()))
            e = self.relu(self.sigmoid(dot_product)/(denominator)) # b * t
        #          * (b_time_decays.squeeze())
        # e = torch.exp(e - torch.max(e, dim=-1, keepdim=True).values)

        # if self.attention_width is not None:
        #     if self.history_only:
        #         lower = torch.arange(0, time_step).to(device) - (self.attention_width - 1)
        #     else:
        #         lower = torch.arange(0, time_step).to(device) - self.attention_width // 2
        #     lower = lower.unsqueeze(-1)
        #     upper = lower + self.attention_width
        #     indices = torch.arange(0, time_step).unsqueeze(0).to(device)
        #     e = e * (lower <= indices).float() * (indices < upper).float()

        # s = torch.sum(e, dim=-1, keepdim=True)
        # mask = subsequent_mask(time_step).to(device) # 1 t t 下三角
        # scores = e.masked_fill(mask == 0, -1e9)# b t t 下三角
        a = self.softmax(e) #B*T
        v = torch.matmul(a.unsqueeze(1), input).squeeze() #B*I

        return v, a

class FinalAttentionQKV(nn.Module):
    def __init__(self, attention_input_dim, attention_hidden_dim, attention_type='add', dropout=None):
        super(FinalAttentionQKV, self).__init__()

        self.attention_type = attention_type
        self.attention_hidden_dim = attention_hidden_dim
        self.attention_input_dim = attention_input_dim


        self.W_q = nn.Linear(attention_input_dim, attention_hidden_dim)
        self.W_k = nn.Linear(attention_input_dim, attention_hidden_dim)
        self.W_v = nn.Linear(attention_input_dim, attention_hidden_dim)

        self.W_out = nn.Linear(attention_hidden_dim, 1)

        self.b_in = nn.Parameter(torch.zeros(1,))
        self.b_out = nn.Parameter(torch.zeros(1,))

        nn.init.kaiming_uniform_(self.W_q.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_k.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_v.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_out.weight, a=math.sqrt(5))

        self.Wh = nn.Parameter(torch.randn(2*attention_input_dim, attention_hidden_dim))
        self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
        self.ba = nn.Parameter(torch.zeros(1,))

        nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))

        self.dropout = nn.Dropout(p=dropout)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):

        batch_size, time_step, input_dim = input.size() # batch_size * input_dim + 1 * hidden_dim(i)
        input_q = self.W_q(input[:, -1, :]) # b h
        input_k = self.W_k(input)# b t h
        input_v = self.W_v(input)# b t h

        if self.attention_type == 'add': #B*T*I  @ H*I

            q = torch.reshape(input_q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            h = q + input_k + self.b_in # b t h
            h = self.tanh(h) #B*T*H
            e = self.W_out(h) # b t 1
            e = torch.reshape(e, (batch_size, time_step))# b t

        elif self.attention_type == 'mul':
            q = torch.reshape(input_q, (batch_size, self.attention_hidden_dim, 1)) #B*h 1
            e = torch.matmul(input_k, q).squeeze()#b t

        elif self.attention_type == 'concat':
            q = input_q.unsqueeze(1).repeat(1,time_step,1)# b t h
            k = input_k
            c = torch.cat((q, k), dim=-1) #B*T*2I
            h = torch.matmul(c, self.Wh)
            h = self.tanh(h)
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step)) # b t

        a = self.softmax(e) #B*T
        if self.dropout is not None:
            a = self.dropout(a)
        v = torch.matmul(a.unsqueeze(1), input_v).squeeze() #B*I

        return v, a

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
    return torch.index_select(a, dim, order_index).to(device)

class PositionwiseFeedForward(nn.Module): # new added
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x)))), None

# class PositionwiseFeedForwardConv(nn.Module):

#     def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
#         super(PositionalWiseFeedForward, self).__init__()
#         self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
#         self.w2 = nn.Conv1d(model_dim, ffn_dim, 1)
#         self.dropout = nn.Dropout(dropout)
#         self.layer_norm = nn.LayerNorm(model_dim)

#     def forward(self, x):
#         output = x.transpose(1, 2)
#         output = self.w2(F.relu(self.w1(output)))
#         output = self.dropout(output.transpose(1, 2))

#         # add residual and norm layer
#         output = self.layer_norm(x + output)
#         return output

class PositionalEncoding(nn.Module): # new added / not use anymore
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=400):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0., max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)],
                         requires_grad=False)
        return self.dropout(x)

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0 # 下三角矩阵

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)# b h t d_k
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # b h t t
    if mask is not None:# 1 1 t t
        scores = scores.masked_fill(mask == 0, -1e9)# b h t t 下三角
    p_attn = F.softmax(scores, dim = -1)# b h t t
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn # b h t v (d_k)

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, self.d_k * self.h), 3)
        self.final_linear = nn.Linear(d_model, d_model)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1) # 1 1 t t

        nbatches = query.size(0)# b
        input_dim = query.size(1)# i+1
        feature_dim = query.size(-1)# i+1

        #input size -> # batch_size * d_input * hidden_dim

        # d_model => h * d_k
        query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))] # b num_head d_input d_k


        x, self.attn = attention(query, key, value, mask=mask,
                                 dropout=self.dropout)# b num_head d_input d_v (d_k)


        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)# batch_size * d_input * hidden_dim

        #DeCov
        DeCov_contexts = x.transpose(0, 1).transpose(1, 2) # I+1 H B
        Covs = cov(DeCov_contexts[0,:,:])
        DeCov_loss = 0.5 * (torch.norm(Covs, p = 'fro')**2 - torch.norm(torch.diag(Covs))**2 )
        for i in range(feature_dim -1 + 1):
            Covs = cov(DeCov_contexts[i+1,:,:])
            DeCov_loss += 0.5 * (torch.norm(Covs, p = 'fro')**2 - torch.norm(torch.diag(Covs))**2 )


        return self.final_linear(x), DeCov_loss

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-7):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

def cov(m, y=None):
    if y is not None:
        m = torch.cat((m, y), dim=0)
    m_exp = torch.mean(m, dim=1)
    x = m - m_exp[:, None]
    cov = 1 / (x.size(1) - 1) * x.mm(x.t())
    return cov

class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        returned_value = sublayer(self.norm(x))
        return x + self.dropout(returned_value[0]) , returned_value[1]

class ConCare(nn.Module):
    def __init__(self, input_dim, hidden_dim, d_model,  MHD_num_head, d_ff, output_dim, keep_prob=0.5):
        super(ConCare, self).__init__()

        # hyperparameters
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim  # d_model
        self.d_model = d_model
        self.MHD_num_head = MHD_num_head
        self.d_ff = d_ff
        self.output_dim = output_dim
        self.keep_prob = keep_prob

        # layers
        self.PositionalEncoding = PositionalEncoding(self.d_model, dropout = 0, max_len = 400)

        self.GRUs = clones(nn.GRU(1, self.hidden_dim, batch_first = True), self.input_dim)
        self.LastStepAttentions = clones(SingleAttention(self.hidden_dim, 8, attention_type='new', demographic_dim=12, time_aware=True, use_demographic=False),self.input_dim)

        self.FinalAttentionQKV = FinalAttentionQKV(self.hidden_dim, self.hidden_dim, attention_type='mul',dropout = 1 - self.keep_prob)

        self.MultiHeadedAttention = MultiHeadedAttention(self.MHD_num_head, self.d_model,dropout = 1 - self.keep_prob)
        self.SublayerConnection = SublayerConnection(self.d_model, dropout = 1 - self.keep_prob)

        self.PositionwiseFeedForward = PositionwiseFeedForward(self.d_model, self.d_ff, dropout=0.1)

        self.demo_proj_main = nn.Linear(12, self.hidden_dim)
        self.demo_proj = nn.Linear(12, self.hidden_dim)
        self.output0 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.output1 = nn.Linear(self.hidden_dim, self.output_dim)

        self.dropout = nn.Dropout(p = 1 - self.keep_prob)
        self.tanh=nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.relu=nn.ReLU()

    def forward(self, input, demo_input):
        # input shape [batch_size, timestep, feature_dim]
        demo_main = self.tanh(self.demo_proj_main(demo_input)).unsqueeze(1)# b hidden_dim

        batch_size = input.size(0)
        time_step = input.size(1)
        feature_dim = input.size(2)
        assert(feature_dim == self.input_dim)# input Tensor : 256 * 48 * 76
        assert(self.d_model % self.MHD_num_head == 0)

        # Initialization
        #cur_hs = Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0))

        # forward
        GRU_embeded_input = self.GRUs[0](input[:,:,0].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0] # b t h
        Attention_embeded_input = self.LastStepAttentions[0](GRU_embeded_input)[0].unsqueeze(1)# b 1 h
        for i in range(feature_dim-1):
            embeded_input = self.GRUs[i+1](input[:,:,i+1].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0] # b 1 h
            embeded_input = self.LastStepAttentions[i+1](embeded_input)[0].unsqueeze(1)# b 1 h
            Attention_embeded_input = torch.cat((Attention_embeded_input, embeded_input), 1)# b i h

        Attention_embeded_input = torch.cat((Attention_embeded_input, demo_main), 1)# b i+1 h
        posi_input = self.dropout(Attention_embeded_input) # batch_size * d_input+1 * hidden_dim

        #         GRU_embeded_input = self.GRUs[0](input[:,:,0].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0][:,-1,:].unsqueeze(1) # b 1 h
        #         for i in range(feature_dim-1):
        #             embeded_input = self.GRUs[i+1](input[:,:,i+1].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0][:,-1,:].unsqueeze(1) # b 1 h
        #             GRU_embeded_input = torch.cat((GRU_embeded_input, embeded_input), 1)

        #         GRU_embeded_input = torch.cat((GRU_embeded_input, demo_main), 1)# b i+1 h
        #         posi_input = self.dropout(GRU_embeded_input) # batch_size * d_input * hidden_dim


        #mask = subsequent_mask(time_step).to(device) # 1 t t 下三角 N to 1任务不用mask
        contexts = self.SublayerConnection(posi_input, lambda x: self.MultiHeadedAttention(posi_input, posi_input, posi_input, None))# # batch_size * d_input * hidden_dim

        DeCov_loss = contexts[1]
        contexts = contexts[0]

        contexts = self.SublayerConnection(contexts, lambda x: self.PositionwiseFeedForward(contexts))[0]# # batch_size * d_input * hidden_dim
        #contexts = contexts.view(batch_size, feature_dim * self.hidden_dim)#
        # contexts = torch.matmul(self.Wproj, contexts) + self.bproj
        # contexts = contexts.squeeze()
        # demo_key = self.demo_proj(demo_input)# b hidden_dim
        # demo_key = self.relu(demo_key)
        # input_dim_scores = torch.matmul(contexts, demo_key.unsqueeze(-1)).squeeze() # b i
        # input_dim_scores = self.dropout(self.sigmoid(input_dim_scores)).unsqueeze(1)# b i

        # weighted_contexts = torch.matmul(input_dim_scores, contexts).squeeze()

        weighted_contexts = self.FinalAttentionQKV(contexts)[0]
        output = self.output1(self.relu(self.output0(weighted_contexts)))# b 1
        output = self.sigmoid(output)

        return output, DeCov_loss
    #, self.MultiHeadedAttention.attn


In [11]:
def get_loss(y_pred, y_true):
    loss = torch.nn.BCELoss()
    return loss(y_pred, y_true)

In [12]:
class Dataset(data.Dataset):
    def __init__(self, x, y, name):
        self.x = x
        self.y = y
        self.name = name

    def __getitem__(self, index):#返回的是tensor
        return self.x[index], self.y[index], self.name[index]

    def __len__(self):
        return len(self.x)

In [13]:
train_dataset = Dataset(train_raw['data'][0], train_raw['data'][1], train_raw['names'])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataset = Dataset(val_raw['data'][0], val_raw['data'][1], val_raw['names'])
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [15]:
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED) #numpy
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED) # cpu
torch.cuda.manual_seed(RANDOM_SEED) #gpu
torch.backends.cudnn.deterministic=True # cudnn

model = ConCare(input_dim = 76, hidden_dim = 64, d_model = 64,  MHD_num_head = 4 , d_ff = 256, output_dim = 1).to(device)
# input_dim, d_model, d_k, d_v, MHD_num_head, d_ff, output_dim
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

max_roc = 0
max_prc = 0
train_loss = []
train_model_loss = []
train_decov_loss = []
valid_loss = []
valid_model_loss = []
valid_decov_loss = []
history = []
np.set_printoptions(threshold=np.inf)
np.set_printoptions(precision=2)
np.set_printoptions(suppress=True)

for each_epoch in range(100):
    batch_loss = []
    model_batch_loss = []
    decov_batch_loss = []

    model.train()

    for step, (batch_x, batch_y, batch_name) in enumerate(train_loader):
        optimizer.zero_grad()
        batch_x = batch_x.float().to(device)
        batch_y = batch_y.float().to(device)

        batch_demo = []
        for i in range(len(batch_name)):
            cur_id, cur_ep, _ = batch_name[i].split('_', 2)
            cur_idx = cur_id + '_' + cur_ep
            cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
            batch_demo.append(cur_demo)

        batch_demo = torch.stack(batch_demo).to(device)
        output, decov_loss = model(batch_x, batch_demo)


        model_loss = get_loss(output, batch_y.unsqueeze(-1))
        loss = model_loss + 800* decov_loss

        batch_loss.append(loss.cpu().detach().numpy())
        model_batch_loss.append(model_loss.cpu().detach().numpy())
        decov_batch_loss.append(decov_loss.cpu().detach().numpy())
        loss.backward()
        optimizer.step()

        if step % 30 == 0:
            print('Epoch %d Batch %d: Train Loss = %.4f'%(each_epoch, step, np.mean(np.array(batch_loss))))
            print('Model Loss = %.4f, Decov Loss = %.4f'%(np.mean(np.array(model_batch_loss)), np.mean(np.array(decov_batch_loss))))
    train_loss.append(np.mean(np.array(batch_loss)))
    train_model_loss.append(np.mean(np.array(model_batch_loss)))
    train_decov_loss.append(np.mean(np.array(decov_batch_loss)))

    batch_loss = []
    model_batch_loss = []
    decov_batch_loss = []

    y_true = []
    y_pred = []
    with torch.no_grad():
        model.eval()
        for step, (batch_x, batch_y, batch_name) in enumerate(valid_loader):
            batch_x = batch_x.float().to(device)
            batch_y = batch_y.float().to(device)
            batch_demo = []
            for i in range(len(batch_name)):
                cur_id, cur_ep, _ = batch_name[i].split('_', 2)
                cur_idx = cur_id + '_' + cur_ep
                cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
                batch_demo.append(cur_demo)

            batch_demo = torch.stack(batch_demo).to(device)
            output,decov_loss = model(batch_x, batch_demo)

            model_loss = get_loss(output, batch_y.unsqueeze(-1))

            loss = model_loss + 10* decov_loss
            batch_loss.append(loss.cpu().detach().numpy())
            model_batch_loss.append(model_loss.cpu().detach().numpy())
            decov_batch_loss.append(decov_loss.cpu().detach().numpy())
            y_pred += list(output.cpu().detach().numpy().flatten())
            y_true += list(batch_y.cpu().numpy().flatten())

    valid_loss.append(np.mean(np.array(batch_loss)))
    valid_model_loss.append(np.mean(np.array(model_batch_loss)))
    valid_decov_loss.append(np.mean(np.array(decov_batch_loss)))

    print("\n==>Predicting on validation")
    print('Valid Loss = %.4f'%(valid_loss[-1]))
    print('valid_model Loss = %.4f'%(valid_model_loss[-1]))
    print('valid_decov Loss = %.4f'%(valid_decov_loss[-1]))
    y_pred = np.array(y_pred)
    y_pred = np.stack([1 - y_pred, y_pred], axis=1)
    ret = metrics.print_metrics_binary(y_true, y_pred)
    history.append(ret)
    print()

    cur_auroc = ret['auroc']

    if cur_auroc > max_roc:
        max_roc = cur_auroc
        state = {
            'net': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': each_epoch
        }
        torch.save(state, file_name)
        print('\n------------ Save best model ------------\n')



Epoch 0 Batch 0: Train Loss = 0.9473
Model Loss = 0.7090, Decov Loss = 0.0003

==>Predicting on validation
Valid Loss = 0.6479
valid_model Loss = 0.6479
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.5325671447089226
AUC of PRC = 0.15720638865426895
min(+P, Se) = 0.19742489270386265
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 1 Batch 0: Train Loss = 0.7230
Model Loss = 0.6472, Decov Loss = 0.0001

==>Predicting on validation
Valid Loss = 0.5955
valid_model Loss = 0.5955
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.45294297343934475
AUC of PRC = 0.13560424385323075
min(+P, Se) = 0.14393939393939395
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 2 Batch 0: Train Loss = 0.6205
Model Loss = 0.5914, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.5189
valid_model Loss = 0.5189
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.44409217381253147
AUC of PRC = 0.1354248010115906
min(+P, Se) = 0.14583333333333334
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 3 Batch 0: Train Loss = 0.5476
Model Loss = 0.5317, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4315
valid_model Loss = 0.4315
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.5365343174894113
AUC of PRC = 0.1833143056728533
min(+P, Se) = 0.23076923076923078
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 4 Batch 0: Train Loss = 0.4521
Model Loss = 0.4397, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4061
valid_model Loss = 0.4061
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.6151759839496693
AUC of PRC = 0.23040968356819963
min(+P, Se) = 0.2905405405405405
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 5 Batch 0: Train Loss = 0.4034
Model Loss = 0.3934, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4062
valid_model Loss = 0.4062
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.6319776418233307
AUC of PRC = 0.24631389290127276
min(+P, Se) = 0.2789115646258503
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 6 Batch 0: Train Loss = 0.4352
Model Loss = 0.4257, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4093
valid_model Loss = 0.4093
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.6506700022292127
AUC of PRC = 0.26060765910272843
min(+P, Se) = 0.3120567375886525
f1_score = nan


------------ Save best model ------------


  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])







Epoch 7 Batch 0: Train Loss = 0.4018
Model Loss = 0.3940, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4144
valid_model Loss = 0.4144
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.6713149877393307
AUC of PRC = 0.2818520127437688
min(+P, Se) = 0.3380281690140845
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 8 Batch 0: Train Loss = 0.4087
Model Loss = 0.4010, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4101
valid_model Loss = 0.4101
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.6962326307185496
AUC of PRC = 0.3112755264831296
min(+P, Se) = 0.38028169014084506
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 9 Batch 0: Train Loss = 0.4097
Model Loss = 0.4030, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4060
valid_model Loss = 0.4060
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7158662142190738
AUC of PRC = 0.3359543371864415
min(+P, Se) = 0.41134751773049644
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 10 Batch 0: Train Loss = 0.4078
Model Loss = 0.4017, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4056
valid_model Loss = 0.4056
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7269792518102032
AUC of PRC = 0.35590406261365576
min(+P, Se) = 0.4025974025974026
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 11 Batch 0: Train Loss = 0.4268
Model Loss = 0.4212, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4052
valid_model Loss = 0.4052
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7328082299226379
AUC of PRC = 0.3670721772705472
min(+P, Se) = 0.3945578231292517
f1_score = nan


------------ Save best model ------------


  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])







Epoch 12 Batch 0: Train Loss = 0.4454
Model Loss = 0.4403, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4063
valid_model Loss = 0.4063
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7371758353354965
AUC of PRC = 0.37676800230614305
min(+P, Se) = 0.3900709219858156
f1_score = nan


------------ Save best model ------------


  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])







Epoch 13 Batch 0: Train Loss = 0.4525
Model Loss = 0.4480, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4073
valid_model Loss = 0.4073
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7409077023423246
AUC of PRC = 0.3888267359171856
min(+P, Se) = 0.40425531914893614
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 14 Batch 0: Train Loss = 0.3917
Model Loss = 0.3876, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4057
valid_model Loss = 0.4057
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7430130697908668
AUC of PRC = 0.39440824246316214
min(+P, Se) = 0.3971631205673759
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 15 Batch 0: Train Loss = 0.4125
Model Loss = 0.4087, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4051
valid_model Loss = 0.4051
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7427571231598676
AUC of PRC = 0.39359207172215593
min(+P, Se) = 0.3945578231292517
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 16 Batch 0: Train Loss = 0.4006
Model Loss = 0.3971, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4048
valid_model Loss = 0.4048
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7421255129253048
AUC of PRC = 0.3915784120074775
min(+P, Se) = 0.3971631205673759
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 17 Batch 0: Train Loss = 0.3718
Model Loss = 0.3684, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4047
valid_model Loss = 0.4047
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7409407277140664
AUC of PRC = 0.3877741672840581
min(+P, Se) = 0.3972602739726027
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 18 Batch 0: Train Loss = 0.3665
Model Loss = 0.3634, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4050
valid_model Loss = 0.4050
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7391573576400069
AUC of PRC = 0.38083411703974357
min(+P, Se) = 0.40425531914893614
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 19 Batch 0: Train Loss = 0.4033
Model Loss = 0.4002, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4059
valid_model Loss = 0.4059
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7357392316647265
AUC of PRC = 0.37265444943956394
min(+P, Se) = 0.4041095890410959
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 20 Batch 0: Train Loss = 0.3617
Model Loss = 0.3589, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4043
valid_model Loss = 0.4043
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.732056902715511
AUC of PRC = 0.36246923855130014
min(+P, Se) = 0.3986013986013986
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 21 Batch 0: Train Loss = 0.4099
Model Loss = 0.4073, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4044
valid_model Loss = 0.4044
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7311156796208688
AUC of PRC = 0.36365953747405444
min(+P, Se) = 0.3931034482758621
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 22 Batch 0: Train Loss = 0.4297
Model Loss = 0.4271, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4036
valid_model Loss = 0.4036
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.725790338427497
AUC of PRC = 0.3607634561381753
min(+P, Se) = 0.3900709219858156
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 23 Batch 0: Train Loss = 0.3973
Model Loss = 0.3947, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4030
valid_model Loss = 0.4030
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7183637579570504
AUC of PRC = 0.34946221400413335
min(+P, Se) = 0.3829787234042553
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 24 Batch 0: Train Loss = 0.4132
Model Loss = 0.4107, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4030
valid_model Loss = 0.4030
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7158290606758642
AUC of PRC = 0.3400345233548855
min(+P, Se) = 0.3829787234042553
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 25 Batch 0: Train Loss = 0.4275
Model Loss = 0.4251, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4024
valid_model Loss = 0.4024
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.709731751418027
AUC of PRC = 0.3289783043964702
min(+P, Se) = 0.36879432624113473
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 26 Batch 0: Train Loss = 0.5395
Model Loss = 0.5371, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4019
valid_model Loss = 0.4019
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7031514460984651
AUC of PRC = 0.317940579301071
min(+P, Se) = 0.35664335664335667
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 27 Batch 0: Train Loss = 0.4127
Model Loss = 0.4104, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.4003
valid_model Loss = 0.4003
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.6920796902220131
AUC of PRC = 0.30316246659454443
min(+P, Se) = 0.3561643835616438
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 28 Batch 0: Train Loss = 0.4324
Model Loss = 0.4299, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3984
valid_model Loss = 0.3984
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.6861020979367399
AUC of PRC = 0.30873568336711754
min(+P, Se) = 0.3475177304964539
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 29 Batch 0: Train Loss = 0.4069
Model Loss = 0.4044, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3927
valid_model Loss = 0.3927
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.6881166456129921
AUC of PRC = 0.31282560265261194
min(+P, Se) = 0.3422818791946309
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 30 Batch 0: Train Loss = 0.3961
Model Loss = 0.3935, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3980
valid_model Loss = 0.3980
valid_decov Loss = 0.0000
confusion matrix:
[[832  27]
 [122  19]]
accuracy = 0.8510000109672546
precision class 0 = 0.8721174001693726
precision class 1 = 0.41304346919059753
recall class 0 = 0.9685680866241455
recall class 1 = 0.13475176692008972
AUC of ROC = 0.6479990752895911
AUC of PRC = 0.28330105604643796
min(+P, Se) = 0.2978723404255319
f1_score = 0.20320854808233624





Epoch 31 Batch 0: Train Loss = 0.4321
Model Loss = 0.4282, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3817
valid_model Loss = 0.3817
valid_decov Loss = 0.0000
confusion matrix:
[[853   6]
 [131  10]]
accuracy = 0.8629999756813049
precision class 0 = 0.8668699264526367
precision class 1 = 0.625
recall class 0 = 0.9930151104927063
recall class 1 = 0.07092198729515076
AUC of ROC = 0.6813216753771085
AUC of PRC = 0.30963368023745486
min(+P, Se) = 0.3191489361702128
f1_score = 0.12738853196323216





Epoch 32 Batch 0: Train Loss = 0.3544
Model Loss = 0.3520, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3839
valid_model Loss = 0.3839
valid_decov Loss = 0.0000
confusion matrix:
[[846  13]
 [130  11]]
accuracy = 0.8569999933242798
precision class 0 = 0.8668032884597778
precision class 1 = 0.4583333432674408
recall class 0 = 0.9848661422729492
recall class 1 = 0.07801418751478195
AUC of ROC = 0.6859287147350952
AUC of PRC = 0.27776684004633057
min(+P, Se) = 0.3120567375886525
f1_score = 0.13333334386348755





Epoch 33 Batch 0: Train Loss = 0.3581
Model Loss = 0.3556, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3837
valid_model Loss = 0.3837
valid_decov Loss = 0.0000
confusion matrix:
[[832  27]
 [117  24]]
accuracy = 0.8560000061988831
precision class 0 = 0.8767123222351074
precision class 1 = 0.47058823704719543
recall class 0 = 0.9685680866241455
recall class 1 = 0.1702127605676651
AUC of ROC = 0.6969013944963217
AUC of PRC = 0.28060802059884904
min(+P, Se) = 0.296875
f1_score = 0.24999998862040246





Epoch 34 Batch 0: Train Loss = 0.3471
Model Loss = 0.3447, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3801
valid_model Loss = 0.3801
valid_decov Loss = 0.0000
confusion matrix:
[[828  31]
 [116  25]]
accuracy = 0.8529999852180481
precision class 0 = 0.8771186470985413
precision class 1 = 0.4464285671710968
recall class 0 = 0.9639115333557129
recall class 1 = 0.1773049682378769
AUC of ROC = 0.7060081407541344
AUC of PRC = 0.30091357201683466
min(+P, Se) = 0.3191489361702128
f1_score = 0.25380711576414555





Epoch 35 Batch 0: Train Loss = 0.4007
Model Loss = 0.3983, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3756
valid_model Loss = 0.3756
valid_decov Loss = 0.0000
confusion matrix:
[[842  17]
 [121  20]]
accuracy = 0.8619999885559082
precision class 0 = 0.8743509650230408
precision class 1 = 0.5405405163764954
recall class 0 = 0.9802095293998718
recall class 1 = 0.1418439745903015
AUC of ROC = 0.7152800138706561
AUC of PRC = 0.3329984403356832
min(+P, Se) = 0.33098591549295775
f1_score = 0.22471910274897133





Epoch 36 Batch 0: Train Loss = 0.4014
Model Loss = 0.3992, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3736
valid_model Loss = 0.3736
valid_decov Loss = 0.0000
confusion matrix:
[[843  16]
 [124  17]]
accuracy = 0.8600000143051147
precision class 0 = 0.8717683553695679
precision class 1 = 0.5151515007019043
recall class 0 = 0.98137366771698
recall class 1 = 0.12056737393140793
AUC of ROC = 0.722471288567442
AUC of PRC = 0.3463234329801723
min(+P, Se) = 0.3617021276595745
f1_score = 0.1954022929533058





Epoch 37 Batch 0: Train Loss = 0.3226
Model Loss = 0.3203, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3671
valid_model Loss = 0.3671
valid_decov Loss = 0.0000
confusion matrix:
[[835  24]
 [117  24]]
accuracy = 0.859000027179718
precision class 0 = 0.8771008253097534
precision class 1 = 0.5
recall class 0 = 0.9720605611801147
recall class 1 = 0.1702127605676651
AUC of ROC = 0.7279122185619102
AUC of PRC = 0.35882662243180097
min(+P, Se) = 0.36551724137931035
f1_score = 0.2539682536153416





Epoch 38 Batch 0: Train Loss = 0.3617
Model Loss = 0.3596, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3722
valid_model Loss = 0.3722
valid_decov Loss = 0.0000
confusion matrix:
[[823  36]
 [109  32]]
accuracy = 0.8550000190734863
precision class 0 = 0.8830472230911255
precision class 1 = 0.47058823704719543
recall class 0 = 0.9580907821655273
recall class 1 = 0.22695034742355347
AUC of ROC = 0.7302487636126453
AUC of PRC = 0.3684724756075986
min(+P, Se) = 0.3618421052631579
f1_score = 0.30622007644006527





Epoch 39 Batch 0: Train Loss = 0.4018
Model Loss = 0.3992, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3673
valid_model Loss = 0.3673
valid_decov Loss = 0.0000
confusion matrix:
[[837  22]
 [122  19]]
accuracy = 0.8560000061988831
precision class 0 = 0.8727841377258301
precision class 1 = 0.46341463923454285
recall class 0 = 0.974388837814331
recall class 1 = 0.13475176692008972
AUC of ROC = 0.7320486463725757
AUC of PRC = 0.3577763216261109
min(+P, Se) = 0.36363636363636365
f1_score = 0.20879120194973846





Epoch 40 Batch 0: Train Loss = 0.3044
Model Loss = 0.3022, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3738
valid_model Loss = 0.3738
valid_decov Loss = 0.0000
confusion matrix:
[[845  14]
 [124  17]]
accuracy = 0.8619999885559082
precision class 0 = 0.8720329999923706
precision class 1 = 0.5483871102333069
recall class 0 = 0.9837020039558411
recall class 1 = 0.12056737393140793
AUC of ROC = 0.7347732395412776
AUC of PRC = 0.358784851826515
min(+P, Se) = 0.3493150684931507
f1_score = 0.1976744146496783





Epoch 41 Batch 0: Train Loss = 0.3402
Model Loss = 0.3378, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3642
valid_model Loss = 0.3642
valid_decov Loss = 0.0000
confusion matrix:
[[833  26]
 [114  27]]
accuracy = 0.8600000143051147
precision class 0 = 0.879619836807251
precision class 1 = 0.5094339847564697
recall class 0 = 0.9697322249412537
recall class 1 = 0.19148936867713928
AUC of ROC = 0.7331632526688628
AUC of PRC = 0.3683633236230595
min(+P, Se) = 0.375886524822695
f1_score = 0.2783505380255018





Epoch 42 Batch 0: Train Loss = 0.3217
Model Loss = 0.3196, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3755
valid_model Loss = 0.3755
valid_decov Loss = 0.0000
confusion matrix:
[[812  47]
 [ 99  42]]
accuracy = 0.8539999723434448
precision class 0 = 0.8913282155990601
precision class 1 = 0.47191011905670166
recall class 0 = 0.9452852010726929
recall class 1 = 0.2978723347187042
AUC of ROC = 0.7247748082464353
AUC of PRC = 0.37363653775123795
min(+P, Se) = 0.3546099290780142
f1_score = 0.3652174031599261





Epoch 43 Batch 0: Train Loss = 0.4034
Model Loss = 0.4013, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3661
valid_model Loss = 0.3661
valid_decov Loss = 0.0000
confusion matrix:
[[842  17]
 [119  22]]
accuracy = 0.8640000224113464
precision class 0 = 0.8761706352233887
precision class 1 = 0.5641025900840759
recall class 0 = 0.9802095293998718
recall class 1 = 0.1560283750295639
AUC of ROC = 0.7313963952806745
AUC of PRC = 0.3792233704158958
min(+P, Se) = 0.3732394366197183
f1_score = 0.24444444947772567





Epoch 44 Batch 0: Train Loss = 0.3511
Model Loss = 0.3491, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3679
valid_model Loss = 0.3679
valid_decov Loss = 0.0000
confusion matrix:
[[851   8]
 [126  15]]
accuracy = 0.8659999966621399
precision class 0 = 0.871033787727356
precision class 1 = 0.6521739363670349
recall class 0 = 0.99068683385849
recall class 1 = 0.10638298094272614
AUC of ROC = 0.7398426341036501
AUC of PRC = 0.38144972195175675
min(+P, Se) = 0.36879432624113473
f1_score = 0.1829268370601551





Epoch 45 Batch 0: Train Loss = 0.3541
Model Loss = 0.3519, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3599
valid_model Loss = 0.3599
valid_decov Loss = 0.0000
confusion matrix:
[[847  12]
 [122  19]]
accuracy = 0.8659999966621399
precision class 0 = 0.8740969896316528
precision class 1 = 0.6129032373428345
recall class 0 = 0.9860302805900574
recall class 1 = 0.13475176692008972
AUC of ROC = 0.7434671686523171
AUC of PRC = 0.38769898547198317
min(+P, Se) = 0.36879432624113473
f1_score = 0.22093021626276535


------------ Save best model ------------





Epoch 46 Batch 0: Train Loss = 0.3722
Model Loss = 0.3701, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3622
valid_model Loss = 0.3622
valid_decov Loss = 0.0000
confusion matrix:
[[829  30]
 [111  30]]
accuracy = 0.859000027179718
precision class 0 = 0.8819149136543274
precision class 1 = 0.5
recall class 0 = 0.965075671672821
recall class 1 = 0.21276596188545227
AUC of ROC = 0.7422699989266754
AUC of PRC = 0.3918299311244321
min(+P, Se) = 0.38666666666666666
f1_score = 0.2985074795362657





Epoch 47 Batch 0: Train Loss = 0.3117
Model Loss = 0.3089, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3526
valid_model Loss = 0.3526
valid_decov Loss = 0.0000
confusion matrix:
[[844  15]
 [119  22]]
accuracy = 0.8659999966621399
precision class 0 = 0.8764278292655945
precision class 1 = 0.5945945978164673
recall class 0 = 0.9825378060340881
recall class 1 = 0.1560283750295639
AUC of ROC = 0.7477687233216919
AUC of PRC = 0.39971616891558875
min(+P, Se) = 0.4125874125874126
f1_score = 0.24719101443215155


------------ Save best model ------------





Epoch 48 Batch 0: Train Loss = 0.3610
Model Loss = 0.3590, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3563
valid_model Loss = 0.3563
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7500722430006853
AUC of PRC = 0.39956566663755017
min(+P, Se) = 0.4326241134751773
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 49 Batch 0: Train Loss = 0.3626
Model Loss = 0.3605, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3541
valid_model Loss = 0.3541
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7500144486001371
AUC of PRC = 0.3908995559953438
min(+P, Se) = 0.3900709219858156
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 50 Batch 0: Train Loss = 0.3031
Model Loss = 0.3012, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3544
valid_model Loss = 0.3544
valid_decov Loss = 0.0000
confusion matrix:
[[847  12]
 [126  15]]
accuracy = 0.8619999885559082
precision class 0 = 0.8705036044120789
precision class 1 = 0.5555555820465088
recall class 0 = 0.9860302805900574
recall class 1 = 0.10638298094272614
AUC of ROC = 0.7524913514807752
AUC of PRC = 0.388874580769702
min(+P, Se) = 0.375886524822695
f1_score = 0.17857143708637782


------------ Save best model ------------





Epoch 51 Batch 0: Train Loss = 0.3505
Model Loss = 0.3487, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3565
valid_model Loss = 0.3565
valid_decov Loss = 0.0000
confusion matrix:
[[834  25]
 [115  26]]
accuracy = 0.8600000143051147
precision class 0 = 0.8788198232650757
precision class 1 = 0.5098039507865906
recall class 0 = 0.9708963632583618
recall class 1 = 0.1843971610069275
AUC of ROC = 0.7525408895383879
AUC of PRC = 0.38773663329221153
min(+P, Se) = 0.375886524822695
f1_score = 0.27083333517657565


------------ Save best model ------------





Epoch 52 Batch 0: Train Loss = 0.2942
Model Loss = 0.2921, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3591
valid_model Loss = 0.3591
valid_decov Loss = 0.0000
confusion matrix:
[[824  35]
 [111  30]]
accuracy = 0.8539999723434448
precision class 0 = 0.8812834024429321
precision class 1 = 0.4615384638309479
recall class 0 = 0.9592549204826355
recall class 1 = 0.21276596188545227
AUC of ROC = 0.7528298615411291
AUC of PRC = 0.385236640182836
min(+P, Se) = 0.38461538461538464
f1_score = 0.2912621405377713


------------ Save best model ------------





Epoch 53 Batch 0: Train Loss = 0.3182
Model Loss = 0.3162, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3567
valid_model Loss = 0.3567
valid_decov Loss = 0.0000
confusion matrix:
[[828  31]
 [113  28]]
accuracy = 0.8560000061988831
precision class 0 = 0.8799149990081787
precision class 1 = 0.47457626461982727
recall class 0 = 0.9639115333557129
recall class 1 = 0.19858156144618988
AUC of ROC = 0.7563553199745705
AUC of PRC = 0.3880484238671806
min(+P, Se) = 0.3829787234042553
f1_score = 0.2800000062108042


------------ Save best model ------------





Epoch 54 Batch 0: Train Loss = 0.3446
Model Loss = 0.3427, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3535
valid_model Loss = 0.3535
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7616806611679423
AUC of PRC = 0.38974558707181167
min(+P, Se) = 0.39072847682119205
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 55 Batch 0: Train Loss = 0.3653
Model Loss = 0.3635, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3544
valid_model Loss = 0.3544
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.765734525549253
AUC of PRC = 0.3894573847247702
min(+P, Se) = 0.38028169014084506
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 56 Batch 0: Train Loss = 0.3618
Model Loss = 0.3600, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3493
valid_model Loss = 0.3493
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7704406410224655
AUC of PRC = 0.39325867659040065
min(+P, Se) = 0.3972602739726027
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 57 Batch 0: Train Loss = 0.3488
Model Loss = 0.3468, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3508
valid_model Loss = 0.3508
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7655776550334794
AUC of PRC = 0.3964715986539719
min(+P, Se) = 0.4097222222222222
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 58 Batch 0: Train Loss = 0.3501
Model Loss = 0.3482, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3504
valid_model Loss = 0.3504
valid_decov Loss = 0.0000
confusion matrix:
[[844  15]
 [121  20]]
accuracy = 0.8640000224113464
precision class 0 = 0.8746113777160645
precision class 1 = 0.5714285969734192
recall class 0 = 0.9825378060340881
recall class 1 = 0.1418439745903015
AUC of ROC = 0.7656271930910922
AUC of PRC = 0.4019584407036672
min(+P, Se) = 0.4225352112676056
f1_score = 0.22727273309156912





Epoch 59 Batch 0: Train Loss = 0.3089
Model Loss = 0.3072, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3464
valid_model Loss = 0.3464
valid_decov Loss = 0.0000
confusion matrix:
[[832  27]
 [114  27]]
accuracy = 0.859000027179718
precision class 0 = 0.8794925808906555
precision class 1 = 0.5
recall class 0 = 0.9685680866241455
recall class 1 = 0.19148936867713928
AUC of ROC = 0.7697883899305641
AUC of PRC = 0.40512636331680096
min(+P, Se) = 0.42953020134228187
f1_score = 0.2769230961517476





Epoch 60 Batch 0: Train Loss = 0.3332
Model Loss = 0.3315, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3475
valid_model Loss = 0.3475
valid_decov Loss = 0.0000
confusion matrix:
[[837  22]
 [115  26]]
accuracy = 0.8629999756813049
precision class 0 = 0.8792017102241516
precision class 1 = 0.5416666865348816
recall class 0 = 0.974388837814331
recall class 1 = 0.1843971610069275
AUC of ROC = 0.7746431195766148
AUC of PRC = 0.4035920554761965
min(+P, Se) = 0.41843971631205673
f1_score = 0.27513227534252066


------------ Save best model ------------





Epoch 61 Batch 0: Train Loss = 0.3589
Model Loss = 0.3571, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3459
valid_model Loss = 0.3459
valid_decov Loss = 0.0000
confusion matrix:
[[827  32]
 [106  35]]
accuracy = 0.8619999885559082
precision class 0 = 0.8863880038261414
precision class 1 = 0.5223880410194397
recall class 0 = 0.9627473950386047
recall class 1 = 0.24822695553302765
AUC of ROC = 0.7833948430882025
AUC of PRC = 0.3986456036446075
min(+P, Se) = 0.40425531914893614
f1_score = 0.33653846892842354


------------ Save best model ------------





Epoch 62 Batch 0: Train Loss = 0.3427
Model Loss = 0.3408, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3469
valid_model Loss = 0.3469
valid_decov Loss = 0.0000
confusion matrix:
[[824  35]
 [106  35]]
accuracy = 0.859000027179718
precision class 0 = 0.8860214948654175
precision class 1 = 0.5
recall class 0 = 0.9592549204826355
recall class 1 = 0.24822695553302765
AUC of ROC = 0.784748883329618
AUC of PRC = 0.398800971003945
min(+P, Se) = 0.3916083916083916
f1_score = 0.3317535657342192


------------ Save best model ------------





Epoch 63 Batch 0: Train Loss = 0.3178
Model Loss = 0.3160, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3477
valid_model Loss = 0.3477
valid_decov Loss = 0.0000
confusion matrix:
[[831  28]
 [110  31]]
accuracy = 0.8619999885559082
precision class 0 = 0.88310307264328
precision class 1 = 0.5254237055778503
recall class 0 = 0.9674039483070374
recall class 1 = 0.21985815465450287
AUC of ROC = 0.7876881414146418
AUC of PRC = 0.4002676687123199
min(+P, Se) = 0.4012345679012346
f1_score = 0.3099999883919954


------------ Save best model ------------





Epoch 64 Batch 0: Train Loss = 0.3300
Model Loss = 0.3282, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3434
valid_model Loss = 0.3434
valid_decov Loss = 0.0000
confusion matrix:
[[826  33]
 [108  33]]
accuracy = 0.859000027179718
precision class 0 = 0.8843683004379272
precision class 1 = 0.5
recall class 0 = 0.9615832567214966
recall class 1 = 0.23404255509376526
AUC of ROC = 0.7863506138590974
AUC of PRC = 0.39600938128685026
min(+P, Se) = 0.4011627906976744
f1_score = 0.31884059442038654





Epoch 65 Batch 0: Train Loss = 0.3289
Model Loss = 0.3268, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3471
valid_model Loss = 0.3471
valid_decov Loss = 0.0000
confusion matrix:
[[843  16]
 [120  21]]
accuracy = 0.8640000224113464
precision class 0 = 0.8753893971443176
precision class 1 = 0.5675675868988037
recall class 0 = 0.98137366771698
recall class 1 = 0.1489361673593521
AUC of ROC = 0.7776979664627349
AUC of PRC = 0.3948237282473789
min(+P, Se) = 0.4097222222222222
f1_score = 0.23595505917656212





Epoch 66 Batch 0: Train Loss = 0.3203
Model Loss = 0.3186, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3508
valid_model Loss = 0.3508
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.775848545645192
AUC of PRC = 0.3954490843392007
min(+P, Se) = 0.3972602739726027
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 67 Batch 0: Train Loss = 0.3387
Model Loss = 0.3370, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3494
valid_model Loss = 0.3494
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7704984354230137
AUC of PRC = 0.39051676323756535
min(+P, Se) = 0.3829787234042553
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 68 Batch 0: Train Loss = 0.2848
Model Loss = 0.2831, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3523
valid_model Loss = 0.3523
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.768492144089697
AUC of PRC = 0.3885168512274815
min(+P, Se) = 0.39864864864864863
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Epoch 69 Batch 0: Train Loss = 0.3063
Model Loss = 0.3045, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3497
valid_model Loss = 0.3497
valid_decov Loss = 0.0000
confusion matrix:
[[837  22]
 [117  24]]
accuracy = 0.8610000014305115
precision class 0 = 0.8773584961891174
precision class 1 = 0.52173912525177
recall class 0 = 0.974388837814331
recall class 1 = 0.1702127605676651
AUC of ROC = 0.7760632105615137
AUC of PRC = 0.39472100938762067
min(+P, Se) = 0.40853658536585363
f1_score = 0.25668449075051775





Epoch 70 Batch 0: Train Loss = 0.3529
Model Loss = 0.3511, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3405
valid_model Loss = 0.3405
valid_decov Loss = 0.0000
confusion matrix:
[[825  34]
 [104  37]]
accuracy = 0.8619999885559082
precision class 0 = 0.8880516886711121
precision class 1 = 0.5211267471313477
recall class 0 = 0.9604191184043884
recall class 1 = 0.26241135597229004
AUC of ROC = 0.7849718045888754
AUC of PRC = 0.4070365127637307
min(+P, Se) = 0.4225352112676056
f1_score = 0.34905660824021967





Epoch 71 Batch 0: Train Loss = 0.4068
Model Loss = 0.4050, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3362
valid_model Loss = 0.3362
valid_decov Loss = 0.0000
confusion matrix:
[[821  38]
 [100  41]]
accuracy = 0.8619999885559082
precision class 0 = 0.8914223909378052
precision class 1 = 0.5189873576164246
recall class 0 = 0.955762505531311
recall class 1 = 0.29078012704849243
AUC of ROC = 0.7918410819111783
AUC of PRC = 0.4134253014092765
min(+P, Se) = 0.4258064516129032
f1_score = 0.3727272646584781


------------ Save best model ------------





Epoch 72 Batch 0: Train Loss = 0.2993
Model Loss = 0.2974, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3438
valid_model Loss = 0.3438
valid_decov Loss = 0.0000
confusion matrix:
[[825  34]
 [108  33]]
accuracy = 0.8579999804496765
precision class 0 = 0.8842443823814392
precision class 1 = 0.49253731966018677
recall class 0 = 0.9604191184043884
recall class 1 = 0.23404255509376526
AUC of ROC = 0.7924520512884023
AUC of PRC = 0.4095857172892909
min(+P, Se) = 0.425531914893617
f1_score = 0.31730768233317297


------------ Save best model ------------





Epoch 73 Batch 0: Train Loss = 0.2768
Model Loss = 0.2749, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3376
valid_model Loss = 0.3376
valid_decov Loss = 0.0000
confusion matrix:
[[820  39]
 [105  36]]
accuracy = 0.8560000061988831
precision class 0 = 0.8864864706993103
precision class 1 = 0.47999998927116394
recall class 0 = 0.9545983672142029
recall class 1 = 0.25531914830207825
AUC of ROC = 0.7909328841882776
AUC of PRC = 0.4078915510520219
min(+P, Se) = 0.41843971631205673
f1_score = 0.3333333302059291





Epoch 74 Batch 0: Train Loss = 0.3495
Model Loss = 0.3473, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3427
valid_model Loss = 0.3427
valid_decov Loss = 0.0000
confusion matrix:
[[833  26]
 [111  30]]
accuracy = 0.8629999756813049
precision class 0 = 0.882415235042572
precision class 1 = 0.5357142686843872
recall class 0 = 0.9697322249412537
recall class 1 = 0.21276596188545227
AUC of ROC = 0.7922951807726285
AUC of PRC = 0.4078663965165144
min(+P, Se) = 0.4084507042253521
f1_score = 0.30456854184122567





Epoch 75 Batch 0: Train Loss = 0.3422
Model Loss = 0.3402, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3431
valid_model Loss = 0.3431
valid_decov Loss = 0.0000
confusion matrix:
[[859   0]
 [141   0]]
accuracy = 0.859000027179718
precision class 0 = 0.859000027179718
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7949372105119759
AUC of PRC = 0.41030726370981113
min(+P, Se) = 0.41721854304635764
f1_score = nan


------------ Save best model ------------


  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])







Epoch 76 Batch 0: Train Loss = 0.2678
Model Loss = 0.2661, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3340
valid_model Loss = 0.3340
valid_decov Loss = 0.0000
confusion matrix:
[[855   4]
 [132   9]]
accuracy = 0.8640000224113464
precision class 0 = 0.8662614226341248
precision class 1 = 0.692307710647583
recall class 0 = 0.9953434467315674
recall class 1 = 0.06382978707551956
AUC of ROC = 0.8001056811895739
AUC of PRC = 0.41674056451513447
min(+P, Se) = 0.4225352112676056
f1_score = 0.11688311803042334


------------ Save best model ------------





Epoch 77 Batch 0: Train Loss = 0.3025
Model Loss = 0.3006, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3354
valid_model Loss = 0.3354
valid_decov Loss = 0.0000
confusion matrix:
[[846  13]
 [121  20]]
accuracy = 0.8659999966621399
precision class 0 = 0.8748707175254822
precision class 1 = 0.6060606241226196
recall class 0 = 0.9848661422729492
recall class 1 = 0.1418439745903015
AUC of ROC = 0.8032018097903715
AUC of PRC = 0.4187225582570351
min(+P, Se) = 0.4276729559748428
f1_score = 0.22988506265684483


------------ Save best model ------------





Epoch 78 Batch 0: Train Loss = 0.3268
Model Loss = 0.3249, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3340
valid_model Loss = 0.3340
valid_decov Loss = 0.0000
confusion matrix:
[[832  27]
 [112  29]]
accuracy = 0.8610000014305115
precision class 0 = 0.8813559412956238
precision class 1 = 0.5178571343421936
recall class 0 = 0.9685680866241455
recall class 1 = 0.20567375421524048
AUC of ROC = 0.7976865727094841
AUC of PRC = 0.4167952136940998
min(+P, Se) = 0.42045454545454547
f1_score = 0.2944162375145103





Epoch 79 Batch 0: Train Loss = 0.4011
Model Loss = 0.3992, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3355
valid_model Loss = 0.3355
valid_decov Loss = 0.0000
confusion matrix:
[[824  35]
 [103  38]]
accuracy = 0.8619999885559082
precision class 0 = 0.8888888955116272
precision class 1 = 0.5205479264259338
recall class 0 = 0.9592549204826355
recall class 1 = 0.26950353384017944
AUC of ROC = 0.7935584012417539
AUC of PRC = 0.4150582555877373
min(+P, Se) = 0.4154929577464789
f1_score = 0.35514017190149394





Epoch 80 Batch 0: Train Loss = 0.3268
Model Loss = 0.3249, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3389
valid_model Loss = 0.3389
valid_decov Loss = 0.0000
confusion matrix:
[[818  41]
 [101  40]]
accuracy = 0.8579999804496765
precision class 0 = 0.8900979161262512
precision class 1 = 0.4938271641731262
recall class 0 = 0.9522700905799866
recall class 1 = 0.283687949180603
AUC of ROC = 0.7910237039605676
AUC of PRC = 0.4115259744991822
min(+P, Se) = 0.4295774647887324
f1_score = 0.3603603661147411





Epoch 81 Batch 0: Train Loss = 0.3374
Model Loss = 0.3355, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3396
valid_model Loss = 0.3396
valid_decov Loss = 0.0000
confusion matrix:
[[816  43]
 [100  41]]
accuracy = 0.8569999933242798
precision class 0 = 0.8908296823501587
precision class 1 = 0.488095223903656
recall class 0 = 0.9499418139457703
recall class 1 = 0.29078012704849243
AUC of ROC = 0.7914860591649535
AUC of PRC = 0.40897690835991024
min(+P, Se) = 0.425531914893617
f1_score = 0.3644444288677639





Epoch 82 Batch 0: Train Loss = 0.3302
Model Loss = 0.3281, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3390
valid_model Loss = 0.3390
valid_decov Loss = 0.0000
confusion matrix:
[[817  42]
 [101  40]]
accuracy = 0.8569999933242798
precision class 0 = 0.8899782299995422
precision class 1 = 0.4878048896789551
recall class 0 = 0.9511059522628784
recall class 1 = 0.283687949180603
AUC of ROC = 0.7920144651128229
AUC of PRC = 0.4053371119294605
min(+P, Se) = 0.4236111111111111
f1_score = 0.3587444024959626





Epoch 83 Batch 0: Train Loss = 0.2848
Model Loss = 0.2829, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3382
valid_model Loss = 0.3382
valid_decov Loss = 0.0000
confusion matrix:
[[815  44]
 [104  37]]
accuracy = 0.8519999980926514
precision class 0 = 0.8868334889411926
precision class 1 = 0.45679011940956116
recall class 0 = 0.9487776756286621
recall class 1 = 0.26241135597229004
AUC of ROC = 0.791263137905696
AUC of PRC = 0.4024342901113853
min(+P, Se) = 0.4166666666666667
f1_score = 0.33333335288952065





Epoch 84 Batch 0: Train Loss = 0.2910
Model Loss = 0.2891, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3386
valid_model Loss = 0.3386
valid_decov Loss = 0.0000
confusion matrix:
[[814  45]
 [101  40]]
accuracy = 0.8539999723434448
precision class 0 = 0.8896175026893616
precision class 1 = 0.47058823704719543
recall class 0 = 0.9476134777069092
recall class 1 = 0.283687949180603
AUC of ROC = 0.7905283233844401
AUC of PRC = 0.410387399839898
min(+P, Se) = 0.41843971631205673
f1_score = 0.35398231997440693





Epoch 85 Batch 0: Train Loss = 0.2852
Model Loss = 0.2833, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3415
valid_model Loss = 0.3415
valid_decov Loss = 0.0000
confusion matrix:
[[817  42]
 [101  40]]
accuracy = 0.8569999933242798
precision class 0 = 0.8899782299995422
precision class 1 = 0.4878048896789551
recall class 0 = 0.9511059522628784
recall class 1 = 0.283687949180603
AUC of ROC = 0.7923447188302413
AUC of PRC = 0.41053855470821604
min(+P, Se) = 0.41843971631205673
f1_score = 0.3587444024959626





Epoch 86 Batch 0: Train Loss = 0.3065
Model Loss = 0.3045, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3356
valid_model Loss = 0.3356
valid_decov Loss = 0.0000
confusion matrix:
[[809  50]
 [ 97  44]]
accuracy = 0.8529999852180481
precision class 0 = 0.8929359912872314
precision class 1 = 0.4680851101875305
recall class 0 = 0.9417927861213684
recall class 1 = 0.3120567500591278
AUC of ROC = 0.798495694317159
AUC of PRC = 0.4097186827111009
min(+P, Se) = 0.4326241134751773
f1_score = 0.37446810960769705





Epoch 87 Batch 0: Train Loss = 0.2873
Model Loss = 0.2852, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3347
valid_model Loss = 0.3347
valid_decov Loss = 0.0000
confusion matrix:
[[814  45]
 [100  41]]
accuracy = 0.8550000190734863
precision class 0 = 0.8905907869338989
precision class 1 = 0.4767441749572754
recall class 0 = 0.9476134777069092
recall class 1 = 0.29078012704849243
AUC of ROC = 0.802566071384341
AUC of PRC = 0.40855436865595335
min(+P, Se) = 0.4326241134751773
f1_score = 0.36123346557608665





Epoch 88 Batch 0: Train Loss = 0.3147
Model Loss = 0.3126, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3382
valid_model Loss = 0.3382
valid_decov Loss = 0.0000
confusion matrix:
[[830  29]
 [112  29]]
accuracy = 0.859000027179718
precision class 0 = 0.881104052066803
precision class 1 = 0.5
recall class 0 = 0.9662398099899292
recall class 1 = 0.20567375421524048
AUC of ROC = 0.8000809121607675
AUC of PRC = 0.4004777843340391
min(+P, Se) = 0.4166666666666667
f1_score = 0.29145728176324814





Epoch 89 Batch 0: Train Loss = 0.3619
Model Loss = 0.3598, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3389
valid_model Loss = 0.3389
valid_decov Loss = 0.0000
confusion matrix:
[[829  30]
 [109  32]]
accuracy = 0.8610000014305115
precision class 0 = 0.8837953209877014
precision class 1 = 0.5161290168762207
recall class 0 = 0.965075671672821
recall class 1 = 0.22695034742355347
AUC of ROC = 0.7952592078864588
AUC of PRC = 0.39345483457587666
min(+P, Se) = 0.4161073825503356
f1_score = 0.3152709261569006





Epoch 90 Batch 0: Train Loss = 0.3216
Model Loss = 0.3195, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3384
valid_model Loss = 0.3384
valid_decov Loss = 0.0000
confusion matrix:
[[822  37]
 [106  35]]
accuracy = 0.8569999933242798
precision class 0 = 0.8857758641242981
precision class 1 = 0.4861111044883728
recall class 0 = 0.9569266438484192
recall class 1 = 0.24822695553302765
AUC of ROC = 0.7964976593267777
AUC of PRC = 0.40123145671127114
min(+P, Se) = 0.43356643356643354
f1_score = 0.328638507346269





Epoch 91 Batch 0: Train Loss = 0.2581
Model Loss = 0.2559, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3447
valid_model Loss = 0.3447
valid_decov Loss = 0.0000
confusion matrix:
[[824  35]
 [104  37]]
accuracy = 0.8610000014305115
precision class 0 = 0.8879310488700867
precision class 1 = 0.5138888955116272
recall class 0 = 0.9592549204826355
recall class 1 = 0.26241135597229004
AUC of ROC = 0.7944005482211708
AUC of PRC = 0.40922062696189004
min(+P, Se) = 0.4267515923566879
f1_score = 0.34741784929874464





Epoch 92 Batch 0: Train Loss = 0.3282
Model Loss = 0.3257, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3424
valid_model Loss = 0.3424
valid_decov Loss = 0.0000
confusion matrix:
[[825  34]
 [106  35]]
accuracy = 0.8600000143051147
precision class 0 = 0.8861439228057861
precision class 1 = 0.5072463750839233
recall class 0 = 0.9604191184043884
recall class 1 = 0.24822695553302765
AUC of ROC = 0.7928318430634335
AUC of PRC = 0.4115873676243935
min(+P, Se) = 0.4225352112676056
f1_score = 0.33333334420408534





Epoch 93 Batch 0: Train Loss = 0.2852
Model Loss = 0.2832, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3392
valid_model Loss = 0.3392
valid_decov Loss = 0.0000
confusion matrix:
[[823  36]
 [103  38]]
accuracy = 0.8610000014305115
precision class 0 = 0.8887689113616943
precision class 1 = 0.5135135054588318
recall class 0 = 0.9580907821655273
recall class 1 = 0.26950353384017944
AUC of ROC = 0.7947142892527185
AUC of PRC = 0.4163654008485822
min(+P, Se) = 0.4195804195804196
f1_score = 0.353488359639553





Epoch 94 Batch 0: Train Loss = 0.3071
Model Loss = 0.3051, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3334
valid_model Loss = 0.3334
valid_decov Loss = 0.0000
confusion matrix:
[[817  42]
 [100  41]]
accuracy = 0.8579999804496765
precision class 0 = 0.8909487724304199
precision class 1 = 0.4939759075641632
recall class 0 = 0.9511059522628784
recall class 1 = 0.29078012704849243
AUC of ROC = 0.8007579322814753
AUC of PRC = 0.40224227170179466
min(+P, Se) = 0.40425531914893614
f1_score = 0.36607140402917854





Epoch 95 Batch 0: Train Loss = 0.3051
Model Loss = 0.3029, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3311
valid_model Loss = 0.3311
valid_decov Loss = 0.0000
confusion matrix:
[[820  39]
 [101  40]]
accuracy = 0.8600000143051147
precision class 0 = 0.8903365731239319
precision class 1 = 0.5063291192054749
recall class 0 = 0.9545983672142029
recall class 1 = 0.283687949180603
AUC of ROC = 0.8040109313980465
AUC of PRC = 0.3987716825765999
min(+P, Se) = 0.41379310344827586
f1_score = 0.3636363698603693


------------ Save best model ------------





Epoch 96 Batch 0: Train Loss = 0.3119
Model Loss = 0.3099, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3318
valid_model Loss = 0.3318
valid_decov Loss = 0.0000
confusion matrix:
[[812  47]
 [101  40]]
accuracy = 0.8519999980926514
precision class 0 = 0.8893756866455078
precision class 1 = 0.4597701132297516
recall class 0 = 0.9452852010726929
recall class 1 = 0.283687949180603
AUC of ROC = 0.8049356418068181
AUC of PRC = 0.4024180520928805
min(+P, Se) = 0.4429530201342282
f1_score = 0.3508772110757411


------------ Save best model ------------





Epoch 97 Batch 0: Train Loss = 0.2891
Model Loss = 0.2866, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3333
valid_model Loss = 0.3333
valid_decov Loss = 0.0000
confusion matrix:
[[815  44]
 [ 98  43]]
accuracy = 0.8579999804496765
precision class 0 = 0.8926615715026855
precision class 1 = 0.49425286054611206
recall class 0 = 0.9487776756286621
recall class 1 = 0.304964542388916
AUC of ROC = 0.8040274440839175
AUC of PRC = 0.40369318781100705
min(+P, Se) = 0.41843971631205673
f1_score = 0.37719298125221445





Epoch 98 Batch 0: Train Loss = 0.3274
Model Loss = 0.3254, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3364
valid_model Loss = 0.3364
valid_decov Loss = 0.0000
confusion matrix:
[[814  45]
 [ 97  44]]
accuracy = 0.8579999804496765
precision class 0 = 0.8935235738754272
precision class 1 = 0.49438202381134033
recall class 0 = 0.9476134777069092
recall class 1 = 0.3120567500591278
AUC of ROC = 0.8008322393678944
AUC of PRC = 0.3948456559946868
min(+P, Se) = 0.41843971631205673
f1_score = 0.3826086912871762





Epoch 99 Batch 0: Train Loss = 0.2375
Model Loss = 0.2354, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3399
valid_model Loss = 0.3399
valid_decov Loss = 0.0000
confusion matrix:
[[817  42]
 [102  39]]
accuracy = 0.8560000061988831
precision class 0 = 0.8890097737312317
precision class 1 = 0.48148149251937866
recall class 0 = 0.9511059522628784
recall class 1 = 0.27659574151039124
AUC of ROC = 0.800072655817832
AUC of PRC = 0.39406705931981517
min(+P, Se) = 0.4397163120567376
f1_score = 0.3513513379196259



### Run for test

In [None]:
checkpoint = torch.load(file_name)
save_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()

test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(data_path, 'test'),
                                        listfile=os.path.join(data_path, 'test_listfile.csv'),
                                        period_length=48.0)
test_raw = utils.load_data(test_reader, discretizer, normalizer, small_part, return_names=True)
test_dataset = Dataset(test_raw['data'][0], test_raw['data'][1], test_raw['names'])
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
batch_loss = []
y_true = []
y_pred = []
with torch.no_grad():
    model.eval()
    for step, (batch_x, batch_y, batch_name) in enumerate(test_loader):
        batch_x = batch_x.float().to(device)
        batch_y = batch_y.float().to(device)
        batch_demo = []
        for i in range(len(batch_name)):
            cur_id, cur_ep, _ = batch_name[i].split('_', 2)
            cur_idx = cur_id + '_' + cur_ep
            cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
            batch_demo.append(cur_demo)

        batch_demo = torch.stack(batch_demo).to(device)
        output = model(batch_x, batch_demo)[0]

        loss = get_loss(output, batch_y.unsqueeze(-1))
        batch_loss.append(loss.cpu().detach().numpy())
        y_pred += list(output.cpu().detach().numpy().flatten())
        y_true += list(batch_y.cpu().numpy().flatten())

print("\n==>Predicting on test")
print('Test Loss = %.4f'%(np.mean(np.array(batch_loss))))
y_pred = np.array(y_pred)
y_pred = np.stack([1 - y_pred, y_pred], axis=1)
test_res = metrics.print_metrics_binary(y_true, y_pred)

In [None]:
# Bootstrap
N = len(y_true)
N_idx = np.arange(N)
K = 1000

auroc = []
auprc = []
minpse = []
for i in range(K):
    boot_idx = np.random.choice(N_idx, N, replace=True)
    boot_true = np.array(y_true)[boot_idx]
    boot_pred = y_pred[boot_idx, :]
    test_ret = metrics.print_metrics_binary(boot_true, boot_pred, verbose=0)
    auroc.append(test_ret['auroc'])
    auprc.append(test_ret['auprc'])
    minpse.append(test_ret['minpse'])
    print('%d/%d'%(i+1,K))

print('auroc %.4f(%.4f)'%(np.mean(auroc), np.std(auroc)))
print('auprc %.4f(%.4f)'%(np.mean(auprc), np.std(auprc)))
print('minpse %.4f(%.4f)'%(np.mean(minpse), np.std(minpse)))