In [1]:
import os
import time
import logging
import time
import warnings
from datetime import datetime

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter

from dataset import DatasetConfig, MIMIC3SetLoader
from model import CohortNet

from util import *

warnings.filterwarnings("ignore")

# define env setting
sets = ['train', 'valid', 'test']


def set_random_seed(seed=2000):
    logging.info("[*]random seed: %d" % seed)
    np.random.seed(seed)
    torch.manual_seed(seed)  # cpu
    torch.cuda.manual_seed(seed)  # gpu
    torch.backends.cudnn.deterministic = True  # cudnn

In [2]:
class ModelTrainer():
    def __init__(self, args):
        timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
        args = calculate_warm(args)
        self.args = args
        # print(self.args)

        # the gpu setting
        os.environ['CUDA_VISIBLE_DEVICES'] = str(self.args.gpu)

        self.set_config()

        # get the file path for the logging
        log_root = "Please update the url for log file"
        if os.path.exists(log_root):
            self.log_root = log_root
        else:
            raise FileNotFoundError("[x] Log path does not exist.")
        # if not os.path.exists(self.log_root):
        #     os.mkdir(self.log_root)
        # self.log_root = os.path.abspath(os.path.dirname(__file__)) + "/logs/"
        self.log_path = self.set_log_path(timestamp)

        # the logger setting
        handlers = [logging.FileHandler(self.log_path + 'log_{}.txt'.format(timestamp), mode='w'),
                    logging.StreamHandler()]
        logging.basicConfig(level=logging.INFO, datefmt='%m-%d-%y %H:%M', format='%(asctime)s:%(message)s',
                            handlers=handlers)
        logging.info("================== Start %s =================="%timestamp)

        # set the random seed
        if not self.args.random:
            random_seed = 2000
        else:
            random_seed = np.random.randint(10e6)
        set_random_seed(random_seed)

        logging.info('Timestamp: {}'.format(timestamp))
        logging.info('Arguments')
        for k, v in sorted(vars(self.args).items()):
            logging.info("%s = %s" % (k, str(v)))

        self.learning()

    def set_config(self):
        # for binary-label classification
        self.metrics = ['bceloss', 'auroc', 'auprc', 'accu', 'f1', 'minpse']
        self.train_mode = "SP"

    def learning(self):
        for i_fold in self.args.folds:
            logging.info('============= {}-th fold ============='.format(i_fold))
            self.dataset = {}
            for name in sets:
                self.dataset[name] = self.set_dataset(name, i_fold)
            self.input_dim = self.dataset['train'].input_dim
            self.output_dim = self.dataset['train'].output_dim
            self.set_model(self.args)
            self.recorders = Recorders(sets, self.metrics, args.patience)

            if self.args.mode == "train":
                self.criterion = self.set_criterion(type=self.args.criterion)
                self.writer = SummaryWriter(self.log_path)
                for epoch in range(1, self.args.epochs+1):
                    adjust_learning_rate(self.args, self.optimizer, epoch)

                    time1 = time.time()
                    loss = self.train(epoch)
                    time2 = time.time()
                    save_flag, results = self.validate(epoch)
                    time3 = time.time()
                    logging.info('Epoch {}, lr {:.6f}, train loss {:.4f}, '
                                 'train time {:.2f}s, valid time {:.2f}s, total time {:.2f}s.'.format(
                        epoch,
                        self.optimizer.param_groups[0]['lr'],
                        loss,
                        time2 - time1,
                        time3 - time2,
                        time3 - time1
                    ))
                    logging.info(self.recorders.to_string())
                    # save results
                    for subset in range(len(self.recorders.sets)):
                        for m in range(len(self.recorders.metrics)):
                            self.writer.add_scalar(tag="%s/%s/%s" % (self.train_mode,
                                                                     self.recorders.sets[subset],
                                                                     self.recorders.metrics[m]),
                                                   scalar_value=results[subset, m], global_step=epoch)
                    if save_flag == 1:
                        self.save_results(epoch, i_fold, results)
                    elif save_flag == 0:
                        logging.info("[*] Overfitting... Stop!")
                        break

                self.eval_model(os.path.join(self.log_path, 'ckpt_{i}.pth'.format(i=i_fold)))

            elif self.args.mode == "eval":
                assert self.args.model_path != "#", "[x] Please provide a valid model path!"
                if not os.path.exists(args.model_path):
                    logging.info("[x] Model path is invalid: %s" % self.args.model_path)
                self.eval_model(args.model_path)

    def train(self, epoch):
        self.model.train()
        losses = []
        batch_time = []
        total_steps = (self.dataset['train'].sample_size-1)//self.args.batch_size+1
        for batch_id, batch_x, batch_y in self.dataset['train'].get_generator(self.args.batch_size, shuffle=True):
            start = time.time()
            warmup_learning_rate(self.args, epoch, batch_id+1, total_steps, self.optimizer)

            # info = torch.tensor(batch_x[0]).float()
            label = torch.tensor(batch_y).float()
            tdata = torch.tensor(batch_x[1]).float()
            tmask = torch.tensor(batch_x[2]).float()
            if self.args.dataset_mode == "regular":
                stime = torch.tensor(batch_x[3]).float()
                if torch.cuda.is_available():
                    stime = stime.cuda()
            else:  # irregular
                stime = torch.tensor(batch_x[3]).float()
                tlength = torch.tensor(batch_x[4]).float()
                if torch.cuda.is_available():
                    stime = stime.cuda()
                    tlength = tlength.cuda()
            if torch.cuda.is_available():
                # info = info.cuda()
                tdata = tdata.cuda()
                tmask = tmask.cuda()
                label = label.cuda()

            if self.args.model in ["CohortNet"]:
                out = self.model([tdata, tmask])
                # prediction
                prediction = out[0]
                # representation
                rep = out[1]
            loss = self.criterion(prediction, label)
            losses.append(loss.item())

            # optimize
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            end = time.time()
            batch_time.append(end - start)

            if batch_id % self.args.print_freq == 0:
                logging.info('Train: [{0}][{1}/{2}]\t'
                             'BT avg {batch_time:.3f}\t'
                             'loss avg {loss:.4f}'.format(
                    epoch, batch_id, total_steps, batch_time=np.average(batch_time), loss=np.average(losses))
                )
        return np.average(losses)

    def validate(self, epoch=-1):
        self.model.eval()
        results = {}
        with torch.no_grad():
            for set in self.recorders.sets:
                sub_pred, sub_label = [], []
                for batch_id, batch_x, batch_y in self.dataset[set].get_generator(self.args.batch_size, shuffle=False):
                    tdata = torch.tensor(batch_x[1]).float()
                    tmask = torch.tensor(batch_x[2]).float()
                    if torch.cuda.is_available():
                        tdata = tdata.cuda()
                        tmask = tmask.cuda()

                    if self.args.model in ["CohortNet"]:
                        out = self.model([tdata, tmask])
                        # prediction
                        prediction = out[0]
                        # representation
                        rep = out[1]

                    pred = prediction
                    if torch.cuda.is_available():
                        pred = pred.cpu()

                    sub_pred.extend(list(pred.detach().numpy()))
                    sub_label.extend(batch_y)
                results[set] = binary_eval(y_true=sub_label, y_pred=sub_pred)
        save_flag = self.recorders.insert(results)
        return save_flag, self.recorders.get_epoch_result()

    def set_dataset(self, name, i_fold):
        assert name in ['train', 'valid', 'test'], "[x] No such dataset mode: %s" % name
        from dataset import DatasetConfig
        config = DatasetConfig(set_name=self.args.dataset)
        if self.args.dataset == "MIMIC3":
            logging.info("[*]loading the MIMIC3 dataset: %s."%name)
            from dataset import MIMIC3SetLoader
            return MIMIC3SetLoader(self.args, name, config, i_fold)
        else:
            raise NotImplementedError("[x] cannot support the dataset: %s"%self.args.dataset)

    def set_criterion(self, type="bce"):
        assert type in ['bce'], "[x] No such criterion: %s" % type
        if type == "bce":
            criterion = torch.nn.BCELoss()

        if torch.cuda.is_available():
            if torch.cuda.device_count() > 1 and len(self.args.gpu) > 1:
                criterion = criterion.cuda()
        return criterion

    def set_model(self, args):
        if self.args.dataset_mode == "regular":
            if self.args.model == "CohortNet":
                from model import CohortNet
                model = CohortNet(o_dim=self.output_dim, f_num=self.input_dim, e_dim=args.embed_dim, c_dim=args.compress_dim,
                                  h_dim=args.hidden_dim, fusion_dim=args.fusion_dim, cluster_num=args.k,
                                  clip_min=args.clip_min, clip_max=args.clip_max, active=args.active)
        else:
            raise NotImplementedError("No such mode: %s" % (args.dataset_mode))

        if torch.cuda.is_available():
            if torch.cuda.device_count() > 1 and len(args.gpu) > 1:
                model = torch.nn.DataParallel(model)

            model = model.cuda()
            cudnn.benchmark = True

        assert self.args.opt in ['adam', 'sgd'], "[x] No such optimizer: %s" % self.args.opt
        if self.args.opt == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        momentum=self.args.momentum,
                                        weight_decay=self.args.weight_decay)
            logging.info("[*] optimizer: SGD, lr: %f, momentum: %f, weight decay: %f" % (args.lr,
                                                                                         self.args.momentum,
                                                                                         self.args.weight_decay))
        elif self.args.opt == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
            logging.info("[*] optimizer: Adam, lr: %f, weight decay: %f" % (args.lr, self.args.weight_decay))
        else:
            raise NotImplementedError()

        logging.info(model)
        total_num = sum(p.numel() for p in model.parameters())
        trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
        logging.info("Model param num: %d  trainable: %d" % (total_num, trainable_num))
        self.model = model
        self.optimizer = optimizer

        # load pretrained parameters
        if self.args.model_path != "#":
            assert os.path.exists(args.model_path), "[x] Model path is invalid: %s" % self.args.model_path
            self.load_model(args.model_path)

    def set_log_path(self, timestamp):
        log_path = self.log_root
        if not os.path.exists(log_path):
            os.mkdir(log_path)

        if self.args.debug:
            log_path += "/%s_%s_debug/" % (self.args.model, timestamp)
            if not os.path.exists(log_path):
                os.mkdir(log_path)
            return log_path

        log_path += "/%s/" % self.args.dataset
        if not os.path.exists(log_path):
            os.mkdir(log_path)

        log_path += "/%s/" % self.args.application
        if not os.path.exists(log_path):
            os.mkdir(log_path)

        log_path += "/%s_%s/" % (self.args.model, self.train_mode)
        if not os.path.exists(log_path):
            os.mkdir(log_path)

        log_path += "/%s_%s/" % (self.args.model, timestamp)
        if self.args.mode == "eval":
            log_path = log_path[:-1] + "_eval"
        if self.args.random:
            log_path = log_path[:-1] + "_random"
        log_path = log_path + "/"

        if not os.path.exists(log_path):
            os.mkdir(log_path)
        return log_path

    def save_results(self, epoch, i_fold, results):
        logging.info("[*] Saving files...")
        state = {
            'args': self.args,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'epoch': epoch,
        }
        save_file = os.path.join(self.log_path, 'ckpt_{i}.pth'.format(i=i_fold))
        torch.save(state, save_file)
        np.save(os.path.join(self.log_path, "best_valid_{i}".format(i=i_fold)), results)
        self.recorders.save(os.path.join(self.log_path, "recorders_{i}.npz".format(i=i_fold)))
        self.recorders.record_to_csv(self.log_path)

    def load_model(self, model_path):
        ckpt = torch.load(model_path, map_location='cpu')
        state_dict = ckpt['model']
        model_dict = self.model.state_dict()
        # print(state_dict, model_dict)
        new_state_dict = {}
        skipcount, loadedcount = 0, 0
        for k, v in state_dict.items():
            # print(k,v.shape, v[0])
            k2 = k.replace("module.", "")
            if k in model_dict.keys():
                new_state_dict[k] = v
                loadedcount += 1
            elif k2 in model_dict.keys():
                new_state_dict[k2] = v
                loadedcount += 1
            else:
                logging.info("skiped: %s"%k)
                skipcount += 1
        model_dict.update(new_state_dict)
        self.model.load_state_dict(model_dict)
        logging.info("[*] Model loaded!")
        logging.info("[*] skipped: %d  loaded: %d" % (skipcount, loadedcount))

    def eval_model(self, model_path):
        logging.info("=============== eval ===============")
        self.load_model(model_path)
        save_flag, results = self.validate()
        logging.info(self.recorders.to_string())

