In [2]:
from tdc.multi_pred import DTI
data = DTI(name = 'KIBA')
print(data.get_data().head())
split = data.get_split()

Found local copy...
Loading...
Done!


         Drug_ID                                     Drug Target_ID  \
0  CHEMBL1087421  COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2    O00141   
1  CHEMBL1087421  COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2    O14920   
2  CHEMBL1087421  COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2    O15111   
3  CHEMBL1087421  COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2    P00533   
4  CHEMBL1087421  COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2    P04626   

                                              Target     Y  
0  MTVKTEAAKGTLTYSRMRGMVAILIAFMKQRRMGLNDFIQKIANNS...  11.1  
1  MSWSPSLTTQTCGAWEMKERLGTGGFGNVIRWHNQETGEQIAIKQC...  11.1  
2  MERPPGLRPGAGGPWEMRERLGTGGFGNVCLYQHRELDLKIAIKSC...  11.1  
3  MRPSGTAGAALLALLAALCPASRALEEKKVCQGTSNKLTQLGTFED...  11.1  
4  MELAALCRWGLLLALLPPGAASTQVCTGTDMKLRLPASPETHLDML...  11.1  


In [3]:
kiba_data = data.get_data()
print(kiba_data.columns)


Index(['Drug_ID', 'Drug', 'Target_ID', 'Target', 'Y'], dtype='object')


In [1]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

In [2]:
import torch
from torch_geometric.data import Data
from rdkit import Chem
from rdkit.Chem import AllChem
from torch_geometric.utils import from_networkx
import networkx as nx


In [3]:
from transformers import BertModel, BertTokenizer


In [35]:
from torch.utils.data import Dataset
from torch_geometric.data import Batch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_mean
from torch_geometric.utils import softmax
from torch.utils.data import DataLoader
import swifter
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import GCNConv
import torch.optim as optim
import numpy as np
from sklearn.metrics import mean_squared_error


In [22]:
def smiles_to_pyg_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    # Atom features (You can expand this as needed)
    atom_features = []
    for atom in mol.GetAtoms():
        atom_type = atom.GetAtomicNum()  # Atomic number (basic feature)
        formal_charge = atom.GetFormalCharge()  # Formal charge on the atom
        atom_features.append([atom_type, formal_charge])  # Add more features if needed
    
    if len(atom_features) == 0:
        return None

    x = torch.tensor(atom_features, dtype=torch.float)
    
    # Edge features (You can expand this as needed)
    edge_list = []
    bond_types = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bond_type = bond.GetBondTypeAsDouble()  # Bond type as a numeric value

        edge_list.append([i, j])
        edge_list.append([j, i])
        
        bond_types.append([bond_type])
        bond_types.append([bond_type])  # For undirected graph

    if len(edge_list) == 0:
        return None
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(bond_types, dtype=torch.float)  # Optional edge attributes

    # Return PyG Data object with node features (x), edge_index, and edge features (edge_attr)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

In [64]:
def convert_compound_to_graph(compound_smiles):
    """
    Convert compound SMILES string to a graph representation.
    
    :param compound_smiles: SMILES string of a compound.
    :return: Graph representation (PyTorch Geometric Data object).
    """
    mol = Chem.MolFromSmiles(compound_smiles)
    if mol is None:
        return None
    
    atom_features = []
    for atom in mol.GetAtoms():
        atom_feature = [
            atom.GetAtomicNum(), 
            atom.GetTotalNumHs(),  
            atom.GetFormalCharge(), 
            atom.GetIsAromatic(),
            atom.GetHybridization(),
        ]
        atom_features.append(atom_feature)

    edge_index = []
    edge_types = []
    for bond in mol.GetBonds():
        start_atom = bond.GetBeginAtomIdx()
        end_atom = bond.GetEndAtomIdx()
        bond_type = bond.GetBondTypeAsDouble()
        
        edge_index.append([start_atom, end_atom])
        edge_index.append([end_atom, start_atom])
        
        edge_types.append([bond_type])
        edge_types.append([bond_type])

    x = torch.tensor(atom_features, dtype=torch.float)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_types, dtype=torch.float)
    
    graph_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    
    return graph_data

In [6]:
# Initialize ProtBERT model and tokenizer
model_name = 'Rostlab/prot_bert_bfd'
tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=False)
protbert_model = BertModel.from_pretrained(model_name)

