In [1]:
import numpy as np
import argparse
import os
import imp
import re
import pickle5 as pickle
import datetime
import random
import math
import logging
import copy
import matplotlib.pyplot as plt
import sklearn
import logging
from sklearn.cluster import KMeans
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import KFold
from sklearn.neighbors import kneighbors_graph

from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils import data
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn import Parameter

from utils import utils
from utils.readers import InHospitalMortalityReader
from utils.preprocessing import Discretizer, Normalizer
from utils import metrics
from utils import common_utils

  after removing the cwd from sys.path.


In [2]:
# Select the target dataset: COVID-19 Dataset from TJ Hospital or HM Hospital
target_dataset = 'PD' 

# Use CUDA if available
device = torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu')
print("available device: {}".format(device))
reverse = True
model_name = 'distcare_change4'

available device: cuda:0


In [3]:
if reverse:
    file_name = 'log_file' + '_' + model_name + '_' + target_dataset + '_' + 'reverse' + '.log'
else:
    file_name = 'log_file' + '_' + model_name + '_' + target_dataset + '.log'
def get_logger(name, file_name):
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    
    # 以下两行是为了在jupyter notebook 中不重复输出日志
    if logger.root.handlers:
        logger.root.handlers[0].setLevel(logging.WARNING)
 
    handler_stdout = logging.StreamHandler()
    handler_stdout.setLevel(logging.INFO)
    handler_stdout.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
    logger.addHandler(handler_stdout)
 
    handler_file = logging.FileHandler(filename=file_name, mode='w', encoding='utf-8')
    handler_file.setLevel(logging.DEBUG)
    handler_file.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
    logger.addHandler(handler_file)
 
    return logger

logger = get_logger(__name__,file_name)

logger.debug('这是希望输出的debug内容')
logger.info('这是希望输出的info内容')
logger.warning('这是希望输出的warning内容')

2023-08-18 16:46:11,275 - INFO - 这是希望输出的info内容


In [4]:
def get_loss(y_pred, y_true):
    loss = torch.nn.BCELoss()
    return loss(y_pred, y_true)

In [5]:
def get_re_loss(y_pred, y_true):
    loss = torch.nn.MSELoss()
    return loss(y_pred, y_true)

In [6]:
def get_kl_loss(x_pred, x_target):
    loss = torch.nn.KLDivLoss(reduce=True, size_average=True)
    return loss(x_pred, x_target)

In [7]:
def get_wass_dist(x_pred, x_target):
    m1 = torch.mean(x_pred, dim=0)
    m2 = torch.mean(x_target, dim=0)
    v1 = torch.var(x_pred, dim=0)
    v2 = torch.var(x_target, dim=0)
    p1 = torch.sum(torch.pow((m1 - m2), 2))
    p2 = torch.sum(torch.pow(torch.pow(v1, 1/2) - torch.pow(v2, 1/2), 2))
    return torch.pow(p1+p2, 1/2)

In [8]:
def pad_sents(sents, pad_token):

    sents_padded = []

    max_length = max([len(_) for _ in sents])
    for i in sents:
        padded = list(i) + [pad_token]*(max_length-len(i))
        sents_padded.append(np.array(padded))


    return np.array(sents_padded)

In [9]:
def batch_iter(x, y, mask, lens, batch_size, shuffle=False):
    """ Yield batches of source and target sentences reverse sorted by length (largest to smallest).
    @param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
    @param batch_size (int): batch size
    @param shuffle (boolean): whether to randomly shuffle the dataset
    """
    batch_num = math.ceil(len(x) / batch_size) # 向下取整
    index_array = list(range(len(x)))

    if shuffle:
        np.random.shuffle(index_array)

    for i in range(batch_num):
        indices = index_array[i * batch_size: (i + 1) * batch_size] #  fetch out all the induces
        
        examples = []
        for idx in indices:
            examples.append((x[idx], y[idx], mask[idx], lens[idx]))
       
        examples = sorted(examples, key=lambda e: len(e[0]), reverse=True)
    
        batch_x = [e[0] for e in examples]
        batch_y = [e[1] for e in examples]
        batch_mask_x = [e[2] for e in examples]
#         batch_name = [e[2] for e in examples]
        batch_lens = [e[3] for e in examples]

        yield batch_x, batch_y, batch_mask_x, batch_lens

