In [1]:
# Get DataFrame
from src.tdc_constant import TDC
from src.tdc_data import get_data_df, get_mtl_data_df

get_data_df(TDC.Clearance, datasetType = 'train')
get_mtl_data_df([TDC.Clearance, TDC.BBB], datasetType = 'train', scaled=True)


Unnamed: 0,Drug,BBB,Clearance
0,CC(C)(C)OC(=O)CCCc1ccc(N(CCCl)CCCl)cc1,1.0,
1,CC(C)(C)OC(=O)CCCc1ccc(N(CCCl)CCCl)cc1,1.0,
2,CC1COc2c(N3CCN(C)CC3)c(F)cc3c(=O)c(C(=O)O)cn1c23,1.0,
3,CC1COc2c(N3CCN(C)CC3)c(F)cc3c(=O)c(C(=O)O)cn1c23,1.0,
4,Cc1onc(-c2ccccc2Cl)c1C(=O)N[C@@H]1C(=O)N2[C@@H...,1.0,
...,...,...,...
2220,CCc1cccc2c3c([nH]c12)C(CC)(CC(=O)O)OCC3,,-0.601052
2221,Cc1cc(CN2Cc3ccccc3C2C(=O)Nc2ccc(Cl)cc2Cl)ccc1O...,,-0.558425
2222,COc1cccc2c1c(NS(=O)(=O)c1ccc(Cl)s1)nn2Cc1cccc(...,,-0.457888
2223,CO[C@H]1CC[C@]2(CC1)Cc1ccc(-c3cc(Cl)cc(C#N)c3)...,,-0.606883


## Get Embedding Vector

In [6]:
# Get embedding vector
model_name = 'MTL2_DropRatio0.2/MolCLR_[BBB, CYP3A4, Clearance, Solubility]_sc-12.13_1830.pt'
modelf=f'ckpts/{model_name}'

device = 'cuda:0'

In [2]:
from src.ginet_finetune import *
import torch.nn.functional as F

class GINet_Feat(nn.Module):
    """
    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        drop_ratio (float): dropout rate
        gnn_type: gin, gcn, graphsage, gat
    Output:
        node representations
    """
    def __init__(self, 
        drop_ratio=0, pool='mean',
    ):
        super(GINet_Feat, self).__init__()
        self.num_layer = 5 # pretrained
        self.emb_dim = 300 # pretrained
        self.feat_dim = 512 # pretrained
        self.drop_ratio = drop_ratio

        self.x_embedding1 = nn.Embedding(num_atom_type, self.emb_dim) # num_atom_type -> src/ginet_finetune.py
        self.x_embedding2 = nn.Embedding(num_chirality_tag, self.emb_dim)
        nn.init.xavier_uniform_(self.x_embedding1.weight.data)
        nn.init.xavier_uniform_(self.x_embedding2.weight.data)

        # List of MLPs
        self.gnns = nn.ModuleList()
        for layer in range(self.num_layer):
            self.gnns.append(GINEConv(self.emb_dim)) # -> src/ginet_finetune.py

        # List of batchnorms
        self.batch_norms = nn.ModuleList()
        for layer in range(self.num_layer):
            self.batch_norms.append(nn.BatchNorm1d(self.emb_dim))

        if pool == 'mean':
            self.pool = global_mean_pool
        elif pool == 'max':
            self.pool = global_max_pool
        elif pool == 'add':
            self.pool = global_add_pool
        self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        
        h = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])

        for layer in range(self.num_layer):
            h = self.gnns[layer](h, edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layer - 1:
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)

        h = self.pool(h, data.batch)
        h = self.feat_lin(h)  # just before prediction head 
        
        return h

    def load_my_state_dict(self, state_dict):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            if isinstance(param, nn.parameter.Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)

def _load_pre_trained_weights(model,device,location=None):
    if location:
        checkpoint_file = location
    else:
        try:
            checkpoints_folder = os.path.join('./ckpt', 'pretrained_gin', 'checkpoints')
            checkpoint_file = os.path.join(checkpoints_folder, 'model.pth')
        except FileNotFoundError:
            print("Pre-trained weights not found. Training from scratch.")
            
    state_dict = torch.load(checkpoint_file, map_location=device)
    model.load_my_state_dict(state_dict)
    print("Loaded pre-trained model with success.")

    return model

In [7]:
from src.ginet_finetune import GINet_Feat, load_pre_trained_weights
model = GINet_Feat(
    pool='mean',
    drop_ratio=0
)
model = load_pre_trained_weights(model, device, modelf)

Loaded pre-trained model with success.


In [15]:
import torch
import pandas as pd
from src.dataset_mtl import smilesToGeometric

drug_dataset = '/data/project/aigenintern/2023-2/DCC/DrugMAP_approved_smallmolecule_drug.csv'
drugs_df=pd.read_csv(drug_dataset)

df = pd.DataFrame()
df["SMILES"] = pd.Series(dtype='object')
df["isDrug"] = pd.Series(dtype='bool')

for smiles in [ 'C=COC=C' ]:
    data = smilesToGeometric(smiles)
    data.batch = torch.tensor([0] * len(data.x))
    embedding_vector = model(data).flatten().flatten()
    df.at[smiles, "isDrug"] = True

df.to_csv("chem_space.csv")

  df.at[smiles, "1"] = "1"