# Convert protein sequence to ProtBERT embedding
def sequence_to_protbert_embedding(sequence):
    inputs = tokenizer(sequence, return_tensors='pt', truncation=True, padding=True)
    with torch.no_grad():
        outputs = protbert_model(**inputs)
        embedding = outputs.last_hidden_state.mean(dim=1)  # Mean pooling
    return embedding.squeeze(0)

In [7]:

# Create directories for saving tensors
os.makedirs('graph_tensors', exist_ok=True)
os.makedirs('protein_embeddings', exist_ok=True)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
protbert_model = protbert_model.to(device)

In [14]:

# Save to a new file after populating graphs and embeddings
#kiba_data.to_pickle('processed_kiba_data.pkl')  # Save as a pickle file

# Load it later
kiba_data = pd.read_pickle('processed_kiba_data.pkl')


In [15]:

class KIBADataset(Dataset):
    def __init__(self, kiba_data):
        self.kiba_data = kiba_data

    def __len__(self):
        return len(self.kiba_data)

    def __getitem__(self, idx):
        # Load graph tensor
        graph_tensor = torch.load(f'graph_tensors/graph_{idx}.pt')
        
        # Load protein embedding tensor
        protein_embedding = torch.load(f'protein_embeddings/protein_{idx}.pt')
        
        # Target label (Y value)
        target = torch.tensor(self.kiba_data.iloc[idx]['Y'], dtype=torch.float)
        
        return graph_tensor, protein_embedding, target


In [16]:


def collate_fn(batch):
    graphs, protein_embeddings, targets = zip(*batch)
    batched_graphs = Batch.from_data_list(graphs)  # Batch graph data
    protein_embeddings = torch.stack(protein_embeddings)  # Stack protein embeddings
    targets = torch.tensor(targets)
    
    return batched_graphs, protein_embeddings, targets


In [18]:

# Dataset and DataLoader (assuming the KIBA dataset is already preprocessed)
train_dataset = KIBADataset(kiba_data)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)


# Training parameters
learning_rate = 0.001
num_epochs = 50
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [19]:
# Check the types of graphs in your DataLoader
for batch in train_loader:
    graphs, protein_embeddings, targets = batch
    print(type(graphs), type(protein_embeddings))
    break


<class 'torch_geometric.data.batch.DataBatch'> <class 'torch.Tensor'>


In [20]:
# Check types of embeddings
for batch in train_loader:
    graphs, protein_embeddings, targets = batch
    print(type(protein_embeddings), type(protein_embeddings[0]))
    break


<class 'torch.Tensor'> <class 'torch.Tensor'>


In [62]:
tqdm.pandas()

# Apply smiles_to_pyg_graph in parallel using swifter
print("Generating graphs from SMILES strings...")
kiba_data['graph'] = kiba_data['Drug'].swifter.apply(convert_compound_to_graph)



Generating graphs from SMILES strings...


Pandas Apply:   0%|          | 0/117657 [00:00<?, ?it/s]

In [24]:
# Apply sequence_to_protbert_embedding in parallel using swifter
print("Generating protein embeddings from sequences...")
kiba_data['protein_embedding'] = kiba_data['Target'].swifter.apply(sequence_to_protbert_embedding)


Generating protein embeddings from sequences...


Pandas Apply:   0%|          | 0/117657 [00:00<?, ?it/s]

In [65]:
# Test SMILES to graph conversion
sample_smiles = kiba_data['Drug'].iloc[0]
print(convert_compound_to_graph(sample_smiles))

# Test protein sequence to embedding conversion
sample_sequence = kiba_data['Target'].iloc[0]
print(sequence_to_protbert_embedding(sample_sequence))


Data(x=[21, 5], edge_index=[2, 46], edge_attr=[46, 1])
tensor([ 0.0306,  0.0243,  0.1363,  ..., -0.0868, -0.1145, -0.0130])


In [24]:
kiba_data.head()