In [3]:
from argument import base_train_args, regular_dataset, CohortNet_parse_args
import argparse
base_parser = argparse.ArgumentParser(description="ECHO Framework.",add_help=False, 
                                      parents=[base_train_args(), regular_dataset(), CohortNet_parse_args()])
args = base_parser.parse_args([])
args

Namespace(active='relu', application='inhos_mortality', batch_size=128, clip_max=3.0, clip_min=-3.0, compress_dim=24, cosine=True, criterion='bce', data_clip=False, data_clip_max=inf, data_clip_min=-inf, dataset='MIMIC3', dataset_mode='regular', debug=False, embed_dim=24, epochs=200, ffill=True, ffill_steps=48, fix=False, folds=[3], fusion_dim=32, gpu='0,1,2', hidden_dim=32, inter_type='mul', k=7, lr=0.001, lr_decay_epochs='150,350,500', lr_decay_rate=0.1, max_cohort_size=8000, max_timesteps=48, min_freq=10, min_sample_freq=5, mode='train', model='CohortNet', model_path='#', momentum=0.9, opt='adam', patience=10, print_freq=20, random=True, standardization=True, topn=2, warm=True, weight_decay=0)

In [4]:
import importlib
%reload_ext autoreload
%autoreload 1
%aimport model

args.debug = False
args.gpu = "3,4,5"

