In [None]:
# Parameters
note = True
ddi_alpha = 2.5
ddi_beta = 1
demo = True

In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [None]:
from drs import util, dataset, models

import logging
import argparse
import numpy as np
from datetime import datetime
from tqdm.notebook import tqdm
from transformers import get_linear_schedule_with_warmup

import torch 
from torch.cuda.amp import GradScaler, autocast
from torcheval.metrics import MultilabelAUPRC
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

import warnings 
warnings.filterwarnings('ignore')

In [None]:
id_task = datetime.now().strftime('%Y%m%d_%H%M')
model = 'Transformer'
# create model checkpoint folder and log file
path_task = f"../save_model/{model}"
# create folder
os.makedirs(f'{path_task}/{id_task}', exist_ok=True)
logging.basicConfig(filename=f'{path_task}/{id_task}/{id_task}.txt', level=logging.INFO)
print(id_task)

In [None]:
parser = argparse.ArgumentParser()
args = parser.parse_args(args=[])

args.datetime_id = id_task
args.note = note
# embedding path
args.embeddimg_path = f'../data/Fasttext_Multi_300.mdl'
# data path
args.train_path = f'../data/Data_train.pkl'
args.valid_path = f'../data/Data_valid.pkl'


# medication path
args.path_atc = f"../data/list_med.pkl"
# adj DDI path
args.major_ddi_path = f"../data/adj_ddi_major.npy"
args.moder_ddi_path = f"../data/adj_ddi_moder.npy"
args.adj_ddi_major = torch.tensor(np.load(args.major_ddi_path), dtype=torch.float32).cuda()
args.adj_ddi_moder = torch.tensor(np.load(args.moder_ddi_path), dtype=torch.float32).cuda()

# training parameters
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.seed = 42
args.seq_length = 1024
args.alpha = ddi_alpha
args.beta = ddi_beta
args.epochs = 50
args.early_stop = 10
args.batch_size = 128
args.max_grad_norm = 0.1
args.l2_reg_lambda = 0.001
args.learning_rate = 5e-5

In [None]:
# Set Seed
util.set_seed()
# Define embedding
args.embedding_matrix, args.word2idx = util.word2matrix(args)
# Define class label
args.mapping, args.num_classes_layer, args.total_classes = util.class_label(args.path_atc)
# Perform Dataset
train_dataset = dataset.define_ds(args, data_path = args.train_path)
valid_dataset = dataset.define_ds(args, data_path = args.valid_path)
# Perform Dataloader
train_loader = dataset.define_dl(train_dataset, batch_size=args.batch_size)
valid_loader   = dataset.define_dl(valid_dataset, batch_size=args.batch_size)
dataloader   = {'train': train_loader,'valid': valid_loader}
# Debug input
util.debug_result(args, train_dataset[1])

In [None]:
# Define model
net = models.Transformer(
                        embedding_matrix = args.embedding_matrix,
                        num_classes=args.num_classes_layer[3], 
                      )
net = net.to(args.device)
# Define loss function
criterion = models.Loss(args)
# Define optimizer
optimizer = torch.optim.AdamW(net.parameters(),
                              lr=args.learning_rate, 
                              weight_decay=args.l2_reg_lambda, 
                              eps=1e-8)
logging.info(f'Optimizer: {optimizer}')
# Define scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_loader)*args.epochs)
scaler = GradScaler()
# Define metric
auprc_metric = MultilabelAUPRC(num_labels=args.num_classes_layer[3])

In [None]:
args.best_auprc = 0
train_iter, valid_iter = 0, 0
valid_iter = 0

for epoch in range(args.epochs):
    # reset evaluation metric
    auprc_metric.reset()
    sum_train_major_ddi = 0
    sum_n_train_major_ddi_pair = 0
    sum_train_moder_ddi = 0
    sum_n_train_moder_ddi_pair = 0
    # training section 
    net.train()
    with tqdm(dataloader['train'], total=len(dataloader['train'])) as train_process:
        for x_train, y_train in train_process:
            # load inout
            x_train, y_train=  x_train.to(args.device),y_train.to(args.device)
            # optimizer zero grad
            optimizer.zero_grad()
            # model training section
            with autocast():
                logit, scores = net(x_train)
                l_total, l_bce, l_major_ddi, l_moder_ddi = criterion(logit, y_train)
            # tensorboard visulization
            writer.add_scalar('Train/BCEwithlogits_loss', l_bce, train_iter)
            writer.add_scalar('Train/Major DDI loss', l_major_ddi, train_iter)
            writer.add_scalar('Train/Moder DDI loss', l_moder_ddi, train_iter)
            writer.add_scalar('Train/Total_loss', l_total, train_iter)
            
            train_iter+=1
            # backpropagation
            scaler.scale(l_total).backward()
            scaler.step(optimizer)
            torch.nn.utils.clip_grad_norm_(net.parameters(), args.max_grad_norm)
            # step optimizer and scheduler
            scaler.update()
            scheduler.step()
            
            train_process.set_postfix(loss=l_total.item())
            if demo:
                break
    net.eval()
    sum_valid_major_ddi = 0
    sum_n_valid_major_ddi_pair = 0
    sum_valid_moder_ddi = 0
    sum_n_valid_moder_ddi_pair = 0
    with torch.no_grad():
        for x_val, y_val in dataloader['valid']:
            x_val, y_val = x_val.to(args.device), y_val.to(args.device)
            
            logit, scores = net(x_val)
            l_total, l_bce, l_major_ddi, l_moder_ddi = criterion(logit, y_val)
            valid_major_ddi, n_valid_major_ddi_pair = util.cal_ddi_rate_train(scores, args.adj_ddi_major)
            valid_moder_ddi, n_valid_moder_ddi_pair = util.cal_ddi_rate_train(scores, args.adj_ddi_moder)
            sum_valid_major_ddi += valid_major_ddi
            sum_n_valid_major_ddi_pair += n_valid_major_ddi_pair
            sum_valid_moder_ddi += valid_moder_ddi
            sum_n_valid_moder_ddi_pair += n_valid_moder_ddi_pair
            
            auprc_metric.update(scores, y_val)
            
            writer.add_scalar('Valid/BCEwithlogits_loss', l_bce, valid_iter)
            writer.add_scalar('Valid/Major DDI loss', l_major_ddi, valid_iter)
            writer.add_scalar('Valid/Moder DDI loss', l_moder_ddi, valid_iter)
            writer.add_scalar('Valid/Total_loss', l_total, valid_iter)
            
            valid_iter += 1
            if demo:
                break
     
    pr_auc = auprc_metric.compute().item()   

    # print log
    util.printlog(f'epoch {epoch}: PRAUC: {pr_auc:.4f}, Major DDI rate {sum_valid_major_ddi/sum_n_valid_major_ddi_pair:.4f}, Moder DDI rate {sum_valid_moder_ddi/sum_n_valid_moder_ddi_pair:.4f}')
    # save best model
    is_best = pr_auc > args.best_auprc
    if is_best:
        args.early_stop_c = 0
        torch.save(net, f'{path_task}/{id_task}/model_best_{id_task}.pth')
    else:
        args.early_stop_c+=1
        if args.early_stop_c == args.early_stop:
            break
    args.best_auprc = max(pr_auc, args.best_auprc)  
    if demo:
         break
writer.flush()