In [10]:
def length_to_mask(length, max_len=None, dtype=None):
    """length: B.
    return B x max_len.
    If max_len is None, then max of length will be used.
    """
    assert len(length.shape) == 1, 'Length shape should be 1 dimensional.'
    max_len = max_len or length.max().item()
    mask = torch.arange(max_len, device=length.device,
                        dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1)
    if dtype is not None:
        mask = torch.as_tensor(mask, dtype=dtype, device=length.device)
    return mask

In [11]:
RANDOM_SEED = 43
np.random.seed(RANDOM_SEED) #numpy
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED) # cpu
torch.cuda.manual_seed(RANDOM_SEED) #gpu
torch.backends.cudnn.deterministic=True # cudnn
    
epochs = 150
batch_size = 256
input_dim = 34
hidden_dim = 32
d_model = 32
MHD_num_head = 4
d_ff = 64
output_dim = 1



In [12]:
if target_dataset == 'PD':
    data_path = './data/PD/'
    all_x = pickle.load(open(data_path + 'x.pkl', 'rb'))
    all_time = pickle.load(open(data_path + 'y_z.pkl', 'rb'))
    all_x_len = [len(i) for i in all_x]

    tar_subset_idx = [0, 2, 3, 4, 5, 7, 8, 9, 12, 16, 17, 19, 20, 56, 57, 58]
    tar_other_idx = list(range(69))
    for i in tar_subset_idx:
        tar_other_idx.remove(i)
    for i in range(len(all_x)):
        cur = np.array(all_x[i], dtype=float)
        cur_subset = cur[:, tar_subset_idx]
        cur_other = cur[:, tar_other_idx]
        all_x[i] = np.concatenate((cur_subset, cur_other), axis=1).tolist()
    
print(all_x[0])
print(len(all_x[0][0]))
print(len(all_x))
logger.info(all_x[0])
logger.info(len(all_x[0][0]))
logger.info(len(all_x))