Unnamed: 0,Drug_ID,Drug,Target_ID,Target,Y,graph,protein_embedding
0,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,O00141,MTVKTEAAKGTLTYSRMRGMVAILIAFMKQRRMGLNDFIQKIANNS...,11.1,"[(x, [tensor([6., 3., 0., 0., 4.]), tensor([8....","[tensor(0.0306), tensor(0.0243), tensor(0.1363..."
1,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,O14920,MSWSPSLTTQTCGAWEMKERLGTGGFGNVIRWHNQETGEQIAIKQC...,11.1,"[(x, [tensor([6., 3., 0., 0., 4.]), tensor([8....","[tensor(0.0306), tensor(0.0243), tensor(0.1363..."
2,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,O15111,MERPPGLRPGAGGPWEMRERLGTGGFGNVCLYQHRELDLKIAIKSC...,11.1,"[(x, [tensor([6., 3., 0., 0., 4.]), tensor([8....","[tensor(0.0306), tensor(0.0243), tensor(0.1363..."
3,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,P00533,MRPSGTAGAALLALLAALCPASRALEEKKVCQGTSNKLTQLGTFED...,11.1,"[(x, [tensor([6., 3., 0., 0., 4.]), tensor([8....","[tensor(0.0306), tensor(0.0243), tensor(0.1363..."
4,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,P04626,MELAALCRWGLLLALLPPGAASTQVCTGTDMKLRLPASPETHLDML...,11.1,"[(x, [tensor([6., 3., 0., 0., 4.]), tensor([8....","[tensor(0.0306), tensor(0.0243), tensor(0.1363..."


In [None]:
def evaluate(model, test_loader, loss_fn, device):
    """
    Evaluate the model on the test dataset.
    
    Parameters:
    - model: the trained DTIModel
    - test_loader: DataLoader for the test set
    - loss_fn: Loss function (e.g., MSELoss)
    - device: torch device ('cuda' or 'cpu')
    
    Returns:
    - Average test loss
    """
    model.eval()  # Set the model to evaluation mode
    total_loss = 0.0

    with torch.no_grad():  # Disable gradient computation
        for batch_data in test_loader:
            graphs, protein_embeddings, targets = batch_data
            graphs = graphs.to(device)
            protein_embeddings = protein_embeddings.to(device)
            targets = targets.to(device)

         # Forward pass, passing the batch index
            out = dti_model(graphs, protein_embeddings, graphs.batch)

            # Compute loss
            loss = loss_fn(out.squeeze(), targets)
            loss.backward()
            total_loss += loss.item()

    # Average loss over all test samples
    avg_test_loss = total_loss / len(test_loader)
    return avg_test_loss


In [25]:
kiba_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 117657 entries, 0 to 117656
Data columns (total 7 columns):
 #   Column             Non-Null Count   Dtype  
---  ------             --------------   -----  
 0   Drug_ID            117657 non-null  object 
 1   Drug               117657 non-null  object 
 2   Target_ID          117657 non-null  object 
 3   Target             117657 non-null  object 
 4   Y                  117657 non-null  float64
 5   graph              117657 non-null  object 
 6   protein_embedding  117657 non-null  object 
dtypes: float64(1), object(6)
memory usage: 6.3+ MB


In [26]:


# Custom dataset and data loader functions
class KIBADataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        graph = row['graph']
        protein_embedding = row['protein_embedding']
        target = row['Y']
        return graph, protein_embedding, target

# Initialize dataset
train_dataset = KIBADataset(kiba_data)

# Assuming `kiba_data` is your entire dataset
train_data, test_data = train_test_split(kiba_data, test_size=0.2, random_state=42)

# Create DataLoader for batching
train_dataset = KIBADataset(train_data)
test_dataset = KIBADataset(test_data)



In [27]:
def collate_fn(batch):
    graphs, protein_embeddings, targets = zip(*batch)
    batched_graphs = Batch.from_data_list(graphs)
    protein_embeddings = torch.stack(protein_embeddings)
    targets = torch.tensor(targets, dtype=torch.float)
    return batched_graphs, protein_embeddings, targets

In [44]:
from torch_scatter import scatter_add


In [53]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_scatter import scatter_add

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads, use_bias=True):
        super(MultiHeadAttentionLayer, self).__init__()
        self.out_dim = out_dim
        self.num_heads = num_heads
        self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
        self.K = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
        self.V = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
        self.fc = nn.Linear(out_dim * num_heads, out_dim)  # Final projection layer

    def forward(self, x, edge_index):
        # Linear projections for queries, keys, values
        Q_h = self.Q(x)  # Shape: [num_nodes, num_heads * out_dim]
        K_h = self.K(x)  # Shape: [num_nodes, num_heads * out_dim]
        V_h = self.V(x)  # Shape: [num_nodes, num_heads * out_dim]

        # Reshape for multi-head attention
        Q_h = Q_h.view(-1, self.num_heads, self.out_dim)  # Shape: [num_nodes, num_heads, out_dim]
        K_h = K_h.view(-1, self.num_heads, self.out_dim)  # Shape: [num_nodes, num_heads, out_dim]
        V_h = V_h.view(-1, self.num_heads, self.out_dim)  # Shape: [num_nodes, num_heads, out_dim]

        # Compute attention scores
        scores = torch.einsum('bhd,bhd->bh', Q_h, K_h) / (self.out_dim ** 0.5)

        # Apply softmax over the edges for each head
        attention = F.softmax(scores, dim=-1)

        # Multiply attention scores with the values
        out = V_h * attention.unsqueeze(-1)  # Shape: [num_nodes, num_heads, out_dim]

        # Aggregate across the heads and reshape
        out = out.view(-1, self.out_dim * self.num_heads)  # Shape: [num_nodes, num_heads * out_dim]

        # Final projection to collapse heads back to out_dim
        out = self.fc(out)  # Shape: [num_nodes, out_dim]

        return out

class GraphTransformerNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, out_dim, num_heads, num_layers, dropout=0.1):
        super(GraphTransformerNet, self).__init__()
        self.input_fc = nn.Linear(input_dim, hidden_dim)  # Transformation from input_dim -> hidden_dim
        self.layers = nn.ModuleList()
        self.num_layers = num_layers
        self.multihead_attn = MultiHeadAttentionLayer(hidden_dim, hidden_dim, num_heads)

        # Add graph convolutional layers (GCNConv)
        for _ in range(num_layers):
            self.layers.append(GCNConv(hidden_dim, hidden_dim))

        # Final linear layer to project from hidden_dim to out_dim
        self.final_layer = nn.Linear(hidden_dim, out_dim)  # Projection back to out_dim
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        # Initial transformation of input node features
        x = F.relu(self.input_fc(x))  # x is transformed from input_dim -> hidden_dim

        # Apply graph convolutional layers
        for layer in self.layers:
            x = F.relu(layer(x, edge_index))
            x = self.dropout(x)

        # Apply multi-head attention layer after the GCN layers
        x = self.multihead_attn(x, edge_index)

        # Project the result to the final output dimension
        x = self.final_layer(x)

        return x


In [28]:


class DTIModel(nn.Module):
    def __init__(self, graph_transformer, protein_embedding_dim, hidden_dim):
        super(DTIModel, self).__init__()
        self.graph_transformer = graph_transformer
        self.protein_fc = nn.Linear(protein_embedding_dim, hidden_dim)
        self.final_fc = nn.Linear(hidden_dim * 2, 1)  # Combine drug and protein features

    def forward(self, graphs, protein_embeddings, batch):
        # Get drug graph features from Graph Transformer
        x = graphs.x  # Node features
        edge_index = graphs.edge_index  # Edge index

        drug_features = self.graph_transformer(x, edge_index)

        # Apply mean pooling to graph node features
        drug_features_pooled = global_mean_pool(drug_features, batch)  # Mean pool over the nodes for each graph
        
        
        # Reduce protein embeddings to the same dimension
        protein_features = F.relu(self.protein_fc(protein_embeddings))

        #drug_features_pooled = drug_features.mean(dim=0) #for model summary
        
        # Combine both features (drug + protein)
        combined_features = torch.cat([drug_features_pooled, protein_features], dim=-1)

        # Final prediction
        out = self.final_fc(combined_features)

        return out


In [46]:

# # class GraphTransformerNet(nn.Module):
# #     def __init__(self, input_dim, hidden_dim, out_dim, num_layers, dropout=0.1):
# #         super(GraphTransformerNet, self).__init__()
# #         self.layers = nn.ModuleList()
# #         self.input_fc = nn.Linear(input_dim, hidden_dim)  # Transformation from 5 -> 128 features
# #         self.num_layers = num_layers
# #         self.dropout = nn.Dropout(dropout)

# #         # Add graph convolutional layers (GCNConv or others as needed)
# #         for _ in range(num_layers):
# #             self.layers.append(GCNConv(hidden_dim, hidden_dim))

# #         self.final_layer = GCNConv(hidden_dim, out_dim)  # Final layer to output dimension

# #     def forward(self, x, edge_index):
# #         # Initial transformation of input node features (e.g., atomic numbers)
# #         x = F.relu(self.input_fc(x))  # x is transformed from 5D -> 128D
        
# #         # Apply the graph convolutional layers
# #         for layer in self.layers:
# #             x = F.relu(layer(x, edge_index))
# #             x = self.dropout(x)

# #         # Final graph convolution layer
# #         x = self.final_layer(x, edge_index)
        
# #         return x