ModelTrainer(args)

07-18-23 07:10:[*]random seed: 9093840
07-18-23 07:10:Timestamp: 20230718071054
07-18-23 07:10:Arguments
07-18-23 07:10:active = relu
07-18-23 07:10:application = inhos_mortality
07-18-23 07:10:batch_size = 128
07-18-23 07:10:clip_max = 3.0
07-18-23 07:10:clip_min = -3.0
07-18-23 07:10:compress_dim = 24
07-18-23 07:10:cosine = True
07-18-23 07:10:criterion = bce
07-18-23 07:10:data_clip = False
07-18-23 07:10:data_clip_max = inf
07-18-23 07:10:data_clip_min = -inf
07-18-23 07:10:dataset = MIMIC3
07-18-23 07:10:dataset_mode = regular
07-18-23 07:10:debug = False
07-18-23 07:10:embed_dim = 24
07-18-23 07:10:epochs = 200
07-18-23 07:10:ffill = True
07-18-23 07:10:ffill_steps = 48
07-18-23 07:10:fix = False
07-18-23 07:10:folds = [3]
07-18-23 07:10:fusion_dim = 32
07-18-23 07:10:gpu = 3,4,5
07-18-23 07:10:hidden_dim = 32
07-18-23 07:10:inter_type = mul
07-18-23 07:10:k = 7
07-18-23 07:10:lr = 0.001
07-18-23 07:10:lr_decay_epochs = 150,350,500
07-18-23 07:10:lr_decay_rate = 0.1
07-18-23 07:

