In [2]:
import sys
sys.path.append('.')
from src.settings import settings
#from src.utils import set_random_seed
import argparse
import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import random
from src.data.featurizer import Vocab, N_ATOM_TYPES, N_BOND_TYPES
from src.data.finetune_dataset import MoleculeDataset
from src.data.collator import Collator_tune
from src.model.light import LiGhTPredictor as LiGhT
from src.trainer.finetune_trainer import Trainer
from src.trainer.evaluator import Evaluator
from src.trainer.result_tracker import Result_Tracker
from src.model_config import config_dict
from src.data.featurizer import smiles_to_graph_tune
from rdkit import Chem
#import RDKit2DNormalized
from src.data.descriptors.rdNormalizedDescriptors import RDKit2DNormalized



In [2]:
def init_params(module):
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=0.0, std=0.02)
        if module.bias is not None:
            module.bias.data.zero_()
    if isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=0.02)
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [3]:
def get_predictor(d_input_feats, n_tasks, n_layers, predictor_drop, device, d_hidden_feats=None):
    print(f'All parameters: {d_input_feats}, {n_tasks}, {n_layers}, {predictor_drop}, {device}, {d_hidden_feats}')
    if n_layers == 1:
        predictor = nn.Linear(d_input_feats, n_tasks)
    else:
        predictor = nn.ModuleList()
        predictor.append(nn.Linear(d_input_feats, d_hidden_feats))
        predictor.append(nn.Dropout(predictor_drop))
        predictor.append(nn.GELU())
        for _ in range(n_layers-2):
            predictor.append(nn.Linear(d_hidden_feats, d_hidden_feats))
            predictor.append(nn.Dropout(predictor_drop))
            predictor.append(nn.GELU())
        predictor.append(nn.Linear(d_hidden_feats, n_tasks))
        predictor = nn.Sequential(*predictor)
    predictor.apply(lambda module: init_params(module))
    return predictor.to(device)

In [4]:
def finetune(dataset_type,model_path):
    config = config_dict["base"]
    vocab = Vocab(N_ATOM_TYPES, N_BOND_TYPES)
    g = torch.Generator()
    g.manual_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    train_dataset = MoleculeDataset(root_path=args.data_path, dataset = args.dataset, dataset_type=dataset_type)
    n_tasks = len(settings[dataset_type]['output_names'])

    # Model Initialization
    model = LiGhT(
        d_node_feats=config['d_node_feats'],
        d_edge_feats=config['d_edge_feats'],
        d_g_feats=config['d_g_feats'],
        d_fp_feats=train_dataset.d_fps,
        d_md_feats=train_dataset.d_mds,
        d_hpath_ratio=config['d_hpath_ratio'],
        n_mol_layers=config['n_mol_layers'],
        path_length=config['path_length'],
        n_heads=config['n_heads'],
        n_ffn_dense_layers=config['n_ffn_dense_layers'],
        input_drop=0,
        attn_drop=0,
        feat_drop=0,
        n_node_types=vocab.vocab_size
    ).to(device)
    
    # Finetuning Setting
    model.predictor = get_predictor(d_input_feats=config['d_g_feats']*3, n_tasks=n_tasks,
                                    
                                     n_layers=2, predictor_drop=0, device=device, d_hidden_feats=256)
    model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(f'{model_path}').items()})
    return model

In [25]:
instanciated_models={}

In [26]:
for model_name in settings:
    print(model_name)
    with torch.no_grad():
        instanciated_models[model_name] = finetune(model_name,
        f'/home/zach/Downloads/downstream/{model_name}/scaffold-0.pth'
        )

bace
All parameters: 2304, 1, 2, 0, cpu, 256
bbbp
All parameters: 2304, 1, 2, 0, cpu, 256
clintox
All parameters: 2304, 2, 2, 0, cpu, 256
esol
All parameters: 2304, 1, 2, 0, cpu, 256
freesolv
All parameters: 2304, 1, 2, 0, cpu, 256
lipo
All parameters: 2304, 1, 2, 0, cpu, 256
metstab
All parameters: 2304, 2, 2, 0, cpu, 256
sider
All parameters: 2304, 27, 2, 0, cpu, 256
tox21
All parameters: 2304, 12, 2, 0, cpu, 256
toxcast
All parameters: 2304, 617, 2, 0, cpu, 256


In [28]:

def get_representation(smiles):
    print('constructing graphs')
    graph = smiles_to_graph_tune(smiles)
    mol = Chem.MolFromSmiles(smiles)
    ecfp= Chem.RDKFingerprint(mol, minPath=1, maxPath=7, fpSize=512)
    ecfp=torch.tensor(ecfp, dtype=torch.float32)

    generator = RDKit2DNormalized()
    md = generator.calculateMol(mol,smiles)
    md = torch.tensor(md, dtype=torch.float32)
    return graph, ecfp, md


In [57]:
from collections import defaultdict 
def gen_all_predictions(smiles):
    result=defaultdict(lambda :{})
    g,ecfp,md=get_representation(smiles)
    with torch.no_grad():
        for model_name in instanciated_models:
            print(model_name)
            
            predictions = instanciated_models[model_name].forward_tune(g.clone(), 
            ecfp,
             md
             )
            for prediction_name,prediction in zip(settings[model_name]['output_names'],predictions[0].tolist()):
                result[model_name][prediction_name]={
                    "value":prediction,
                   
                    "type":settings[model_name]['output_types'][prediction_name]
                }
                if "regression"==settings[model_name]['output_types'][prediction_name]:
                    result[model_name][prediction_name]["units"]=settings[model_name]['units']

                
    return dict(result)



In [58]:
gen_all_predictions("CCN(CC)C(=O)[C@H]1CN([C@@H]2CC3=CNC4=CC=CC(=C34)C2=C1)C")


constructing graphs
bace
bbbp
clintox
esol
freesolv
lipo
metstab
sider
tox21
toxcast


{'bace': {'Class': -2.109135866165161},
 'bbbp': {'p_np': 9.484339714050293},
 'clintox': {'FDA_APPROVED': -4.3052978515625, 'CT_TOX': 3.5269336700439453},
 'esol': {'logSolubility': -1.4553965330123901},
 'freesolv': {'freesolv': -2.0622310638427734},
 'lipo': {'lipo': 0.41645899415016174},
 'metstab': {'high': -2.5849862098693848, 'low': 1.8073099851608276},
 'sider': {'Hepatobiliary disorders': -1.2984468936920166,
  'Metabolism and nutrition disorders': 0.3069888949394226,
  'Product issues': -3.098461627960205,
  'Eye disorders': 0.08531132340431213,
  'Investigations': 0.8364753127098083,
  'Musculoskeletal and connective tissue disorders': 1.5021477937698364,
  'Gastrointestinal disorders': 2.887082576751709,
  'Social circumstances': -1.3231562376022339,
  'Immune system disorders': -0.8708500266075134,
  'Reproductive system and breast disorders': -0.2558687627315521,
  'Neoplasms benign, malignant and unspecified (incl cysts and polyps)': -1.4556807279586792,
  'General disor