In [33]:
# import torch, dgl, accelerate
from txgnn import TxData, TxGNN, TxEval
# import torch.nn as nn
# import torch.nn.functional as F
# from accelerate.utils import set_seed
# from dgl.distributed import partition_graph
import numpy as np
import time
import pandas as pd
# saving_path = './saved_models/'
# split = 'random'
# split = 'complex_disease'
# split = 'cell_proliferation'
# split = 'mental_health'
# split = 'cardiovascular'
# split = 'anemia'
# split = 'adrenal_gland'
# print(split)

'''
    Let's first try one iteration to increase performance.
'''

def obtain_disease_idx(TxData1, deg):
    '''
        returns the disease idx that have less than k degrees (drug-disease relation)
    '''
    ## extract all disease's ids
    kg = pd.read_csv('data/kg.csv')
    diseases1= kg[kg['x_type'] == 'disease']['x_id']
    diseases2 = kg[kg['y_type'] == 'disease']['y_id']
    disease_ids = pd.concat([diseases1, diseases2]).unique()
    len(disease_ids)

    ## obtain all diseases' degree from disease-drug relation only
    disease_drug1 = kg[(kg['x_type'] == 'disease') & (kg['y_type'] == 'drug')]['x_id']
    disease_drug2 = kg[(kg['x_type'] == 'drug') & (kg['y_type'] == 'disease')]['y_id']
    disease_drug_value_counts = pd.concat([disease_drug1, disease_drug2]).value_counts()
    disease_drug_degree = disease_drug_value_counts.reindex(disease_ids).fillna(0).astype(int)
    disease_drug_degree.sum()

    ## length of ID
    # drug_ids_x = kg[kg['x_type'] == 'drug']['x_id']
    # drug_ids_y = kg[kg['y_type'] == 'drug']['y_id']
    # drug_ids_value_count = pd.concat([drug_ids_x, drug_ids_y]).value_counts()
    # drug_ids_value_count

    disease_drug_degree.index.values
    low_disease = disease_drug_degree[disease_drug_degree < deg]

    id_mapping = TxData1.retrieve_id_mapping()
    id2idx = {id:idx for idx, id in id_mapping['idx2id_disease'].items()}
    print(f"Total number of diseases?: {len(id2idx)}")
    print(f"total number of {deg}> degree diseases?: {len(low_disease)}")
    low_disease_idx = low_disease.index.map(lambda x: id2idx[x] if '_' in x else id2idx[x+'.0'])#.apply(lambda x: id2idx[x])
    low_disease_idx = np.array(low_disease_idx)

    return low_disease_idx

def turn_into_dataframe(results):
    additional_train_dict = []
    for rel, result in results.items():
        for (dis_id, drug_ids), drug_idxs, dis_idx in zip(result['ranked_drug_ids'].items(), result['ranked_drug_idxs'].values(), result['dis_idx'].values()):
            t = 5 ## number of psuedo_labels to be generated for low_diseases
            new_dicts = [{'y_id': dis_id, 'y_idx': dis_idx, 'x_id': drug_id, 'x_idx': drug_idx, 'relation': rel} 
                        for i, (drug_id, drug_idx) in enumerate(zip(drug_ids, drug_idxs)) if i < t]
            additional_train_dict += new_dicts

    df = pd.DataFrame(additional_train_dict)
    df["x_idx"] = df["x_idx"].astype(float)
    df["y_type"] = "disease"
    df["x_type"] = "drug"
    return df

def generate_psuedo_labels(pre_trained_dir='pre_trained_model_ckpt/1', split='complex_disease', size=100, seed=1, deg=1):
    '''
        Loads a pre-trained model, generates psuedo_labels for diseases less than 'deg' and turns them into dataframe.
    '''
    strt = time.time()
    TxData1 = TxData(data_folder_path = './data/')
    TxData1.prepare_split(split=split, seed=seed, no_kg=False)
    low_disease_idx = obtain_disease_idx(TxData1=TxData1, deg=deg)

    txGNN = TxGNN(
                data = TxData1, 
                weight_bias_track = False,
                proj_name = 'TxGNN',
                exp_name = 'TxGNN'
            )
        
    txGNN.model_initialize(n_hid = size, 
                            n_inp = size, 
                            n_out = size, 
                            proto = True,
                            proto_num = 3,
                            attention = False,
                            sim_measure = 'all_nodes_profile',
                            bert_measure = 'disease_name',
                            agg_measure = 'rarity',
                            num_walks = 200,
                            walk_mode = 'bit',
                            path_length = 2)
    txGNN.load_pretrained(pre_trained_dir)
    disease_idxs = low_disease_idx[:30]
    txEval = TxEval(model = txGNN)
    indication = txEval.eval_disease_centric(disease_idxs = disease_idxs, 
                                         relation = 'indication',
                                         save_name = None, 
                                         return_raw="concise",
                                         save_result = False)
    
    contraindication = txEval.eval_disease_centric(disease_idxs = disease_idxs, 
                                        relation = 'contraindication',
                                        save_name = None, 
                                        return_raw="concise",
                                        save_result = False)
    results =  {"indication":indication, "contraindication":contraindication}
    psuedo_training_df = turn_into_dataframe(results)
    psuedo_end = time.time() 
    print(f"time it took to generate psuedo_labels: {psuedo_end - strt}")
    return psuedo_training_df