07-18-23 07:12:Model param num: 995942  trainable: 995942
07-18-23 07:12:Train: [1][0/133]	BT avg 8.678	loss avg 0.7005
07-18-23 07:12:Train: [1][20/133]	BT avg 0.984	loss avg 0.6401
07-18-23 07:12:Train: [1][40/133]	BT avg 0.796	loss avg 0.5842
07-18-23 07:12:Train: [1][60/133]	BT avg 0.730	loss avg 0.5397
07-18-23 07:13:Train: [1][80/133]	BT avg 0.697	loss avg 0.5081
07-18-23 07:13:Train: [1][100/133]	BT avg 0.678	loss avg 0.4853
07-18-23 07:13:Train: [1][120/133]	BT avg 0.665	loss avg 0.4705
07-18-23 07:14:Epoch 1, lr 0.000144, train loss 0.4641, train time 87.90s, valid time 43.67s, total time 131.57s.
07-18-23 07:14:
      bceloss  auroc    auprc    accu     f1       minpse   
train 0.3890   0.6539   0.2602   0.8677   0.0000   0.2987   
valid 0.3896   0.6167   0.2451   0.8675   0.0000   0.3036   
test  0.3884   0.6500   0.2578   0.8680   0.0000   0.3176   

07-18-23 07:14:[*] Saving files...
07-18-23 07:14:Train: [2][0/133]	BT avg 0.578	loss avg 0.3010
07-18-23 07:14:Train: [2][20

07-18-23 07:33:
      bceloss  auroc    auprc    accu     f1       minpse   
train 0.2816   0.8570   0.5404   0.8879   0.4428   0.5094   
valid 0.2988   0.8378   0.4809   0.8732   0.3558   0.4737   
test  0.2850   0.8521   0.5344   0.8855   0.4292   0.5071   

07-18-23 07:33:[*] Saving files...
07-18-23 07:33:Train: [11][0/133]	BT avg 0.581	loss avg 0.2481
07-18-23 07:33:Train: [11][20/133]	BT avg 0.603	loss avg 0.2731
07-18-23 07:33:Train: [11][40/133]	BT avg 0.604	loss avg 0.2787
07-18-23 07:33:Train: [11][60/133]	BT avg 0.603	loss avg 0.2843
07-18-23 07:33:Train: [11][80/133]	BT avg 0.604	loss avg 0.2827
07-18-23 07:34:Train: [11][100/133]	BT avg 0.603	loss avg 0.2830
07-18-23 07:34:Train: [11][120/133]	BT avg 0.602	loss avg 0.2843
07-18-23 07:35:Epoch 11, lr 0.001000, train loss 0.2846, train time 80.30s, valid time 43.55s, total time 123.86s.
07-18-23 07:35:
      bceloss  auroc    auprc    accu     f1       minpse   
train 0.2891   0.8601   0.5490   0.8856   0.3027   0.5226   
va

07-18-23 07:52:Train: [20][120/133]	BT avg 0.604	loss avg 0.2739
07-18-23 07:53:Epoch 20, lr 0.000993, train loss 0.2731, train time 80.52s, valid time 43.31s, total time 123.83s.
07-18-23 07:53:
      bceloss  auroc    auprc    accu     f1       minpse   
train 0.2650   0.8755   0.5873   0.8931   0.4293   0.5460   
valid 0.2828   0.8624   0.5145   0.8798   0.3420   0.4982   
test  0.2751   0.8616   0.5572   0.8898   0.4190   0.4947   

07-18-23 07:53:[*] Saving files...
07-18-23 07:53:Train: [21][0/133]	BT avg 0.577	loss avg 0.3107
07-18-23 07:53:Train: [21][20/133]	BT avg 0.599	loss avg 0.2653
07-18-23 07:54:Train: [21][40/133]	BT avg 0.601	loss avg 0.2699
07-18-23 07:54:Train: [21][60/133]	BT avg 0.602	loss avg 0.2708
07-18-23 07:54:Train: [21][80/133]	BT avg 0.601	loss avg 0.2717
07-18-23 07:54:Train: [21][100/133]	BT avg 0.600	loss avg 0.2703
07-18-23 07:54:Train: [21][120/133]	BT avg 0.600	loss avg 0.2728
07-18-23 07:55:Epoch 21, lr 0.000992, train loss 0.2700, train time 79.92s,

07-18-23 08:12:Train: [30][60/133]	BT avg 0.602	loss avg 0.2637
07-18-23 08:13:Train: [30][80/133]	BT avg 0.602	loss avg 0.2626
07-18-23 08:13:Train: [30][100/133]	BT avg 0.602	loss avg 0.2623
07-18-23 08:13:Train: [30][120/133]	BT avg 0.602	loss avg 0.2582
07-18-23 08:14:Epoch 30, lr 0.000973, train loss 0.2585, train time 80.26s, valid time 43.74s, total time 124.00s.
07-18-23 08:14:
      bceloss  auroc    auprc    accu     f1       minpse   
train 0.2524   0.8883   0.6271   0.9003   0.4767   0.5650   
valid 0.2842   0.8637   0.5138   0.8794   0.3411   0.4912   
test  0.2768   0.8597   0.5577   0.8865   0.4203   0.5155   

07-18-23 08:14:Train: [31][0/133]	BT avg 0.576	loss avg 0.2695
07-18-23 08:14:Train: [31][20/133]	BT avg 0.600	loss avg 0.2477
07-18-23 08:14:Train: [31][40/133]	BT avg 0.601	loss avg 0.2650
07-18-23 08:14:Train: [31][60/133]	BT avg 0.601	loss avg 0.2565
07-18-23 08:15:Train: [31][80/133]	BT avg 0.601	loss avg 0.2592
07-18-23 08:15:Train: [31][100/133]	BT avg 0.60

07-18-23 08:34:
      bceloss  auroc    auprc    accu     f1       minpse   
train 0.2541   0.8865   0.6230   0.9003   0.5076   0.5693   
valid 0.2779   0.8663   0.5250   0.8855   0.4155   0.5194   
test  0.2749   0.8624   0.5617   0.8879   0.4527   0.5122   



<__main__.ModelTrainer at 0x7f44383a2e48>