In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [None]:
from drs import cognet

import os 
import numpy as np
import pandas as pd

import torch 
from torch.utils.data import Dataset, DataLoader

In [None]:
df_train = cognet.read_df('../data/data_train_master.pkl')
df_valid = cognet.read_df('../data/data_valid_master.pkl')
df_test  = cognet.read_df('../data/data_test_master.pkl')

A_ddi = np.load('../data/adj_ddi_major.npy')
atcmap = pd.read_pickle('../data/l4_map_reverse.pkl')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

os.makedirs(f'../save_model/cognet/', exist_ok=True)

In [None]:
import importlib
importlib.reload(cognet)

In [None]:
vocabs = cognet.build_vocabs(df_train, atcmap=atcmap, use_atcmap_for_meds=True)

A_ehr = cognet.build_ehr_coprescription_adj_sparse(df_train, vocabs.med2id, device=device)
A_ehr_norm = cognet.normalize_sparse_adj(A_ehr)

if A_ddi.shape[0] == vocabs.n_med:
    A_ddi_raw = cognet.ddi_matrix_to_sparse(A_ddi, device=device)
    A_ddi_norm = cognet.normalize_sparse_adj(cognet.add_self_loops_sparse(A_ddi_raw))
else:
    A_ddi_raw, A_ddi_norm = cognet.ddi_matrix_atcmap_to_model_sparse(A_ddi, atcmap, vocabs.med2id, device)

A_ddi_dense = A_ddi_raw.coalesce().to_dense().detach().cpu().numpy().astype(np.float32)

med_rank = cognet.build_med_rank_rare_first(df_train, vocabs)

# Define Dataset, Dataloader
max_len = 45
train_ds = cognet.COGNetFlatDataset(df_train, vocabs, med_rank=med_rank, max_len=max_len)
train_loader = DataLoader(train_ds, batch_size=1024, shuffle=True, num_workers=8,
                          pin_memory=False, persistent_workers=True, prefetch_factor=2, 
                          collate_fn=lambda b: cognet.collate_cognet(b, n_med=vocabs.n_med, max_len=max_len),)

In [None]:
# Define COGNet 
model = cognet.COGNet(n_diag=vocabs.n_diag, n_med=vocabs.n_med, 
               emb_dim=64, gcn_hidden=64, max_len=max_len,)

# Train COGNet
cognet.train_cognet(model, train_loader, A_ehr_norm=A_ehr_norm, A_ddi_norm=A_ddi_norm,
                device=device, lr=1e-4, epochs=20, grad_clip=5.0)

In [None]:
scores_valid, y_valid = cognet.get_scores_and_targets( model, df_valid, vocabs, med_rank, A_ehr_norm, A_ddi_norm, 
                                               device, batch_size=512, num_workers=0, max_len=max_len)

tuned = cognet.tune_threshold_micro_f1(scores_valid, y_valid)
scores_test, y_test = cognet.get_scores_and_targets(model, df_test, vocabs, med_rank, A_ehr_norm, A_ddi_norm, 
                                             device, batch_size=512, num_workers=0, max_len=max_len)

In [None]:
# Save valid array
np.save('../save_model/cognet/val_pred.npy', scores_valid)
np.save('../save_model/cognet/val_true.npy', y_valid)
# Save test array
np.save('../save_model/cognet/test_pred.npy', scores_test)
np.save('../save_model/cognet/test_true.npy', y_test)
# Save threshold 
np.save('../save_model/cognet/global_threshold.npy', global_threshold) 