# class GraphTransformerNet(nn.Module):
#     def __init__(self, input_dim, hidden_dim, out_dim, num_layers, num_heads=8, dropout=0.1):
#         super(GraphTransformerNet, self).__init__()
#         self.layers = nn.ModuleList()
#         self.input_fc = nn.Linear(input_dim, hidden_dim)  # Transformation from 5 -> 128 features
#         self.num_layers = num_layers
#         self.num_heads = num_heads
#         self.dropout = nn.Dropout(dropout)

#         # Add graph convolutional layers (GCNConv or others as needed)
#         for _ in range(num_layers):
#             self.layers.append(GCNConv(hidden_dim, hidden_dim))

#         # Multi-head attention layer
#         self.multihead_attn = MultiHeadAttentionLayer(hidden_dim, hidden_dim, num_heads)

#         # Final graph convolution layer to output dimension
#         self.final_layer = GCNConv(hidden_dim, out_dim)

#     def forward(self, x, edge_index):
#         # Initial transformation of input node features (e.g., atomic numbers)
#         x = F.relu(self.input_fc(x))  # x is transformed from input_dim -> hidden_dim
        
#         # Apply the graph convolutional layers
#         for layer in self.layers:
#             x = F.relu(layer(x, edge_index))
#             x = self.dropout(x)

#         # Apply multi-head attention layer after the GCN layers
#         x = self.multihead_attn(x, edge_index)

#         # Final graph convolution layer
#         x = self.final_layer(x, edge_index)

#         return x

In [56]:


# Training parameters
learning_rate = 0.001
batch_size = 32
num_epochs = 100  # You can adjust this
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize DTI Model
# Define model with correct arguments
in_dim = 5  # Input dimension (e.g., atom features)
hidden_dim = 128  # Hidden dimension for transformer layers
out_dim = 128  # Output dimension from transformer layers
num_heads = 8
num_layers = 12  # Number of layers in the graph transformer
dropout = 0.1
protein_embedding_dim = 1024 

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# Initialize GraphTransformerNet and DTIModel
graph_transformer = GraphTransformerNet(in_dim, hidden_dim, out_dim,num_heads, num_layers, dropout).to(device)
dti_model = DTIModel(graph_transformer, protein_embedding_dim, hidden_dim).to(device)  # Pass hidden_dim here


optimizer = torch.optim.Adam(dti_model.parameters(), lr=learning_rate)
loss_fn = nn.MSELoss()




In [57]:

for epoch in range(num_epochs):
    dti_model.train()
    epoch_loss = 0

    # Initialize tqdm progress bar for each epoch
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)

    for batch_idx, batch_data in progress_bar:
        graphs, protein_embeddings, targets = batch_data
        graphs = graphs.to(device)
        protein_embeddings = protein_embeddings.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()

        # Forward pass, passing the batch index
        out = dti_model(graphs, protein_embeddings, graphs.batch)

        # Compute loss
        loss = loss_fn(out.squeeze(), targets)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        # Update tqdm description with current loss for each batch
        progress_bar.set_postfix({'Batch': batch_idx + 1, 'Loss': epoch_loss / (batch_idx + 1)})

    # Print epoch-level statistics after each epoch
    print(f'Epoch {epoch+1}/{num_epochs}, Avg Loss: {epoch_loss / len(train_loader)}')


                                                                                       

Epoch 1/100, Avg Loss: 1.0938499811441618


                                                                                        

Epoch 2/100, Avg Loss: 0.714787721892225


                                                                                        

Epoch 3/100, Avg Loss: 0.7158366828538949


                                                                                        

Epoch 4/100, Avg Loss: 0.7152901258519683


                                                                                        

Epoch 5/100, Avg Loss: 0.7159166746120077


                                                                                        

Epoch 6/100, Avg Loss: 0.7165546642678438


                                                                                        

Epoch 7/100, Avg Loss: 0.7132910434068019


                                                                                        

Epoch 8/100, Avg Loss: 0.7158813619192241


                                                                                        

Epoch 9/100, Avg Loss: 0.7127568095081684


                                                                                         

Epoch 10/100, Avg Loss: 0.7149880324736653


                                                                                         

Epoch 11/100, Avg Loss: 0.7133726887404919


                                                                                         

Epoch 12/100, Avg Loss: 0.7122642323729297


                                                                                         

Epoch 13/100, Avg Loss: 0.7112233366495609


                                                                                         

Epoch 14/100, Avg Loss: 0.70837590270337


                                                                                         

