In [1]:
import pandas as pd
import numpy as np

In [2]:
df = pd.read_parquet("../data/screening_data/ligand_embeddings.parquet")

In [3]:
df

Unnamed: 0,ID,encoding
0,CCCOc1ccc(C(O)(CC)C(CN2CCOCC2)c2ccccc2)cc1,"[0.4631887376308441, 1.1496385335922241, -0.16..."
1,Cc1cccc(N2C(=O)C(Cl)=C(Nc3ccccc3O)C2=O)c1C,"[0.3505532741546631, 0.6738343238830566, 0.027..."
2,O=C(Cc1cccs1)Nc1cccc(-c2nc3cc4ccccc4cc3[nH]2)c1,"[0.2411477416753769, -0.02009040117263794, -0...."
3,Cn1ncc(N2CCC(C(=O)Nc3cccc(-c4nc5ccccc5[nH]4)c3...,"[0.20003139972686768, -0.06338706612586975, -0..."
4,CCOC(=O)c1c(N2C(=O)C=CC2=O)sc2c1CCCC2,"[0.6468941569328308, 0.5003923773765564, 0.236..."
...,...,...
1269938,CCCN(C)c1ccc2ncc(=O)n(C)c2n1,"[0.6567842960357666, 0.25491511821746826, -0.2..."
1269939,CN(C)C(=O)c1cc(N(C)C)nc2ccccc12,"[0.9823178052902222, 0.6226070523262024, -0.19..."
1269940,CCN1CCN(C(=O)c2cc(C(C)C)n[nH]2)CC1,"[0.423491895198822, 0.19044129550457, 0.111522..."
1269941,O=C(c1ccccc1)N1CC(c2ncco2)C2(CCN(CC3CCCCC3)CC2)C1,"[0.567745566368103, 0.27047884464263916, -0.18..."


In [4]:
import torch
from torch import tensor
from torch.utils.data import DataLoader,TensorDataset

In [5]:
def create_tensor_dataset(df, mol_col, batch_size=1024):
    # Make Tensors
    mol_tensor = torch.tensor(np.array(df[mol_col].tolist())).to(torch.float32)
    
    # Create TensorDataset
    dataset = TensorDataset(mol_tensor)
    
    # Create DataLoader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    return dataloader

In [6]:
data_loader = create_tensor_dataset(df, mol_col='encoding')

In [7]:
batch = next(iter(data_loader))

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class resBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_conv1=False, strides=1, dropout=0.4):
        super().__init__()
        
        self.process = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=strides, padding=1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_channels)
        )
        
        if use_conv1:
            self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=strides)
        else:
            self.conv1 = None
        
    def forward(self, x):
        left = self.process(x)
        right = x if self.conv1 is None else self.conv1(x)
        
        return F.relu(left + right)

class cnnModule(nn.Module):
    def __init__(self, in_channel, out_channel, hidden_channel=128, dropout=0.4):
        super().__init__()
        
        self.head = nn.Sequential(
            nn.Conv1d(in_channel, hidden_channel, 7, stride=2, padding=3, bias=False),
            nn.BatchNorm1d(hidden_channel),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.MaxPool1d(2)
        )
        
        self.cnn = nn.Sequential(
            resBlock(hidden_channel, out_channel, use_conv1=True, strides=1),
            resBlock(out_channel, out_channel, strides=1),
            resBlock(out_channel, out_channel, strides=1)
        )
    
    def forward(self, x):
        x = self.head(x)
        x = self.cnn(x)
        
        return x

class DeepLPI(nn.Module):
    def __init__(self, molshape, seqshape, dropout=0.4):
        super().__init__()
        
        self.molshape = molshape
        self.seqshape = seqshape

        self.molcnn = cnnModule(1, 64)  # Adjusted out_channel
        self.seqcnn = cnnModule(1, 64)  # Adjusted out_channel
        
        self.pool = nn.AvgPool1d(5, stride=3)
        self.lstm = nn.LSTM(64, 64, num_layers=3, batch_first=True, bidirectional=True)  # Adjusted hidden size and num_layers
        
        self.mlp = nn.Sequential(
            nn.Linear(round(((molshape + seqshape) / 4 - 2) * 2 / 3) * 64, 4096),  # Adjusted hidden units
            nn.BatchNorm1d(4096),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            
            nn.Linear(4096, 2048),  # Adjusted hidden units
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            
            nn.Linear(2048, 512),  # Adjusted hidden units
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            
            nn.Linear(512, 1),
        )

    def forward(self, mol, seq):
        mol = self.molcnn(mol.reshape(-1, 1, self.molshape))
        seq = self.seqcnn(seq.reshape(-1, 1, self.seqshape))
        
        # Concatenate along the sequence dimension
        x = torch.cat((mol, seq), 2)
        x = self.pool(x)
        
        # Reshape for LSTM
        batch_size = x.size(0)
        x = x.reshape(batch_size, -1, 64)
        x, _ = self.lstm(x)
        
        # Fully connected layer
        x = self.mlp(x.flatten(1))
        
        x = x.flatten()
        
        return x

# Example usage
molshape = 768
seqshape = 320

In [9]:
# Load the model parameters
save_path = "../models/production/deeplpi_model_v2.pth"
model = DeepLPI(molshape, seqshape)
model.load_state_dict(torch.load(save_path))
model = model.to("cuda")

In [10]:
import torch
import numpy as np
from tqdm import tqdm

