In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [1]:
!sudo cp -r "/content/drive/My Drive/UIUC/DLH/Project/data" .

In [6]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import argparse
import os
import pickle
from utils import fill_mask, bool_ext, load_dataset, split_dataset
from collabfilter import CF

In [7]:
# If bool_ext is a custom function, define it or replace it with bool
def bool_ext(val):
    return val.lower() in ("yes", "true", "t", "1")

# Manually define the arguments using argparse.Namespace
args = argparse.Namespace(
    is_train=True,
    input_dir="data/input",
    output_dir="data/output",
    repository="gdsc",
    drug_id=-1,
    use_cuda=True and torch.cuda.is_available(),  # Ensure GPU availability
    use_relu=False,
    init_gene_emb=True,
    omic="exp",
    use_attention=True,
    use_cntx_attn=True,
    embedding_dim=200,
    attention_size=128,
    attention_head=8,
    hidden_dim_enc=200,
    use_hid_lyr=False,
    max_iter=int(384000),
    max_fscore=0.63,
    dropout_rate=0.6,
    learning_rate=0.3,
    weight_decay=3e-4,
    batch_size=8,
    test_batch_size=8,
    test_inc_size=1024,
    model_label="cntx-attn-gdsc"
)

# Now, args is ready to use just like it would be from argparse.parse_args()
print(args)

Namespace(is_train=True, input_dir='data/input', output_dir='data/output', repository='gdsc', drug_id=-1, use_cuda=True, use_relu=False, init_gene_emb=True, omic='exp', use_attention=True, use_cntx_attn=True, embedding_dim=200, attention_size=128, attention_head=8, hidden_dim_enc=200, use_hid_lyr=False, max_iter=384000, max_fscore=0.63, dropout_rate=0.6, learning_rate=0.3, weight_decay=0.0003, batch_size=8, test_batch_size=8, test_inc_size=1024, model_label='cntx-attn-gdsc')


In [8]:
print("Loading drug dataset...")
dataset, ptw_ids = load_dataset(input_dir=args.input_dir, repository=args.repository, drug_id=args.drug_id)
train_set, test_set = split_dataset(dataset, ratio=0.8)

# replace tgt in train_set
train_set['tgt'], train_set['msk'] = fill_mask(train_set['tgt'], train_set['msk'])


Loading drug dataset...


In [9]:
args.exp_size = dataset['exp_bin'].shape[1]
args.mut_size = dataset['mut_bin'].shape[1]
args.cnv_size = dataset['cnv_bin'].shape[1]

if args.omic == 'exp':
  args.omc_size = args.exp_size
elif args.omic == 'mut':
  args.omc_size = args.mut_size
elif args.omic == 'cnv':
  args.omc_size = args.cnv_size

args.drg_size = dataset['tgt'].shape[1]
args.train_size = len(train_set['tmr'])
args.test_size = len(test_set['tmr'])

In [10]:
print("Hyperparameters:")
print(args)

Hyperparameters:
Namespace(is_train=True, input_dir='data/input', output_dir='data/output', repository='gdsc', drug_id=-1, use_cuda=True, use_relu=False, init_gene_emb=True, omic='exp', use_attention=True, use_cntx_attn=True, embedding_dim=200, attention_size=128, attention_head=8, hidden_dim_enc=200, use_hid_lyr=False, max_iter=384000, max_fscore=0.63, dropout_rate=0.6, learning_rate=0.3, weight_decay=0.0003, batch_size=8, test_batch_size=8, test_inc_size=1024, model_label='cntx-attn-gdsc', exp_size=3000, mut_size=1000, cnv_size=1000, omc_size=3000, drg_size=260, train_size=676, test_size=170)


In [11]:
model = CF(args)
model.build(ptw_ids)

if args.use_cuda:
    model = model.cuda()


logs = {'args':args, 'iter':[],
        'precision':[], 'recall':[],
        'f1score':[], 'accuracy':[], 'auc':[],
        'precision_train':[], 'recall_train':[],
        'f1score_train':[], 'accuracy_train':[], 'auc_train':[],
        'loss':[], 'ptw_ids':ptw_ids}

if args.is_train:
    print("Training...")
    logs = model.train(train_set, test_set,
        batch_size=args.batch_size,
        test_batch_size=args.test_batch_size,
        max_iter=args.max_iter,
        max_fscore=args.max_fscore,
        test_inc_size=args.test_inc_size,
        logs=logs
    )

    labels, msks, preds, tmr, amtr = model.test(test_set, test_batch_size=args.test_batch_size)
    labels_train, msks_train, preds_train, tmr_train, amtr_train = model.test_train(train_set, test_batch_size=args.test_batch_size)

    logs["preds"] = preds
    logs["msks"] = msks
    logs["labels"] = labels
    logs['tmr'] = tmr
    logs['amtr'] = amtr

    logs['preds_train'] = preds_train
    logs['msks_train'] = msks_train
    logs['labels_train'] = labels_train
    logs['tmr_train'] = tmr_train
    logs['amtr_train'] = amtr_train

else:
    print("LR finding...")
    logs = model.find_lr(train_set, test_set,
        batch_size=args.batch_size,
        test_batch_size=args.test_batch_size,
        max_iter=args.max_iter,
        max_fscore=args.max_fscore,
        test_inc_size=args.test_inc_size,
        logs=logs
    )

for trial in range(0, 100):
    if os.path.exists("data/output/cf-rep/logs"+str(trial)+".pkl"):
        continue
    print(trial)
    with open("data/output/cf/logs"+str(trial)+".pkl", "wb") as f:
        pickle.dump(logs, f, protocol=2)
    break

Training...
[0,0] | tst acc:49.4, f1:42.4, auc:49.2 | trn acc:51.2, f1:40.7, auc:49.8 | loss:6.015
[1,348] | tst acc:63.8, f1:53.5, auc:65.3 | trn acc:56.4, f1:42.1, auc:56.8 | loss:4.105
[3,20] | tst acc:64.3, f1:52.3, auc:66.0 | trn acc:64.4, f1:48.4, auc:63.1 | loss:0.901
[4,368] | tst acc:64.8, f1:50.3, auc:62.4 | trn acc:64.9, f1:48.7, auc:65.2 | loss:0.730
[6,40] | tst acc:70.7, f1:58.7, auc:72.9 | trn acc:65.4, f1:48.6, auc:65.5 | loss:0.724
[7,388] | tst acc:65.3, f1:52.4, auc:69.0 | trn acc:65.4, f1:48.9, auc:65.6 | loss:0.724
[9,60] | tst acc:62.3, f1:50.0, auc:61.8 | trn acc:65.7, f1:49.0, auc:66.0 | loss:0.721
[10,408] | tst acc:61.5, f1:48.2, auc:64.6 | trn acc:65.8, f1:48.9, auc:66.1 | loss:0.720
[12,80] | tst acc:68.5, f1:56.4, auc:71.6 | trn acc:66.4, f1:49.5, auc:65.3 | loss:0.694
[13,428] | tst acc:69.2, f1:57.2, auc:72.5 | trn acc:67.3, f1:50.1, auc:68.9 | loss:0.634
[15,100] | tst acc:69.2, f1:56.3, auc:72.5 | trn acc:69.2, f1:52.2, auc:71.1 | loss:0.615
[16,448] | 