In [1]:
import os
import pandas as pd
import numpy as np

import torch
import deepchem as dc
from pytorch_lightning import seed_everything

import wandb
os.environ["WANDB_SILENT"] = "true"
wandb.login(key = "27edf9c66b032c03f72d30e923276b93aa736429")

import torch.multiprocessing as mp
import gc
torch.multiprocessing.set_sharing_strategy('file_system')

from utils.data_utils import scafoldsplit_train_test, convert_to_dataframe, get_initial_set_with_main_and_aux_samples
from utils.data_utils import convert_dataframe_to_dataloader

from utils.utils import wandb_init_model, compute_binary_classification_metrics_MT, active_learning_loop
from utils.model_utils import get_pred_with_uncertainities

Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/mmasood1/.conda/envs/env_arslan/lib/python3.9/site-packages/deepchem/models/torch_models/__init__.py)
Skipped loading some Jax models, missing a dependency. No module named 'jax'


In [2]:
config = {
        # directories
        "project_name": "Test",
        "metadata_dir": '/projects/home/mmasood1/trained_model_predictions/Tox21/Frozen_BERT/Main_aux_Tasks/',
        "target_file": "/projects/home/mmasood1/arslan_data_repository/Tox21/complete_Tox21.csv",
        "BERT_features_file":"/projects/home/mmasood1/arslan_data_repository/Tox21/Tox21_BERT_features.csv",
        "model_weights_dir" : '/projects/home/mmasood1/Model_weights/Tox21/',
        "pos_weights": "/projects/home/mmasood1/arslan_data_repository/Tox21/pos_weights.csv",
        "class_weights": "/projects/home/mmasood1/arslan_data_repository/Tox21/target_weights.csv",
        
        # data
        "features_type" :"BERT",
        "FP_size" : 1024,
        "train_frac": 0.8,

        # architechture
        "input_dim": 768,
        "hidden_dim": 128,
        "depth" : 1,
        "dropout_p": 0.2,
        "BatchNorm1d": True,
        "use_skip_connection": True,
    
        # training
        "optim": 'Adam',#SGD
        "lr_schedulers": "CosineAnnealingLR",
        "lr": 1e-3,
        "l2_lambda": 0.0,
        "optm_l2_lambda": 1e-3,
        "epochs": 10,
        "compute_metric_after_n_epochs": 5,
        "batch_size": 512,
        "EarlyStopping": False, 
        "pretrained_model": False,
        
        # loss
        "missing" : 'nan',
        "alpha": 0.0,
        "beta": 0.0,
        "gamma":0.0,

        "gpu": [1],
        "accelerator": "gpu",
        "device" :torch.device("cuda"),
        "return_trainer": True, 
        "save_predicitons" : True,
        "Final_model": False,

        # active learning
        "num_forward_passes": 2,
        "num_itterations": 1000,
        "sampling_strategy": "EPIG_MT",
        "n_query":100
    }

# get targets information
# get targets information
data = pd.read_csv(config["target_file"])
all_tasks = data.loc[:, "NR-AR":"SR-p53"].columns.tolist()
config["main_task"] = ["NR-AR"]

config["aux_task"]  = all_tasks.copy()
config["aux_task"].remove(config["main_task"][0])

config["main_task_index"] = all_tasks.index(config["main_task"][0])
config["aux_task_index"] = [all_tasks.index(i) for i in config["aux_task"]]

target_names = config["main_task"] + config["aux_task"]
config["project_name"] = config["project_name"] +"_"+ config["main_task"][0]

config["main_task_samples"] = 498
config["aux_task_samples"] = 100

config["num_of_tasks"] = len(target_names)
config["selected_tasks"] = target_names

config["sample_only_from_aux"] = True

config["loss_type"] = "BCE" #"Focal_loss",# "BCE","Focal_loss_v2"
config["seed"] = 42


In [3]:
# Splitting by using deepchem
np.random.seed(config["seed"])
train_set, test_set = scafoldsplit_train_test(config)
config["main_task_samples"] = (train_set.y[:, config["main_task_index"]] == 1).sum() * 2
initial_set, train_set = get_initial_set_with_main_and_aux_samples(train_set, config)

randomstratifiedsplitter = dc.splits.RandomStratifiedSplitter()
pool_set, val_set = randomstratifiedsplitter.train_test_split(train_set,
                                                            frac_train = 0.85,
                                                            seed = config["seed"])

print("train_set", sorted(np.nansum(train_set.y, axis=0)))
print("test_set", sorted(np.nansum(test_set.y, axis=0)))
print("pool_set", sorted(np.nansum(pool_set.y, axis=0)))
print("val_set", sorted(np.nansum(val_set.y, axis=0)))
print("initial_set", sorted(np.nansum(initial_set.y, axis=0)))