def inference_model(model, dataloader, target_protein):
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for step, batch in enumerate(tqdm(dataloader, desc="Inference Progress")):
            step_mol = batch[0]
            step_seq = target_protein
            
            # Repeat the protein sequence for each molecule in the batch
            step_seq = step_seq.unsqueeze(0).repeat(step_mol.size(0), 1)
            
            step_mol, step_seq = step_mol.to("cuda"), step_seq.to("cuda")
            
            logits = model(step_mol, step_seq)
            
            # Apply exponential transformation to convert back to original scale
            preds = torch.exp(logits) - 1
            preds = preds.cpu().numpy()
            
            predictions.extend(preds)
    
    return np.array(predictions)

In [11]:
# To load the data back
target_protein = np.load('../data/proteins/embeddings/6VKV_GAG_embeddings.npy')
seq_tensor = torch.tensor(target_protein.tolist()).to(torch.float32)

In [12]:
predictions_6VKV_GAG = inference_model(model=model,dataloader=data_loader,target_protein=seq_tensor)

Inference Progress: 100%|██████████| 1241/1241 [00:32<00:00, 38.53it/s]


In [13]:
# To load the data back
target_protein = np.load('../data/proteins/embeddings/7L5E_XPO1_embeddings.npy')
seq_tensor = torch.tensor(target_protein.tolist()).to(torch.float32)

In [14]:
predictions_7L5E_XPO1 = inference_model(model=model,dataloader=data_loader,target_protein=seq_tensor)

Inference Progress: 100%|██████████| 1241/1241 [00:29<00:00, 42.33it/s]


In [15]:
# To load the data back
target_protein = np.load('../data/proteins/embeddings/8QYR_MYH7_embeddings.npy')
seq_tensor = torch.tensor(target_protein.tolist()).to(torch.float32)

In [16]:
predictions_8QYR_MYH7 = inference_model(model=model,dataloader=data_loader,target_protein=seq_tensor)

Inference Progress: 100%|██████████| 1241/1241 [00:29<00:00, 42.33it/s]


In [17]:
# To load the data back
target_protein = np.load('../data/proteins/embeddings/105M_FPT_embeddings.npy')
seq_tensor = torch.tensor(target_protein.tolist()).to(torch.float32)

In [18]:
predictions_105M_FPT = inference_model(model=model,dataloader=data_loader,target_protein=seq_tensor)

Inference Progress: 100%|██████████| 1241/1241 [00:29<00:00, 42.39it/s]


In [19]:
import pandas as pd

def create_final_df(smiles_id, predictions_dict):
    # Create a DataFrame with the smiles_id
    final_df = pd.DataFrame({'Ligand_SMILE': smiles_id})
    
    # Add each prediction array to the DataFrame
    for name, predictions in predictions_dict.items():
        final_df[name] = predictions
    
    return final_df

# Example usage
smiles_id = df['ID']
predictions_dict = {
    'Predictions_6VKV_GAG': predictions_6VKV_GAG,
    'Predictions_7L5E_XPO1': predictions_7L5E_XPO1,
    'Predictions_8QYR_MYH7': predictions_8QYR_MYH7,
    'Predictions_105M_FPT': predictions_105M_FPT
}

final_df = create_final_df(smiles_id, predictions_dict)

print(final_df.head())

                                        Ligand_SMILE  Predictions_6VKV_GAG   
0         CCCOc1ccc(C(O)(CC)C(CN2CCOCC2)c2ccccc2)cc1             16.940990  \
1         Cc1cccc(N2C(=O)C(Cl)=C(Nc3ccccc3O)C2=O)c1C             10.951564   
2    O=C(Cc1cccs1)Nc1cccc(-c2nc3cc4ccccc4cc3[nH]2)c1              1.337074   
3  Cn1ncc(N2CCC(C(=O)Nc3cccc(-c4nc5ccccc5[nH]4)c3...             21.190752   
4              CCOC(=O)c1c(N2C(=O)C=CC2=O)sc2c1CCCC2              9.321757   

   Predictions_7L5E_XPO1  Predictions_8QYR_MYH7  Predictions_105M_FPT  
0              20.698540               1.010973              1.212017  
1               4.073555              18.035702              4.752003  
2               0.888778               8.434605              0.338604  
3               3.953398               2.300498              1.069123  
4               2.715251              12.506530              0.318364  


In [20]:
final_df.to_parquet('../data/screening_data/inference_v1/results.parquet')

In [21]:
final_df

Unnamed: 0,Ligand_SMILE,Predictions_6VKV_GAG,Predictions_7L5E_XPO1,Predictions_8QYR_MYH7,Predictions_105M_FPT
0,CCCOc1ccc(C(O)(CC)C(CN2CCOCC2)c2ccccc2)cc1,16.940990,20.698540,1.010973,1.212017
1,Cc1cccc(N2C(=O)C(Cl)=C(Nc3ccccc3O)C2=O)c1C,10.951564,4.073555,18.035702,4.752003
2,O=C(Cc1cccs1)Nc1cccc(-c2nc3cc4ccccc4cc3[nH]2)c1,1.337074,0.888778,8.434605,0.338604
3,Cn1ncc(N2CCC(C(=O)Nc3cccc(-c4nc5ccccc5[nH]4)c3...,21.190752,3.953398,2.300498,1.069123
4,CCOC(=O)c1c(N2C(=O)C=CC2=O)sc2c1CCCC2,9.321757,2.715251,12.506530,0.318364
...,...,...,...,...,...
1269938,CCCN(C)c1ccc2ncc(=O)n(C)c2n1,20.866096,4.209562,1.549488,0.767108
1269939,CN(C)C(=O)c1cc(N(C)C)nc2ccccc12,29.075506,6.171589,25.340958,1.025443
1269940,CCN1CCN(C(=O)c2cc(C(C)C)n[nH]2)CC1,16.378925,0.207296,13.468922,0.537754
1269941,O=C(c1ccccc1)N1CC(c2ncco2)C2(CCN(CC3CCCCC3)CC2)C1,11.630421,4.542598,0.976811,0.635975
