# Prediction on your own molecules

This notebook assume that you have docked complexes of your molecules docked with METLL3. 

# Load Libraries

In [None]:
import pandas as pd
import numpy as np
from oddt.fingerprints import PLEC
import oddt
from joblib import Parallel, delayed
from tqdm import tqdm
import glob
import os
import tempfile
from openbabel import openbabel
from rdkit import Chem
from rdkit.Chem import AllChem
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset,TensorDataset
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
# Evaluate metrics on the test set
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, average_precision_score, matthews_corrcoef
import deepchem as dc
from deepchem.utils.vina_utils import prepare_inputs
from deepchem.utils import download_url, load_from_disk
from deepchem.feat import RdkitGridFeaturizer

In [None]:
device = torch.device('cuda')

# Define function for voxel features

In [None]:
protein = '/home/juni/working/mettl3/notebooks/Attention_4DCNN/example/receptor.pdb'
featurizer = RdkitGridFeaturizer(box_width=24,voxel_width = 6, feature_types = ["splif"], ecfp_power = 9, splif_power = 9, flatten = False, verbose = False)
def extract_grid_feature(ligand_file):
    feature = featurizer._featurize((ligand_file, protein))
    return feature

# Generate voxel features

In [None]:
os.chdir('/home/juni/working/mettl3/notebooks/Attention_4DCNN/example/docked_complexes/')
docked_sdf_active = glob.glob('*.sdf')
docked_sdf_active.sort(key=lambda x: int(''.join(filter(str.isdigit, x))))

screening_features = Parallel(n_jobs = 60, backend = "multiprocessing")(delayed(extract_grid_feature)(mol) for mol in tqdm(docked_sdf_active))

# Create a DataLoader for the screening molecules

In [None]:
input_features = torch.tensor(screening_features, dtype=torch.float32)
input_features = input_features.permute(0, 4, 1, 2, 3)
screening_dataset = TensorDataset(input_features)  # No target values
screening_loader = DataLoader(screening_dataset, batch_size=32, shuffle=False)

# Load the model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class MultiheadAttention3D(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiheadAttention3D, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # Reshape (batch_size, channels, D, H, W) -> (batch_size, channels, D*H*W)
        batch_size, channels, D, H, W = x.shape
        x = x.view(batch_size, channels, -1).permute(2, 0, 1)  # Shape: (D*H*W, batch_size, channels)

        # Apply Multihead Attention
        attn_output, _ = self.multihead_attn(x, x, x)
        attn_output = self.norm(attn_output)

        # Reshape back to (batch_size, channels, D, H, W)
        attn_output = attn_output.permute(1, 2, 0).view(batch_size, channels, D, H, W)
        return attn_output

class CNN3D(nn.Module):
    def __init__(self):
        super(CNN3D, self).__init__()
        # Input shape: (batch_size, 1539, 4, 4, 4)
        self.conv1 = nn.Conv3d(1536, 32, kernel_size=3, stride=1, padding=1)  # Output: (batch_size, 512, 4, 4, 4)
        self.bn1 = nn.BatchNorm3d(32)

        #self.attn1 = MultiheadAttention3D(embed_dim=32, num_heads=1)  # Apply Multihead Attention
        self.attn_layers = nn.ModuleList([MultiheadAttention3D(embed_dim=32, num_heads=8) for _ in range(6)])  # Apply Multihead Attention with 6 layers
        self.conv2 = nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1)  # Output: (batch_size, 256, 4, 4, 4)
        self.bn2 = nn.BatchNorm3d(64)

        self.attn2 = MultiheadAttention3D(embed_dim=64, num_heads=8)  # Apply Multihead Attention

        self.conv3 = nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1)  # Output: (batch_size, 128, 4, 4, 4)
        self.bn3 = nn.BatchNorm3d(128)

        self.attn3 = MultiheadAttention3D(embed_dim=128, num_heads=8)  # Apply Multihead Attention

        self.pool = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)  # Output: (batch_size, 128, 2, 2, 2)
        self.dropout_conv = nn.Dropout3d(p=0.5)

        # Fully connected layers
        self.fc1 = nn.Linear(128 * 2 * 2 * 2, 256)
        self.dropout_fc1 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(256, 128)
        self.dropout_fc2 = nn.Dropout(p=0.5)
        self.fc3 = nn.Linear(128, 1)  # Binary classification (e.g., active/inactive)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))  # Conv1 + BatchNorm + ReLU
        #Apply 6 layers of Multihead Attention sequentially
        for attn_layer in self.attn_layers:
             x = attn_layer(x)
        #x = self.attn1(x)  # Multihead Attention after conv1

        x = F.relu(self.bn2(self.conv2(x)))  # Conv2 + BatchNorm + ReLU
        #x = self.attn2(x)  # Multihead Attention after conv2

        x = F.relu(self.bn3(self.conv3(x)))  # Conv3 + BatchNorm + ReLU
        #x = self.attn3(x)  # Multihead Attention after conv3

        x = self.dropout_conv(x)  # Dropout after convolutions
        x = self.pool(x)  # Pooling layer

        x = x.view(x.size(0), -1)  # Flatten the output
        x = F.relu(self.fc1(x))  # Fully connected layer 1
        x = self.dropout_fc1(x)  # Dropout after fc1
        x = F.relu(self.fc2(x))  # Fully connected layer 2
        x = self.dropout_fc2(x)  # Dropout after fc2
        x = torch.sigmoid(self.fc3(x))  # Sigmoid for binary classification
        #x = self.fc3(x)  # No activation function for regression
        return x
# Example usage:
model = CNN3D()
print(model)

# Initialize the model

In [None]:
model = CNN3D()

# Define the path to the saved model

In [None]:
import os
save_dir = '/home/juni/working/mettl3/notebooks/Attention_4DCNN/models/'  # Ensure this matches the directory where you saved the model
save_path = os.path.join(save_dir, 'model_checkpoint_predict_100epochs.pth')

# Load the saved checkpoint

In [None]:
checkpoint = torch.load(save_path)

# Load the model state_dict

In [None]:
model.load_state_dict(checkpoint['model_state_dict'])

# Load the optimizer state_dict

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Set the model to evaluation model

In [None]:
model.eval()

# Make predictions

In [None]:
import pandas as pd
import torch

# Lists to store predictions
predicted_scores = []
predicted_classes = []

# Define a threshold (for binary classification)
threshold = 0.5

# Run predictions
with torch.no_grad():
    for inputs in screening_loader:
        inputs = inputs[0].to(device)  # Inputs from the DataLoader
        outputs = model(inputs)
        
        # Convert outputs to numpy and flatten
        scores = outputs.cpu().numpy().astype(float).flatten()  # Flatten to remove extra brackets
        classes = (scores >= threshold).astype(int)  # Convert scores to classes based on threshold

        predicted_scores.extend(scores.tolist())  # Convert to list
        predicted_classes.extend(classes.tolist())  # Convert to list

# Convert predictions to a DataFrame
screening_results = pd.DataFrame({
    "Predicted_score": predicted_scores,
    "Predicted_class": predicted_classes
})

# Save predictions to a CSV file
screening_results.to_csv("./screening_results.csv", index=False)

print("Predictions for screening molecules saved to 'screening_results.csv'")