In [None]:
# ONLY RUN THIS CELL IF YOU ARE RUNNING IN COLAB
# ONLY RUN THIS CELL IF YOU ARE RUNNING IN COLAB
import sys
import os

!git clone https://github.com/Jiminator/CADRE.git
!ln -s /content/CADRE/data /content/data
!mkdir -p /content/data/output

print(os.path.exists("data/input/rng.txt"))
print(os.path.exists("data/input/exp_emb_gdsc.csv"))
print(os.path.isdir("data/output"))

sys.path.append('/content/CADRE')
# ONLY RUN THIS CELL IF YOU ARE RUNNING IN COLAB
# ONLY RUN THIS CELL IF YOU ARE RUNNING IN COLAB

# Import Libraries and Fix Seeds

In [None]:
import random
import numpy as np
SEED = 5497
random.seed(SEED)
np.random.seed(SEED)
import torch
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import argparse
import os
import pickle
from utils import fill_mask, bool_ext, load_dataset, split_dataset
from collabfilter import CF

# Define Train and Eval Functions

In [None]:
# If bool_ext is a custom function, define it or replace it with bool
def bool_ext(val):
    return val.lower() in ("yes", "true", "t", "1")

def train(args, pkl_path):
    args.is_train = True
    args.pkl_path = pkl_path
    model = CF(args)
    model.build(ptw_ids)

    if args.use_cuda:
        model = model.cuda()


    logs = {'args':args, 'iter':[],
            'precision':[], 'recall':[],
            'f1score':[], 'accuracy':[], 'auc':[],
            'precision_train':[], 'recall_train':[],
            'f1score_train':[], 'accuracy_train':[], 'auc_train':[],
            'loss':[], 'ptw_ids':ptw_ids}

    print("Training...")
    logs = model.train(train_set, test_set,
        batch_size=args.batch_size,
        test_batch_size=args.test_batch_size,
        max_iter=args.max_iter,
        max_fscore=args.max_fscore,
        test_inc_size=args.test_inc_size,
        logs=logs
    )

    labels, msks, preds, tmr, amtr = model.test(test_set, test_batch_size=args.test_batch_size)
    labels_train, msks_train, preds_train, tmr_train, amtr_train = model.test_train(train_set, test_batch_size=args.test_batch_size)

    logs["preds"] = preds
    logs["msks"] = msks
    logs["labels"] = labels
    logs['tmr'] = tmr
    logs['amtr'] = amtr

    logs['preds_train'] = preds_train
    logs['msks_train'] = msks_train
    logs['labels_train'] = labels_train
    logs['tmr_train'] = tmr_train
    logs['amtr_train'] = amtr_train


    if args.store_model:
        if not args.pkl_path:
            for trial in range(100):
                trial_path = os.path.join(args.output_dir, f"logs{trial}.pkl")
                if not os.path.exists(trial_path):
                    print(f"Auto-saving model to {trial_path}")
                    with open(trial_path, "wb") as f:
                        pickle.dump(logs, f, protocol=2)
                    break
        else:
            save_path = os.path.join(args.output_dir, args.pkl_path)
            if os.path.exists(save_path):
                print(f"Warning: Overwriting existing file {save_path}")
            else:
                print(f"Saving model to {save_path}")
            with open(save_path, "wb") as f:
                pickle.dump(logs, f, protocol=2)