Epoch 15/100, Avg Loss: 0.7088977608406811


                                                                                         

Epoch 16/100, Avg Loss: 0.709479608487548


                                                                                         

Epoch 17/100, Avg Loss: 0.7084490371188041


                                                                                         

Epoch 18/100, Avg Loss: 0.7080223110335405


                                                                                         

Epoch 19/100, Avg Loss: 0.709024003984813


                                                                                         

Epoch 20/100, Avg Loss: 0.7078859442699369


                                                                                         

Epoch 21/100, Avg Loss: 0.7078688636138518


                                                                                         

Epoch 22/100, Avg Loss: 0.706981605522836


                                                                                         

Epoch 23/100, Avg Loss: 0.7073676005806103


                                                                                         

Epoch 24/100, Avg Loss: 0.7079271285064686


                                                                                         

Epoch 25/100, Avg Loss: 0.7080382304880224


                                                                                         

Epoch 26/100, Avg Loss: 0.7069547850458822


                                                                                         

Epoch 27/100, Avg Loss: 0.7069157339698355


                                                                                         

Epoch 28/100, Avg Loss: 0.7071160144653148


                                                                                         

Epoch 29/100, Avg Loss: 0.707062168010111


                                                                                         

Epoch 30/100, Avg Loss: 0.7081355244974795


                                                                                         

Epoch 31/100, Avg Loss: 0.7067081532678501


                                                                                         

Epoch 32/100, Avg Loss: 0.7080790278844328


                                                                                         

Epoch 33/100, Avg Loss: 0.7073651463309978


                                                                                         

Epoch 34/100, Avg Loss: 0.7048777053696131


                                                                                         

Epoch 35/100, Avg Loss: 0.7058981478031925


                                                                                         

Epoch 36/100, Avg Loss: 0.7061244502187505


                                                                                         

Epoch 37/100, Avg Loss: 0.7068897550792657


                                                                                         

Epoch 38/100, Avg Loss: 0.7062398880861958


                                                                                         

Epoch 39/100, Avg Loss: 0.7052883181156228


                                                                                         

Epoch 40/100, Avg Loss: 0.7061360322830428


                                                                                         

Epoch 41/100, Avg Loss: 0.705421927726285


                                                                                         

Epoch 42/100, Avg Loss: 0.7063294597512034


                                                                                         

Epoch 43/100, Avg Loss: 0.7046167190705206


                                                                                         

Epoch 44/100, Avg Loss: 0.7046734937609134


                                                                                         

Epoch 45/100, Avg Loss: 0.70621199979821


                                                                                         

Epoch 46/100, Avg Loss: 0.7062813058687991


                                                                                         

Epoch 47/100, Avg Loss: 0.703788040442624


                                                                                         

Epoch 48/100, Avg Loss: 0.7042859899352433


                                                                                         

Epoch 49/100, Avg Loss: 0.7042252462126826


                                                                                         

Epoch 50/100, Avg Loss: 0.7048069543321063


                                                                                         

Epoch 51/100, Avg Loss: 0.7037923487016006


                                                                                         

Epoch 52/100, Avg Loss: 0.7036582242360656


                                                                                         

Epoch 53/100, Avg Loss: 0.7048376851761365


                                                                                         

Epoch 54/100, Avg Loss: 0.7040916169119238


                                                                                         

Epoch 55/100, Avg Loss: 0.703953132681665


                                                                                         

Epoch 56/100, Avg Loss: 0.7052330413920631


                                                                                         

Epoch 57/100, Avg Loss: 0.704677948521806


                                                                                         

Epoch 58/100, Avg Loss: 0.7048944579798254


                                                                                         

Epoch 59/100, Avg Loss: 0.7040899945948499


                                                                                         

Epoch 60/100, Avg Loss: 0.7057660726702473


                                                                                         

Epoch 61/100, Avg Loss: 0.7030926981825637


                                                                                         

Epoch 62/100, Avg Loss: 0.7037735815577204


                                                                                         

Epoch 63/100, Avg Loss: 0.7043788250685143


                                                                                         

Epoch 64/100, Avg Loss: 0.7043588170636114


                                                                                         

Epoch 65/100, Avg Loss: 0.7029625084238503


                                                                                         

Epoch 66/100, Avg Loss: 0.7036585587423161


                                                                                         

Epoch 67/100, Avg Loss: 0.702548815412314


                                                                                         

Epoch 68/100, Avg Loss: 0.7022728517443894


                                                                                         

