In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [None]:
from drs import gamenet

import os 
import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader

In [None]:
df_train = gamenet.read_df('../data/data_train_master.pkl')
df_valid = gamenet.read_df('../data/data_valid_master.pkl')
df_valid = gamenet.read_df('../data/data_valid_master.pkl')

A_ddi  = np.load('../data/adj_ddi_major.npy')
atcmap = pd.read_pickle('../data/l4_map_reverse.pkl')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.makedirs(f'../save_model/gamenet/', exist_ok=True)

In [None]:
vocabs = gamenet.build_vocabs_from_train_df(df_train)

A_ehr = gamenet.build_ehr_coprescription_adj_sparse(df_train, vocabs.med2id)
A_ehr_norm = gamenet.normalize_sparse_adj(A_ehr)

A_ddi, A_ddi_norm = gamenet.ddi_matrix_atcmap_to_model_sparse(A_ddi, atcmap, vocabs.med2id, device)

In [None]:
train_ds = gamenet.FlatVisitDataset(df_train, vocabs)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4, 
                          collate_fn=lambda b: gamenet.collate_flat_visits(b, n_med=vocabs.n_med))

model = gamenet.GAMENet(n_diag=vocabs.n_diag, n_med=vocabs.n_med, emb_dim=64, rnn_hidden=128, gcn_hidden=64, ddi_lambda=0.1)
gamenet.train_gamenet(model, train_loader, A_ehr_norm, A_ddi_norm, A_ddi, device, lr=1e-3, epochs=20)

In [None]:
A_ddi_dense = A_ddi.coalesce().to_dense().cpu().numpy().astype(np.float32)
scores_valid, y_valid = gamenet.get_probs_and_targets(model, df_valid, vocabs, A_ehr_norm, A_ddi_norm, device, batch_size=256)

global_threshold = gamenet.tune_threshold_f1(scores_valid, y_valid)
scores_test, y_test = gamenet.get_probs_and_targets(model, df_test, vocabs, A_ehr_norm, A_ddi_norm, device, batch_size=256)

In [None]:
# Save valid array
np.save('../save_model/gamenet/val_pred.npy', scores_valid)
np.save('../save_model/gamenet/val_true.npy', y_valid)
# Save test array
np.save('../save_model/gamenet/test_pred.npy', scores_test)
np.save('../save_model/gamenet/test_true.npy', y_test)
# Save threshold 
np.save('../save_model/gamenet/global_threshold.npy', global_threshold) 