def eval(args, pkl_path):
    args.pkl_path = pkl_path
    logs = {'args':args, 'iter':[],
        'precision':[], 'recall':[],
        'f1score':[], 'accuracy':[], 'auc':[],
        'precision_train':[], 'recall_train':[],
        'f1score_train':[], 'accuracy_train':[], 'auc_train':[],
            'loss':[], 'ptw_ids':ptw_ids}

    print(f"Evaluating from saved logs at: {args.pkl_path}")
    with open(args.output_dir + args.pkl_path, "rb") as f:
        logs = pickle.load(f)
    from utils import evaluate_all

    preds = logs["preds"]
    labels = logs["labels"]
    msks = logs["msks"]
    precision, recall, f1, acc, auc_roc, auc_pr = evaluate_all(labels, msks, preds)

    print(f"\nEvaluation Metrics from {args.pkl_path}:")
    print(f"Accuracy: {acc:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"AUC-ROC: {auc_roc:.4f}")
    print(f"AUC-PR: {auc_pr:.4f}")

Namespace(is_train=True, input_dir='data/input', output_dir='data/output', repository='gdsc', drug_id=-1, use_cuda=True, use_relu=False, init_gene_emb=True, omic='exp', use_attention=True, use_cntx_attn=True, embedding_dim=200, attention_size=128, attention_head=8, hidden_dim_enc=200, use_hid_lyr=False, max_iter=384000, max_fscore=0.6, dropout_rate=0.6, learning_rate=0.3, weight_decay=0.0003, batch_size=8, test_batch_size=8, test_inc_size=1024, model_label='cntx-attn-gdsc')


# Set Default Arguments

In [None]:
# Manually define the arguments using argparse.Namespace
args = argparse.Namespace(
    seed=SEED,
    is_train=True,
    eval=False,
    pkl_path=None,
    store_model=True,
    input_dir="data/input",
    output_dir="data/output/cf/",
    repository="gdsc",
    drug_id=-1,
    use_cuda=True and torch.cuda.is_available(),  # Ensure GPU availability
    use_relu=True,
    init_gene_emb=True,
    scheduler='onecycle',
    shuffle=False,
    omic="exp",
    use_attention=True,
    use_cntx_attn=True,
    embedding_dim=200,
    attention_size=128,
    attention_head=8,
    hidden_dim_enc=200,
    use_hid_lyr=True,
    max_iter=int(48000),
    max_fscore=-1,
    dropout_rate=0.6,
    learning_rate=0.3,
    weight_decay=3e-4,
    batch_size=8,
    test_batch_size=8,
    test_inc_size=1024,
    model_label="cntx-attn-gdsc",
    focal=False,
    alpha=0.6,
    gamma=2.0,
    adam=False,
    mlp=False,
    norm_strategy='None',
    use_residual=False
)


# Now, args is ready to use just like it would be from argparse.parse_args()
print(args)

# Preprocess Dataset

In [None]:
print("Loading drug dataset...")
dataset, ptw_ids = load_dataset(input_dir=args.input_dir, repository=args.repository, drug_id=args.drug_id)
train_set, test_set = split_dataset(dataset, ratio=0.8)

# replace tgt in train_set
train_set['tgt'], train_set['msk'] = fill_mask(train_set['tgt'], train_set['msk'])

args.exp_size = dataset['exp_bin'].shape[1]
args.mut_size = dataset['mut_bin'].shape[1]
args.cnv_size = dataset['cnv_bin'].shape[1]

if args.omic == 'exp':
  args.omc_size = args.exp_size
elif args.omic == 'mut':
  args.omc_size = args.mut_size
elif args.omic == 'cnv':
  args.omc_size = args.cnv_size

args.drg_size = dataset['tgt'].shape[1]
args.train_size = len(train_set['tmr'])
args.test_size = len(test_set['tmr'])

print("Hyperparameters:")
print(args)

Loading drug dataset...


# Train and Evaluate Original CADRE Model

In [None]:
train(args, "default.pkl")

In [None]:
eval(args, "default.pkl")

# Train and Evaluate Improved CADRE Model

In [None]:
args.dropout_rate = 0.0
args.use_relu = False
args.focal = True
args.alpha = 0.7
args.scheduler = 'cosine'
args.use_residual = True
print(args)

Hyperparameters:
Namespace(is_train=True, input_dir='data/input', output_dir='data/output', repository='gdsc', drug_id=-1, use_cuda=True, use_relu=False, init_gene_emb=True, omic='exp', use_attention=True, use_cntx_attn=True, embedding_dim=200, attention_size=128, attention_head=8, hidden_dim_enc=200, use_hid_lyr=False, max_iter=384000, max_fscore=0.6, dropout_rate=0.6, learning_rate=0.3, weight_decay=0.0003, batch_size=8, test_batch_size=8, test_inc_size=1024, model_label='cntx-attn-gdsc', exp_size=3000, mut_size=1000, cnv_size=1000, omc_size=3000, drg_size=260, train_size=676, test_size=170)


In [None]:
train(args, "final.pkl")

Training...
[0,0] | tst acc:53.8, f1:34.0, auc:56.0 | trn acc:49.2, f1:39.9, auc:50.3 | loss:6.368
[1,348] | tst acc:60.6, f1:34.4, auc:62.4 | trn acc:55.8, f1:42.1, auc:56.5 | loss:4.194
[3,20] | tst acc:65.7, f1:41.9, auc:64.4 | trn acc:63.5, f1:47.8, auc:62.5 | loss:0.773
[4,368] | tst acc:64.3, f1:41.8, auc:67.7 | trn acc:64.1, f1:48.5, auc:64.3 | loss:0.712
[6,40] | tst acc:65.7, f1:40.0, auc:61.2 | trn acc:64.0, f1:48.4, auc:64.1 | loss:0.724
[7,388] | tst acc:63.2, f1:38.3, auc:62.3 | trn acc:64.5, f1:48.7, auc:64.7 | loss:0.720
[9,60] | tst acc:67.3, f1:43.9, auc:68.0 | trn acc:64.7, f1:49.1, auc:65.1 | loss:0.711
[10,408] | tst acc:65.2, f1:42.9, auc:67.0 | trn acc:65.0, f1:49.3, auc:65.3 | loss:0.713
[12,80] | tst acc:67.5, f1:47.4, auc:69.4 | trn acc:65.2, f1:49.3, auc:65.4 | loss:0.709
[13,428] | tst acc:67.0, f1:43.7, auc:69.7 | trn acc:65.5, f1:49.9, auc:66.2 | loss:0.660
[15,100] | tst acc:70.0, f1:49.0, auc:75.8 | trn acc:67.0, f1:50.8, auc:69.1 | loss:0.626
[16,448] | 

In [None]:
eval(args, "final.pkl")