# Import

In [None]:
from func.transformer import TransformerDataset, TransformerModel, MyLoss
from func.execution import eval_epoch, fit
import numpy as np

import torch
from torch.utils.data import DataLoader

import pickle

torch.manual_seed(0)

# Set up device

In [None]:
if(torch.has_mps):
    device = torch.device('mps')
    print('Training on Mac M1! Device was set as "mps"')
else:
    device = torch.device('cpu')
    print('Training on CPU! Device was set as "cpu"') 

# Parameter dictionary

In [None]:
params = {'batch_size': 64,
          'lr': 0.0001,
          'func':'log10',
          'stat_path': 'stat_test/',
          'version': 'draft',
          'train_percent': 0.7,
          'val_percent':0.15,
          'epoch': 50,
          'max_len_i': 187,
          'max_pos': 187,
          'emb_size': 128,
          'num_heads': 8,
          'num_encoder_layers': 2,
          'num_decoder_layers': 2,
          'dropout_p': 0.1,
         }

# Data

In [None]:
combine = torch.Tensor([])
id2cost = torch.Tensor([])
cost_tensor = torch.Tensor([])

# Loading data

In [None]:
# max_visit=187, which is the default number
data = TransformerDataset(combine)

+ Splitting data

In [None]:
num_patients = len(data)

# divide data into training/validation/testing sets
train_percent = params['train_percent']
val_percent = params['val_percent']

num_train = int(np.around(train_percent * num_patients))
num_val = int(np.around(val_percent * num_patients))
num_test = num_patients - num_train - num_val
print(f"Number of patients for training is: {num_train}")
print(f"Number of patients for validation is: {num_val}")
print(f"Number of patients for testing is: {num_test}")

In [None]:
train, val, test = torch.utils.data.random_split(data, [num_train, num_val, num_test])
print(f"Length for training dataset is: {len(train)}")
print(f"Length for validation dataset is: {len(val)}")
print(f"Length for testing dataset is: {len(test)}")

+ Batchify DataLoader

In [None]:
BATCH_SIZE = params['batch_size']
train_DataLoader = DataLoader(dataset=train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_DataLoader = DataLoader(dataset=val, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_DataLoader = DataLoader(dataset=test, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# Set up model

In [None]:
# parameter setting
max_len_i = params['max_len_i']

cost_vocab_size = len(id2cost_type)  # 53
age_vocab_size = dict_vocab_size['age']  # 93
gender_vocab_size = dict_vocab_size['gender']  # 2
diff_vocab_size = dict_vocab_size['diff']  # 5714
department_vocab_size = dict_vocab_size['department']  # 15
specialist_vocab_size = dict_vocab_size['specialist']  # 34
visit_type_vocab_size = dict_vocab_size['visit_type']  # 8

max_pos = params['max_pos']
emb_size = params['emb_size']
num_heads = params['num_heads']
num_encoder_layers = params['num_encoder_layers'] 
num_decoder_layers = params['num_decoder_layers']
dropout_p = params['dropout_p']

In [None]:
# model initiation
model = TransformerModel(cost_vocab_size=cost_vocab_size,
                          age_vocab_size=age_vocab_size,
                          gender_vocab_size=gender_vocab_size,
                          diff_vocab_size=diff_vocab_size,
                          department_vocab_size=department_vocab_size,
                          specialist_vocab_size=specialist_vocab_size,
                          visit_type_vocab_size=visit_type_vocab_size,
                          max_pos=max_pos,
                          emb_size=emb_size,
                          num_heads=num_heads,
                          num_encoder_layers=num_encoder_layers,
                          num_decoder_layers=num_decoder_layers,
                          dropout_p=dropout_p,
                         ).to(device)

# Training and validating

In [None]:
loss_function = MyLoss()
lr = params['lr']
optimizer = torch.optim.Adam(model.parameters(), lr = lr)
epochs = params['epoch']

In [None]:
train_summary, val_summary, best_model = fit(train_DataLoader, val_DataLoader, model, optimizer, loss_function, id2cost, cost_tensor, params, device, epochs)

In [None]:
# Store the best model
PATH_model = params['stat_path'] + 'model_' + params['func'] + '_' + params['version']
torch.save(best_model.state_dict(), PATH_model)

# Test

In [None]:
# Load the best model's state_dict
loaded_model = TransformerModel(cost_vocab_size=cost_vocab_size,
                                 age_vocab_size=age_vocab_size,
                                 gender_vocab_size=gender_vocab_size,
                                 diff_vocab_size=diff_vocab_size,
                                 department_vocab_size=department_vocab_size,
                                 specialist_vocab_size=specialist_vocab_size,
                                 visit_type_vocab_size=visit_type_vocab_size,
                                 max_pos=max_pos,
                                 emb_size=emb_size,
                                 num_heads=num_heads,
                                 num_encoder_layers=num_encoder_layers,
                                 num_decoder_layers=num_decoder_layers,
                                 dropout_p=dropout_p,
                                ).to(device)
loaded_model.load_state_dict(torch.load(PATH_model))

In [None]:
test_results = eval_epoch(test_DataLoader, model, loss_function, id2cost, cost_tensor, params, device)

In [None]:
# print results
epoch_loss_test = test_results[0]
epoch_top3_test = test_results[1]
epoch_top5_test = test_results[2]
epoch_top10_test = test_results[3]
epoch_mae_test = test_results[4]
epoch_mse_test = test_results[5]
epoch_rmse_test = test_results[6]
epoch_r2_test = test_results[7]
print(f"Test summary:\
        \n\tavg loss: {epoch_loss_test:.3f}\
        \n\tMAE:{epoch_mae_test:.3f}, MSE:{epoch_mse_test:.3f}, RMSE:{epoch_rmse_test:.3f}, R2: {epoch_r2_test:.3f} \
        \n\ttop3 acc: {epoch_top3_test:.2f}%, top5 acc: {epoch_top5_test:.2f}%, top10 acc: {epoch_top10_test:.2f}%")