Epoch 69/100, Avg Loss: 0.7017939069810406


                                                                                         

Epoch 70/100, Avg Loss: 0.7021545416173443


                                                                                         

Epoch 71/100, Avg Loss: 0.7035569844198908


                                                                                         

Epoch 72/100, Avg Loss: 0.7030411841035051


                                                                                         

Epoch 73/100, Avg Loss: 0.7027664857249857


                                                                                         

Epoch 74/100, Avg Loss: 0.7014754640810916


                                                                                         

Epoch 75/100, Avg Loss: 0.7023533479919003


                                                                                         

Epoch 76/100, Avg Loss: 0.7017858584470322


                                                                                         

Epoch 77/100, Avg Loss: 0.7018293808616399


                                                                                         

Epoch 78/100, Avg Loss: 0.7008649016433554


                                                                                         

Epoch 79/100, Avg Loss: 0.7008595068557257


                                                                                         

Epoch 80/100, Avg Loss: 0.7004353156722297


                                                                                         

Epoch 81/100, Avg Loss: 0.7017780157283485


                                                                                         

Epoch 82/100, Avg Loss: 0.7007540671001641


                                                                                         

Epoch 83/100, Avg Loss: 0.7013800591508966


                                                                                         

Epoch 84/100, Avg Loss: 0.7012498691313369


                                                                                         

Epoch 85/100, Avg Loss: 0.7008698466863865


                                                                                         

Epoch 86/100, Avg Loss: 0.7016270305938578


                                                                                         

Epoch 87/100, Avg Loss: 0.7017881336927738


                                                                                         

Epoch 88/100, Avg Loss: 0.7016898205745674


                                                                                         

Epoch 89/100, Avg Loss: 0.7015240429454236


                                                                                         

Epoch 90/100, Avg Loss: 0.7003001640243728


                                                                                         

Epoch 91/100, Avg Loss: 0.7012457099042322


                                                                                         

Epoch 92/100, Avg Loss: 0.7020125189083195


                                                                                         

Epoch 93/100, Avg Loss: 0.7024282247922115


                                                                                         

Epoch 94/100, Avg Loss: 0.7007481717105061


                                                                                         

Epoch 95/100, Avg Loss: 0.7003922407407083


                                                                                         

Epoch 96/100, Avg Loss: 0.7012512849341194


                                                                                         

Epoch 97/100, Avg Loss: 0.7011095807614738


                                                                                         

Epoch 98/100, Avg Loss: 0.7016466249522508


                                                                                         

Epoch 99/100, Avg Loss: 0.7014241894867896


                                                                                          

Epoch 100/100, Avg Loss: 0.7005498836729941




In [58]:
def evaluate(model, test_loader, loss_fn, device):
    """
    Evaluate the model on the test dataset.
    
    Parameters:
    - model: the trained DTIModel
    - test_loader: DataLoader for the test set
    - loss_fn: Loss function (e.g., MSELoss)
    - device: torch device ('cuda' or 'cpu')
    
    Returns:
    - Average test loss
    """
    model.eval()  # Set the model to evaluation mode
    total_loss = 0.0

    with torch.no_grad():  # Disable gradient computation
        for batch_data in test_loader:
            graphs, protein_embeddings, targets = batch_data
            graphs = graphs.to(device)
            protein_embeddings = protein_embeddings.to(device)
            targets = targets.to(device)

         # Forward pass, passing the batch index
            out = dti_model(graphs, protein_embeddings, graphs.batch)

            # Compute loss
            loss = loss_fn(out.squeeze(), targets)
            loss.backward()
            total_loss += loss.item()

    # Average loss over all test samples
    avg_test_loss = total_loss / len(test_loader)
    return avg_test_loss


In [59]:


def evaluate_model_mse(model, data_loader, device):
    model.eval()  # Set the model to evaluation mode
    all_targets = []
    all_predictions = []
    epoch_loss = 0

    with torch.no_grad():  # Disable gradient computation for evaluation
        for batch_data in data_loader:
            graphs, protein_embeddings, targets = batch_data
            graphs = graphs.to(device)
            protein_embeddings = protein_embeddings.to(device)
            targets = targets.to(device)

            # Forward pass
            out = model(graphs, protein_embeddings, graphs.batch)

            # Compute loss
            loss = loss_fn(out.squeeze(), targets)
            epoch_loss += loss.item()

            # Collect predictions and true targets
            all_predictions.append(out.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    # Concatenate all predictions and targets across batches
    all_targets = np.concatenate(all_targets)
    all_predictions = np.concatenate(all_predictions)

    # Calculate MSE using sklearn
    mse = mean_squared_error(all_targets, all_predictions)

    # Print MSE
    print(f"Evaluation MSE: {mse}")

    return mse


In [60]:
# Assuming dti_model and test_loader are already defined
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

mse = evaluate_model_mse(dti_model, test_loader, device)


Evaluation MSE: 0.7019504308700562


In [122]:
# Print the model architecture
print(dti_model)

# If you want to print more detailed info about the model, including parameters:
for name, param in dti_model.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")


DTIModel(
  (graph_transformer): GraphTransformerNet(
    (layers): ModuleList(
      (0-2): 3 x GCNConv(128, 128)
    )
    (input_fc): Linear(in_features=5, out_features=128, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (final_layer): GCNConv(128, 128)
  )
  (protein_fc): Linear(in_features=1024, out_features=128, bias=True)
  (final_fc): Linear(in_features=256, out_features=1, bias=True)
)
Layer: graph_transformer.layers.0.bias | Size: torch.Size([128]) | Values : tensor([-0.0317, -0.0417], grad_fn=<SliceBackward0>) 

Layer: graph_transformer.layers.0.lin.weight | Size: torch.Size([128, 128]) | Values : tensor([[ 4.1892e-02,  4.8769e-03,  3.7488e-02, -1.7063e-01, -1.2858e-01,
         -5.4119e-02, -1.4338e-01, -1.3337e-01, -2.8230e-02,  4.6697e-02,
          1.0246e-01, -5.4214e-02,  2.7489e-02,  2.5758e-02,  6.0505e-04,
          6.7776e-02, -1.3753e-01,  6.6764e-02,  4.5799e-02,  9.4719e-02,
          1.4857e-01, -8.5030e-02,  3.3389e-02, -8.1973e-02,  5.6698e-02,
 

In [61]:
dti_model

DTIModel(
  (graph_transformer): GraphTransformerNet(
    (input_fc): Linear(in_features=5, out_features=128, bias=True)
    (layers): ModuleList(
      (0-11): 12 x GCNConv(128, 128)
    )
    (multihead_attn): MultiHeadAttentionLayer(
      (Q): Linear(in_features=128, out_features=1024, bias=True)
      (K): Linear(in_features=128, out_features=1024, bias=True)
      (V): Linear(in_features=128, out_features=1024, bias=True)
      (fc): Linear(in_features=1024, out_features=128, bias=True)
    )
    (final_layer): Linear(in_features=128, out_features=128, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (protein_fc): Linear(in_features=1024, out_features=128, bias=True)
  (final_fc): Linear(in_features=256, out_features=1, bias=True)
)

In [139]:
# This will print a basic summary of the model
print(dti_model)


DTIModel(
  (graph_transformer): GraphTransformerNet(
    (layers): ModuleList(
      (0-2): 3 x GCNConv(128, 128)
    )
    (input_fc): Linear(in_features=5, out_features=128, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (final_layer): GCNConv(128, 128)
  )
  (protein_fc): Linear(in_features=1024, out_features=128, bias=True)
  (final_fc): Linear(in_features=256, out_features=1, bias=True)
)


In [62]:
# Print the basic model summary
print("Model Summary:")


# Print detailed parameter information
print("\nDetailed Parameter Information:")
print("="*86)
print(f"{'Layer':<40} {'Size':<30} {'Requires Grad'}")
print("-"*86)

for name, param in dti_model.named_parameters():
    print(f"{name:<40} {str(param.size()):<30} {param.requires_grad}")
print("="*86)


Model Summary:

Detailed Parameter Information:
Layer                                    Size                           Requires Grad
--------------------------------------------------------------------------------------
graph_transformer.input_fc.weight        torch.Size([128, 5])           True
graph_transformer.input_fc.bias          torch.Size([128])              True
graph_transformer.layers.0.bias          torch.Size([128])              True
graph_transformer.layers.0.lin.weight    torch.Size([128, 128])         True
graph_transformer.layers.1.bias          torch.Size([128])              True
graph_transformer.layers.1.lin.weight    torch.Size([128, 128])         True
graph_transformer.layers.2.bias          torch.Size([128])              True
graph_transformer.layers.2.lin.weight    torch.Size([128, 128])         True
graph_transformer.layers.3.bias          torch.Size([128])              True
graph_transformer.layers.3.lin.weight    torch.Size([128, 128])         True
graph_tra