def train_w_psuedo_labels(size=100, pre_trained_dir='pre_trained_model_ckpt/1', split='complex_disease', additional_train=None, create_psuedo_edges=False, seed=1, save_dir=None):
    '''
        Takes in pretrained model and generate psuedo label? 
    '''
    strt = time.time()
    TxData1 = TxData(data_folder_path = './data/')
    ## add additional psuedo-training labels
    TxData1.prepare_split(split=split, seed=seed, no_kg=False, additional_train=additional_train, create_psuedo_edges=create_psuedo_edges)
    TxGNN1 = TxGNN(
            data = TxData1, 
            weight_bias_track = True,
            proj_name = 'TxGNN',
            exp_name = 'TxGNN'
        )
    TxGNN1.model_initialize(n_hid = size, 
                            n_inp = size, 
                            n_out = size, 
                            proto = True,
                            proto_num = 3,
                            attention = False,
                            sim_measure = 'all_nodes_profile',
                            bert_measure = 'disease_name',
                            agg_measure = 'rarity',
                            num_walks = 200,
                            walk_mode = 'bit',
                            path_length = 2)
    
    ## Train
    TxGNN1.pretrain(n_epoch = 1, #---
                    learning_rate = 1e-3,
                    batch_size = 1024, 
                    train_print_per_n = 20)
    TxGNN1.finetune(n_epoch = 500, #---
                    learning_rate = 5e-4,
                    train_print_per_n = 5,
                    valid_per_n = 20,)
    print(f"time it took for this training iteration: {time.time() - strt}")
    if save_dir is not None:
        noisy_student_fpath = './Noisy_student/'
        TxGNN1.save_model(path = noisy_student_fpath+save_dir)

    # low_disease_idx = 
## self-supverised learning
# additional_train = [{'x_type':'gene/protein', 'x_id': '9796.0',	'relation':'protein_protein',	'y_type':'gene/protein',	'y_id':'56992.0',	'x_idx':27422.0,	'y_idx':19536.0,},
#          {'x_type':'gene/protein', 'x_id': '9796.0',	'relation':'protein_protein',	'y_type':'gene/protein',	'y_id':'56992.0',	'x_idx':27609.0,	'y_idx':19536.0,},
#          {'x_type':'gene/protein', 'x_id': '9796.0',	'relation':'protein_protein',	'y_type':'gene/protein',	'y_id':'56342.0',	'x_idx':27609.0,	'y_idx':19536.0,},
#          {'x_type':'gene/protein', 'x_id': '9796.0',	'relation':'protein_protein',	'y_type':'gene/protein',	'y_id':'24234.0',	'x_idx':24609.0,	'y_idx':22222.0,},
#          {'x_type':'gene/protein', 'x_id': '9796.0',	'relation':'protein_protein',	'y_type':'gene/protein',	'y_id':'324343.0',	'x_idx':11111.0,	'y_idx':19536.0,}]
# create_psuedo_edges = True

# data = [{'x_type':'gene/protein', 'x_id': '9796.0',	'relation':'protein_protein',	'y_type':'gene/protein',	'y_id':'56992.0',	'x_idx':27422.0,	'y_idx':19536.0,},
#          {'x_type':'gene/protein', 'x_id': '9796.0',	'relation':'protein_protein',	'y_type':'gene/protein',	'y_id':'56992.0',	'x_idx':34345.0,	'y_idx':19536.0,},
#          {'x_type':'gene/protein', 'x_id': '9796.0',	'relation':'protein_protein',	'y_type':'gene/protein',	'y_id':'56342.0',	'x_idx':27422.0,	'y_idx':19536.0,},
#          {'x_type':'gene/protein', 'x_id': '9796.0',	'relation':'protein_protein',	'y_type':'gene/protein',	'y_id':'24234.0',	'x_idx':33333.0,	'y_idx':22222.0,},
#          {'x_type':'gene/protein', 'x_id': '9796.0',	'relation':'protein_protein',	'y_type':'gene/protein',	'y_id':'324343.0',	'x_idx':11111.0,	'y_idx':19536.0,}]
# # TxData1.df_train.append(data1, ignore_index=True).drop_duplicates()
# TxData1.df_train[TxData1.df_train.append(data1, ignore_index=True).duplicated()]

## we need the psuedo labels to generate like this:

In [34]:
split = 'complex_disease'
psuedo_labels = generate_psuedo_labels(split=split)
print(f"Total psuedo_labels generated: {len(psuedo_labels)}")
train_w_psuedo_labels(save_dir="The_First_Student/", additional_train=psuedo_labels, create_psuedo_edges=True)

Found local copy...
Found local copy...
Found local copy...
Found saved processed KG... Loading...
Splits detected... Loading splits....
Creating DGL graph....
additional 
Done!
Total number of diseases?: 17080
total number of 1> degree diseases?: 15026


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

time it took to generate psuedo_labels: 360.2800166606903
Found local copy...
Found local copy...
Found local copy...
Found saved processed KG... Loading...
Splits detected... Loading splits....
7467
7470
Creating DGL graph....
additional 
Done!


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33meuku[0m ([33mkuetal[0m). Use [1m`wandb login --relogin`[0m to force relogin


Creating minibatch pretraining dataloader...
Start pre-training with #param: 1015000
Epoch: 0 Step: 0 LR: 0.00100 Loss 0.6940, Pretrain Micro AUROC 0.4528 Pretrain Micro AUPRC 0.4758 Pretrain Macro AUROC 0.4563 Pretrain Macro AUPRC 0.5941
Epoch: 0 Step: 20 LR: 0.00100 Loss 0.6820, Pretrain Micro AUROC 0.5847 Pretrain Micro AUPRC 0.5829 Pretrain Macro AUROC 0.6195 Pretrain Macro AUPRC 0.6964


KeyboardInterrupt: 