# In[4]:


# Who cares about deepchem data_object, trash it
initial_set = convert_to_dataframe(initial_set, config["selected_tasks"])
val_set = convert_to_dataframe(val_set, config["selected_tasks"])
pool_set = convert_to_dataframe(pool_set, config["selected_tasks"])
test_set = convert_to_dataframe(test_set, config["selected_tasks"])

train_test_features (6121, 768) (1531, 768)
train_test_targets (6121, 12) (1531, 12)
train_set [0.0, 114.0, 183.0, 183.0, 201.0, 246.0, 253.0, 276.0, 573.0, 628.0, 673.0, 675.0]
test_set [42.0, 47.0, 52.0, 55.0, 63.0, 88.0, 91.0, 139.0, 140.0, 177.0, 204.0, 218.0]
pool_set [0.0, 93.0, 152.0, 153.0, 172.0, 215.0, 216.0, 231.0, 489.0, 528.0, 562.0, 568.0]
val_set [0.0, 21.0, 29.0, 30.0, 31.0, 31.0, 37.0, 45.0, 84.0, 100.0, 107.0, 111.0]
initial_set [0.0, 0.0, 2.0, 2.0, 3.0, 3.0, 4.0, 12.0, 12.0, 14.0, 14.0, 249.0]


In [4]:
from utils.models import Vanilla_MLP_classifier
itteration = 0
seed_everything(seed = config["seed"])
config["itteration"] = itteration
config["model_name"] = rf'itteration_{config["itteration"]}_s{config["seed"]}_alpha_{config["alpha"]}_gamma_{config["gamma"]}_loss_type_{config["loss_type"]}_λ{config["optm_l2_lambda"]}'




train_set_main_active,train_set_main_inactive =  (initial_set[config["main_task"]] == 1).sum().values, (initial_set[config["main_task"]] == 0).sum().values
train_set_aux_active,train_set_aux_inactive = (initial_set[config["aux_task"]] == 1).sum().sum(), (initial_set[config["aux_task"]] == 0).sum().sum()
train_set_main_total = train_set_main_active + train_set_main_inactive
train_set_aux_total = train_set_aux_active + train_set_aux_inactive
# get dataloaders
train_dataloader = convert_dataframe_to_dataloader(dataframe= initial_set, config = config, shuffle= True, drop_last = True)
val_dataloader = convert_dataframe_to_dataloader(dataframe= val_set, config = config, shuffle= False, drop_last = False)
test_dataloader = convert_dataframe_to_dataloader(dataframe= test_set, config = config, shuffle= False, drop_last = False)
pool_dataloader = convert_dataframe_to_dataloader(dataframe= pool_set, config = config, shuffle= False, drop_last = False)

# Train model
config["training_steps"] = len(train_dataloader)
trained_model, run, trainer = wandb_init_model(model = Vanilla_MLP_classifier, 
                                                        train_dataloader = train_dataloader,
                                                        val_dataloader =val_dataloader,
                                                        config = config, 
                                                        model_type = 'MLP')

Global seed set to 42


After merging (594, 768)
After merging (916, 768)
After merging (1531, 768)
After merging (5192, 768)


  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name                   | Type              | Params
-------------------------------------------------------------
0 | weighted_creterien     | BCEWithLogitsLoss | 0     
1 | non_weighted_creterian | BCEWithLogitsLoss | 0     
2 | FL                     | FocalLoss         | 0     
3 | input_layer            | Linear            | 98.4 K
4 | Hidden_block           | ModuleList        | 16.8 K
5 | output_layer           | Linear            | 1.5 K 
6 | dropout                | Dropout           | 0     
7 | batchnorm1             | BatchNorm1d       | 256   
-------------------------------------------------------------
117 K     Trainable params
0         Non-trainable params
117 K     Total params

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
`Trainer.fit` stopped: `max_epochs=10` reached.


In [5]:
wandb.finish()

In [6]:
# EPIG
trained_model = trained_model.eval()
query_set, updated_training_set, updated_poolset = active_learning_loop(trained_model,
                                                                        pool_dataloader, 
                                                                        initial_set,
                                                                        pool_set, 
                                                                        config,
                                                                        query_set = None,
                                                                        test_dataloader = test_dataloader,
                                                                        test_set = test_set)

######### EPIG SAMPLING ############




0.005179359635581759
initial_counts 1434 initial_poolset_count 51715 query_counts 100 updated_training_counts 1534 updated_poolset_counts 51615