2023-08-18 16:46:11,447 - INFO - [[-0.8427648213988651, 0.3744020210323477, 0.6796123704286434, -1.398975413973587, -0.4831419202847951, -0.2120300841305121, 1.5887596625600091, 0.7945789587268225, -0.8612693268611251, -0.4819729243606949, -0.6745224313841819, 0.7137208435891645, -1.447089954740047, -0.7710128163592748, -1.4231815568069368, -0.5851405270139463, -0.5641898854144399, 0.5775106850669863, 0.3939858698913405, -0.2032969502372001, -0.2890718868318484, 0.1700684310067274, -0.2031244129114749, -0.9752387057279804, -0.995631448716658, -0.7346136214669141, 0.2047416529938912, -0.7879162404292406, -0.4658827214597087, -0.0343615044915247, -1.3314821107815475, 0.3379315521074886, -0.3880554131475662, 0.8285543981909917, 2.770245567717861, 0.0776143028335215, -0.1259757336723928, 0.0863841325254607, -0.1474826847594624, -0.3999358135056357, -0.0116277505661367, -0.0886088512738706, -0.1491049589303728, 0.0552791522161626, -0.0357876785866217, -0.211245000207003, -0.1237195765566573

2023-08-18 16:46:11,451 - INFO - 69
2023-08-18 16:46:11,452 - INFO - 325


[[-0.8427648213988651, 0.3744020210323477, 0.6796123704286434, -1.398975413973587, -0.4831419202847951, -0.2120300841305121, 1.5887596625600091, 0.7945789587268225, -0.8612693268611251, -0.4819729243606949, -0.6745224313841819, 0.7137208435891645, -1.447089954740047, -0.7710128163592748, -1.4231815568069368, -0.5851405270139463, -0.5641898854144399, 0.5775106850669863, 0.3939858698913405, -0.2032969502372001, -0.2890718868318484, 0.1700684310067274, -0.2031244129114749, -0.9752387057279804, -0.995631448716658, -0.7346136214669141, 0.2047416529938912, -0.7879162404292406, -0.4658827214597087, -0.0343615044915247, -1.3314821107815475, 0.3379315521074886, -0.3880554131475662, 0.8285543981909917, 2.770245567717861, 0.0776143028335215, -0.1259757336723928, 0.0863841325254607, -0.1474826847594624, -0.3999358135056357, -0.0116277505661367, -0.0886088512738706, -0.1491049589303728, 0.0552791522161626, -0.0357876785866217, -0.211245000207003, -0.1237195765566573, -0.1259757336723928, 0.04217015

In [13]:
if target_dataset == 'PD':
    data_path = './data/PD/'
    all_x = pickle.load(open(data_path + 'x.pkl', 'rb'))
    all_time = pickle.load(open(data_path + 'y_z.pkl', 'rb'))
    all_x_len = [len(i) for i in all_x]

    tar_subset_idx = [0, 2, 3, 4, 5, 7, 8, 9, 12, 16, 17, 19, 20, 56, 57, 58]
    tar_other_idx = list(range(69))
    for i in tar_subset_idx:
        tar_other_idx.remove(i)
    for i in range(len(all_x)):
        cur = np.array(all_x[i], dtype=float)
        cur_subset = cur[:, tar_subset_idx]
        cur_other = cur[:, tar_other_idx]
        all_x[i] = np.concatenate((cur_subset, cur_other), axis=1).tolist()
    
print(all_x[0])
print(len(all_x[0][0]))
print(len(all_x))
logger.info(all_x[0])
logger.info(len(all_x[0][0]))
logger.info(len(all_x))

2023-08-18 16:46:11,663 - INFO - [[-0.8427648213988651, 0.3744020210323477, 0.6796123704286434, -1.398975413973587, -0.4831419202847951, -0.2120300841305121, 1.5887596625600091, 0.7945789587268225, -0.8612693268611251, -0.4819729243606949, -0.6745224313841819, 0.7137208435891645, -1.447089954740047, -0.7710128163592748, -1.4231815568069368, -0.5851405270139463, -0.5641898854144399, 0.5775106850669863, 0.3939858698913405, -0.2032969502372001, -0.2890718868318484, 0.1700684310067274, -0.2031244129114749, -0.9752387057279804, -0.995631448716658, -0.7346136214669141, 0.2047416529938912, -0.7879162404292406, -0.4658827214597087, -0.0343615044915247, -1.3314821107815475, 0.3379315521074886, -0.3880554131475662, 0.8285543981909917, 2.770245567717861, 0.0776143028335215, -0.1259757336723928, 0.0863841325254607, -0.1474826847594624, -0.3999358135056357, -0.0116277505661367, -0.0886088512738706, -0.1491049589303728, 0.0552791522161626, -0.0357876785866217, -0.211245000207003, -0.1237195765566573

2023-08-18 16:46:11,669 - INFO - 69
2023-08-18 16:46:11,670 - INFO - 325


[[-0.8427648213988651, 0.3744020210323477, 0.6796123704286434, -1.398975413973587, -0.4831419202847951, -0.2120300841305121, 1.5887596625600091, 0.7945789587268225, -0.8612693268611251, -0.4819729243606949, -0.6745224313841819, 0.7137208435891645, -1.447089954740047, -0.7710128163592748, -1.4231815568069368, -0.5851405270139463, -0.5641898854144399, 0.5775106850669863, 0.3939858698913405, -0.2032969502372001, -0.2890718868318484, 0.1700684310067274, -0.2031244129114749, -0.9752387057279804, -0.995631448716658, -0.7346136214669141, 0.2047416529938912, -0.7879162404292406, -0.4658827214597087, -0.0343615044915247, -1.3314821107815475, 0.3379315521074886, -0.3880554131475662, 0.8285543981909917, 2.770245567717861, 0.0776143028335215, -0.1259757336723928, 0.0863841325254607, -0.1474826847594624, -0.3999358135056357, -0.0116277505661367, -0.0886088512738706, -0.1491049589303728, 0.0552791522161626, -0.0357876785866217, -0.211245000207003, -0.1237195765566573, -0.1259757336723928, 0.04217015

In [14]:
long_x = all_x
long_time = all_time

In [15]:
def get_n2n_data(x, y, x_len):
    length = len(x)
    assert length == len(y)
    assert length == len(x_len)
    new_x = []
    new_y = []
    new_x_len = []
    for i in range(length):
        for j in range(len(x[i])):
            new_x.append(x[i][:j+1])
            new_y.append(y[i][j])
            new_x_len.append(j+1)
    return new_x, new_y, new_x_len

In [16]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


class distcare_target(nn.Module):

    def __init__(self, input_dim, hidden_dim, d_model,  MHD_num_head, d_ff, output_dim, keep_prob=0.5):
        super(distcare_target, self).__init__()

        # hyperparameters
        self.input_dim = input_dim  
        self.hidden_dim = hidden_dim  # d_model
        self.d_model = d_model
        self.MHD_num_head = MHD_num_head
        self.d_ff = d_ff
        self.output_dim = output_dim
        self.keep_prob = keep_prob

        # layers

        self.GRUs = clones(nn.GRU(1, self.hidden_dim, batch_first = True), self.input_dim)
        self.output = nn.Linear(self.hidden_dim, self.output_dim)
        self.dropout = nn.Dropout(p = 1 - self.keep_prob)
        self.weight = nn.Parameter(torch.normal(0,1,(1,self.input_dim),requires_grad= True).to(device))
        self.tanh=nn.Tanh()
        self.Linear = nn.Linear(self.hidden_dim, 1)
        self.Linear_los = nn.Linear(self.input_dim, self.output_dim)
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.relu=nn.ReLU()

    def forward(self, input, lens):
        lens = lens.to('cpu')
        # input shape [batch_size, timestep, feature_dim]
#         demo_main = self.tanh(self.demo_proj_main(demo_input)).unsqueeze(1)# b hidden_dim
        
        batch_size = input.size(0)
        time_step = input.size(1)
        feature_dim = input.size(2)
        assert(feature_dim == self.input_dim)# input Tensor : 256 * 48 * 76
        assert(self.d_model % self.MHD_num_head == 0)

        
        GRU_embeded_input = self.GRUs[0](pack_padded_sequence(input[:,:,0].unsqueeze(-1), lens, batch_first=True))[1].squeeze().unsqueeze(1) # b 1 h
#         print(GRU_embeded_input.shape)
        for i in range(feature_dim-1):
            embeded_input = self.GRUs[i+1](pack_padded_sequence(input[:,:,i+1].unsqueeze(-1), lens, batch_first=True))[1].squeeze().unsqueeze(1) # b 1 h
            GRU_embeded_input = torch.cat((GRU_embeded_input, embeded_input), 1)
        
        weight_embeded_input = self.sigmoid(self.weight).unsqueeze(-1) * GRU_embeded_input
        
#         GRU_embeded_input = torch.cat((GRU_embeded_input, demo_main), 1)# b i+1 h
        posi_input = self.dropout(weight_embeded_input) # batch_size * d_input * hidden_dim
        contexts = self.Linear(posi_input).squeeze()# b i
        output = self.Linear_los(self.dropout(contexts))# b 1
        #mask = subsequent_mask(time_step).to(device) # 1 t t 下三角 N to 1任务不用mask
        return output, None, None
    #, self.MultiHeadedAttention.attn




In [17]:
if target_dataset == 'PD':
    input_dim = 69
    
cell = 'GRU'
hidden_dim = 32
d_model = 32
MHD_num_head = 4
d_ff = 64
output_dim = 1

In [18]:
def ckd_batch_iter(x, y, lens, batch_size, shuffle=False):
    """ Yield batches of source and target sentences reverse sorted by length (largest to smallest).
    @param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
    @param batch_size (int): batch size
    @param shuffle (boolean): whether to randomly shuffle the dataset
    """
    batch_num = math.ceil(len(x) / batch_size) # 向下取整
    index_array = list(range(len(x)))

    if shuffle:
        np.random.shuffle(index_array)

    for i in range(batch_num):
        indices = index_array[i * batch_size: (i + 1) * batch_size] #  fetch out all the induces
        
        examples = []
        for idx in indices:
            examples.append((x[idx], y[idx],  lens[idx]))
       
        examples = sorted(examples, key=lambda e: len(e[0]), reverse=True)
    
        batch_x = [e[0] for e in examples]
        batch_y = [e[1] for e in examples]
#         batch_name = [e[2] for e in examples]
        batch_lens = [e[2] for e in examples]
       

        yield batch_x, batch_y, batch_lens

In [19]:
class TargetMultitaskLoss(nn.Module):
    def __init__(self, task_num=2):
        super(TargetMultitaskLoss, self).__init__()
        self.task_num = task_num
        self.alpha = nn.Parameter(torch.ones((task_num)), requires_grad=True)
        self.mse = nn.MSELoss()
        self.bce = nn.BCELoss()

    def forward(self, opt_student, los, outcome, outcome_y):
        MSE_Loss = self.mse(opt_student, los)
        BCE_Loss = self.bce(outcome, outcome_y)
        return MSE_Loss * self.alpha[0] + BCE_Loss * self.alpha[1]

def get_target_multitask_loss(opt_student, los, outcome, outcome_y):
    mtl = TargetMultitaskLoss(task_num=2)
    return mtl(opt_student, los, outcome, outcome_y)

def reverse_los(y, los_info):
    return y * los_info["los_std"] + los_info["los_mean"]

In [20]:
los_info = pickle.load(open(data_path + 'los_info.pkl', 'rb'))
print(los_info)
logger.info(los_info)

2023-08-18 16:46:11,775 - INFO - {'los_mean': 1055.0307777880782, 'los_std': 799.0879849276147}


{'los_mean': 1055.0307777880782, 'los_std': 799.0879849276147}


In [None]:
if target_dataset == 'PD':
    n_splits = 5
    epochs = 30

teacher_flag = True
transfer_flag = True
kfold = KFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_SEED)

if target_dataset == 'PD':    
    data_str = 'pd'

# if teacher_flag:
#     file_name = './model/pretrained-challenge-front-fill-2'+ data_str
# else: 
#     file_name = './model/pretrained-challenge-front-fill-2'+ data_str + '-noteacher'


batch_size = 256

fold_count = 0
total_train_loss = []
total_valid_loss = []

global_best = 10000
mse = []
mad = []
mape = []
kappa = []
history = []

pad_token = np.zeros(input_dim)
# begin_time = time.time()

for train, test in kfold.split(long_x):
        
    train_x = [long_x[i] for i in train]
    train_y = [long_time[i] for i in train]
    train_x_len = [all_x_len[i] for i in train]
    #train_static = [long_static[i] for i in train]
    
    train_x, train_y, train_x_len = get_n2n_data(train_x, train_y, train_x_len)
    if len(train_x) % 256 == 1:
        print(len(train_x))
        print('wrong squeeze!')

# for train, test in kfold.split(long_x):
for train, test in kfold.split(long_x):
    if reverse:
        temp = train
        train = test
        test = temp
    
    model = distcare_target(input_dim = input_dim,output_dim=output_dim, d_model=d_model, MHD_num_head=MHD_num_head, d_ff=d_ff, hidden_dim=hidden_dim).to(device)
    
    # if transfer_flag:
    #     checkpoint = torch.load(file_name, \
    #                     map_location=torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu'))
    #     pretrain_dict = checkpoint['net']
    #     model_dict = model.state_dict()
    #     pretrain_dict = transfer_gru_dict(pretrain_dict, model_dict,latest_idx, common_len)
    #     model_dict.update(pretrain_dict)
    #     model.load_state_dict(model_dict)
        
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    fold_count += 1
#     print(train)

    
    train_x = [long_x[i] for i in train]
    train_y = [long_time[i] for i in train]
    train_x_len = [all_x_len[i] for i in train]
    #train_static = [long_static[i] for i in train]
    
    train_x, train_y, train_x_len = get_n2n_data(train_x, train_y, train_x_len)
    
    test_x = [long_x[i] for i in test]
    test_y = [long_time[i] for i in test]
    test_x_len = [all_x_len[i] for i in test]
    #test_static = [long_static[i] for i in test]
    
    test_x, test_y, test_x_len = get_n2n_data(test_x, test_y, test_x_len)
    
    if not os.path.exists('./model/'+data_str):
        os.mkdir('./model/'+data_str)
        
    
    fold_train_loss = []
    fold_valid_loss = []
    best_mse = 10000
    best_mad = 0
    best_mape = 0
    best_kappa = 0
    
    for each_epoch in range(epochs):
       
        
        epoch_loss = []
        counter_batch = 0
        model.train()  
        
        for step, (batch_x, batch_y, batch_lens) in enumerate(ckd_batch_iter(train_x, train_y, train_x_len, batch_size, shuffle=True)):  
            optimizer.zero_grad()
            batch_x = torch.tensor(pad_sents(batch_x, pad_token), dtype=torch.float32).to(device)
            batch_y = torch.tensor(batch_y, dtype=torch.float32).to(device)
            batch_lens = torch.tensor(batch_lens, dtype=torch.float32).to(device).int()

            masks = length_to_mask(batch_lens).unsqueeze(-1).float()

            opt, decov_loss, emb = model(batch_x, batch_lens)
            

            MSE_Loss = get_re_loss(opt, batch_y.unsqueeze(-1))

#             model_loss = pred_loss + 1e7*decov_loss
            model_loss = MSE_Loss

            loss = model_loss

            epoch_loss.append(MSE_Loss.cpu().detach().numpy())
            loss.backward()
            # print(model.weight.grad == 0)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 20)
            optimizer.step()
            
            if step % 50 == 0:
                print('Fold %d Epoch %d Batch %d: Train Loss = %.4f'%(fold_count,each_epoch, step, loss.cpu().detach().numpy()))
                logger.info('Fold %d Epoch %d Batch %d: Train Loss = %.4f'%(fold_count,each_epoch, step, loss.cpu().detach().numpy()))
            
        epoch_loss = np.mean(epoch_loss)
        fold_train_loss.append(epoch_loss)

        #Validation
        y_true = []
        y_pred = []
        y_pred_flatten = []
        y_true_flatten = []
        outcome_pred_flatten = []
        outcome_true_flatten = []
        with torch.no_grad():
            model.eval()
            valid_loss = []
            valid_true = []
            valid_pred = []
            for batch_x, batch_y, batch_lens in ckd_batch_iter(test_x, test_y, test_x_len, batch_size):
                batch_x = torch.tensor(pad_sents(batch_x, pad_token), dtype=torch.float32).to(device)
                batch_y = torch.tensor(batch_y, dtype=torch.float32).to(device)
                batch_lens = torch.tensor(batch_lens, dtype=torch.float32).to(device).int()
                masks = length_to_mask(batch_lens).unsqueeze(-1).float()
               
                opt, decov_loss, emb = model(batch_x, batch_lens)
                
                MSE_Loss = get_re_loss(opt, batch_y.unsqueeze(-1))
                
                valid_loss.append(MSE_Loss.cpu().detach().numpy())

                y_pred_flatten += [reverse_los(x, los_info) / 30 for x in list(opt.cpu().detach().numpy().flatten())]
                y_true_flatten += [reverse_los(x, los_info) / 30 for x in list(batch_y.cpu().numpy().flatten())]
            

            valid_loss = np.mean(valid_loss)
            fold_valid_loss.append(valid_loss)
            ret = metrics.print_metrics_regression(y_true_flatten, y_pred_flatten, verbose=0)
            history.append(ret)
            #print()

            if each_epoch % 10 == 0:
                print('Fold %d, epoch %d: Loss = %.4f Valid loss = %.4f MSE = %.4f' % (
                    fold_count, each_epoch, fold_train_loss[-1], fold_valid_loss[-1], ret['mse']), flush=True)
                logger.info('Fold %d, epoch %d: Loss = %.4f Valid loss = %.4f MSE = %.4f' % (
                    fold_count, each_epoch, fold_train_loss[-1], fold_valid_loss[-1], ret['mse']))
                # metrics.print_metrics_regression(y_true_flatten, y_pred_flatten)
                
            cur_mse = ret['mse']
            if cur_mse < best_mse:
                print('------------ Save FOLD-BEST model - MSE: %.4f ------------' % cur_mse, flush=True)
                logger.info('------------ Save FOLD-BEST model - MSE: %.4f ------------' % cur_mse)
                metrics.print_metrics_regression(y_true_flatten, y_pred_flatten)
                best_mse = cur_mse
                best_mad = ret['mad']
                state = {
                    'net': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': each_epoch
                }
                # torch.save(state, target_file_name + '_' + str(fold_count))

                if cur_mse < global_best:
                    global_best = cur_mse
                    state = {
                        'net': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'epoch': each_epoch
                    }
                    # torch.save(state, target_file_name)
                    # print('------------ Save best model - MSE: %.4f ------------' % cur_mse, flush=True)
                    # logger.info('------------ Save best model - MSE: %.4f ------------' % cur_mse)

        print('Fold %d, mse = %.4f, mad = %.4f' % (fold_count, ret['mse'], ret['mad']), flush=True)
        logger.info('Fold %d, mse = %.4f, mad = %.4f' % (fold_count, ret['mse'], ret['mad']))

    mse.append(best_mse)
    mad.append(best_mad)
    total_train_loss.append(fold_train_loss)
    total_valid_loss.append(fold_valid_loss)


print('mse %.4f(%.4f)' % (np.mean(mse), np.std(mse)))
print('mad %.4f(%.4f)' % (np.mean(mad), np.std(mad)))
logger.info('mse %.4f(%.4f)' % (np.mean(mse), np.std(mse)))
logger.info('mad %.4f(%.4f)' % (np.mean(mad), np.std(mad)))

2023-08-18 16:46:16,710 - INFO - Fold 1 Epoch 0 Batch 0: Train Loss = 1.1376


Fold 1 Epoch 0 Batch 0: Train Loss = 1.1376
Fold 1, epoch 0: Loss = 1.0231 Valid loss = 1.0170 MSE = 710.5267


2023-08-18 16:46:29,927 - INFO - Fold 1, epoch 0: Loss = 1.0231 Valid loss = 1.0170 MSE = 710.5267


------------ Save FOLD-BEST model - MSE: 710.5267 ------------


2023-08-18 16:46:29,930 - INFO - ------------ Save FOLD-BEST model - MSE: 710.5267 ------------


Custom bins confusion matrix:
[[   0  964  131    0]
 [   0 3395  484    0]
 [   0 2034  306    0]
 [   0 1176  222    0]]
Mean absolute deviation (MAD) = 21.626783204536547
Mean squared error (MSE) = 710.5266594785155
Mean absolute percentage error (MAPE) = 242.8918876026172
Cohen kappa score = 0.011756493412058866
Fold 1, mse = 710.5267, mad = 21.6268


2023-08-18 16:46:29,976 - INFO - Fold 1, mse = 710.5267, mad = 21.6268
2023-08-18 16:46:30,533 - INFO - Fold 1 Epoch 1 Batch 0: Train Loss = 1.0159


Fold 1 Epoch 1 Batch 0: Train Loss = 1.0159
------------ Save FOLD-BEST model - MSE: 707.0658 ------------


2023-08-18 16:46:44,278 - INFO - ------------ Save FOLD-BEST model - MSE: 707.0658 ------------


Custom bins confusion matrix:
[[   0  968  127    0]
 [   0 3418  461    0]
 [   0 1991  349    0]
 [   0 1138  260    0]]
Mean absolute deviation (MAD) = 21.566105543661447
Mean squared error (MSE) = 707.065784608237
Mean absolute percentage error (MAPE) = 241.82238311788439
Cohen kappa score = 0.029812156108557808
Fold 1, mse = 707.0658, mad = 21.5661


2023-08-18 16:46:44,321 - INFO - Fold 1, mse = 707.0658, mad = 21.5661
2023-08-18 16:46:44,773 - INFO - Fold 1 Epoch 2 Batch 0: Train Loss = 0.9549


Fold 1 Epoch 2 Batch 0: Train Loss = 0.9549
------------ Save FOLD-BEST model - MSE: 703.2587 ------------


2023-08-18 16:46:59,124 - INFO - ------------ Save FOLD-BEST model - MSE: 703.2587 ------------


Custom bins confusion matrix:
[[   0 1022   73    0]
 [   0 3547  332    0]
 [   0 2048  292    0]
 [   0 1165  233    0]]
Mean absolute deviation (MAD) = 21.45108856267566
Mean squared error (MSE) = 703.2587458518238
Mean absolute percentage error (MAPE) = 238.20384292766937
Cohen kappa score = 0.03959522857038955
Fold 1, mse = 703.2587, mad = 21.4511


2023-08-18 16:46:59,201 - INFO - Fold 1, mse = 703.2587, mad = 21.4511
2023-08-18 16:46:59,723 - INFO - Fold 1 Epoch 3 Batch 0: Train Loss = 1.0568


Fold 1 Epoch 3 Batch 0: Train Loss = 1.0568
------------ Save FOLD-BEST model - MSE: 698.4095 ------------


2023-08-18 16:47:12,833 - INFO - ------------ Save FOLD-BEST model - MSE: 698.4095 ------------


Custom bins confusion matrix:
[[   0 1011   84    0]
 [   0 3472  407    0]
 [   0 1933  407    0]
 [   0 1063  335    0]]
Mean absolute deviation (MAD) = 21.33897924998676
Mean squared error (MSE) = 698.4094561942762
Mean absolute percentage error (MAPE) = 235.30636420154457
Cohen kappa score = 0.06648991927030423
Fold 1, mse = 698.4095, mad = 21.3390


2023-08-18 16:47:12,876 - INFO - Fold 1, mse = 698.4095, mad = 21.3390
2023-08-18 16:47:13,449 - INFO - Fold 1 Epoch 4 Batch 0: Train Loss = 0.8620


Fold 1 Epoch 4 Batch 0: Train Loss = 0.8620
------------ Save FOLD-BEST model - MSE: 689.4669 ------------


2023-08-18 16:47:27,271 - INFO - ------------ Save FOLD-BEST model - MSE: 689.4669 ------------


Custom bins confusion matrix:
[[   0  823  272    0]
 [   0 2541 1338    0]
 [   0 1364  976    0]
 [   0  561  837    0]]
Mean absolute deviation (MAD) = 21.312802320217152
Mean squared error (MSE) = 689.4669245360717
Mean absolute percentage error (MAPE) = 238.92766448170403
Cohen kappa score = 0.10252641616807912
Fold 1, mse = 689.4669, mad = 21.3128


2023-08-18 16:47:27,315 - INFO - Fold 1, mse = 689.4669, mad = 21.3128
2023-08-18 16:47:27,927 - INFO - Fold 1 Epoch 5 Batch 0: Train Loss = 1.0007


Fold 1 Epoch 5 Batch 0: Train Loss = 1.0007
------------ Save FOLD-BEST model - MSE: 678.4155 ------------


2023-08-18 16:47:42,099 - INFO - ------------ Save FOLD-BEST model - MSE: 678.4155 ------------


Custom bins confusion matrix:
[[   0  681  414    0]
 [   0 2077 1802    0]
 [   0  901 1439    0]
 [   0  357 1041    0]]
Mean absolute deviation (MAD) = 21.237870157416214
Mean squared error (MSE) = 678.4154988995704
Mean absolute percentage error (MAPE) = 241.02729830610028
Cohen kappa score = 0.13487095283494777
Fold 1, mse = 678.4155, mad = 21.2379


2023-08-18 16:47:42,151 - INFO - Fold 1, mse = 678.4155, mad = 21.2379
2023-08-18 16:47:42,762 - INFO - Fold 1 Epoch 6 Batch 0: Train Loss = 0.7586


Fold 1 Epoch 6 Batch 0: Train Loss = 0.7586
------------ Save FOLD-BEST model - MSE: 667.8120 ------------


2023-08-18 16:47:55,675 - INFO - ------------ Save FOLD-BEST model - MSE: 667.8120 ------------


Custom bins confusion matrix:
[[   0  787  308    0]
 [   0 2433 1446    0]
 [   0 1213 1127    0]
 [   0  483  915    0]]
Mean absolute deviation (MAD) = 20.931643248572936
Mean squared error (MSE) = 667.8119696514948
Mean absolute percentage error (MAPE) = 231.92779937873672
Cohen kappa score = 0.1220983442712813
Fold 1, mse = 667.8120, mad = 20.9316


2023-08-18 16:47:55,719 - INFO - Fold 1, mse = 667.8120, mad = 20.9316
2023-08-18 16:47:56,267 - INFO - Fold 1 Epoch 7 Batch 0: Train Loss = 0.9418


Fold 1 Epoch 7 Batch 0: Train Loss = 0.9418


In [None]:
if reverse:
    history_name = 'history' + '_' + model_name + '_' + target_dataset + '_' + 'reverse' + '.pkl'
else:
    history_name = 'history' + '_' + model_name + '_' + target_dataset + '.pkl'
with open(history_name, 'wb') as f:
    pickle.dump(history, f)
