In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/ppispd/train_2140.pkl
/kaggle/input/ppispd/test_2140.pkl


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import os
import random
import pickle

# Set seed for reproducibility
seed = 2024
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
class PPIDataset(Dataset):
    def __init__(self, protein_pairs):
        self.protein_pairs = protein_pairs

    def __len__(self):
        return len(self.protein_pairs)

    def __getitem__(self, idx):
        pair = self.protein_pairs[idx]
        pair['label'][:, 2] = (pair['label'][:, 2] == 1)
        return {
            'l_features': torch.FloatTensor(pair['l_bert_features']),
            'r_features': torch.FloatTensor(pair['r_bert_features']),
            'labels': torch.LongTensor(pair['label'])
        }

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PPIPredictor(nn.Module):
    def __init__(self, 
                 input_dim=1024,
                 hidden_dim=512,
                 num_heads=8,
                 num_layers=2,
                 dropout=0.1):
        super().__init__()

        # Dimensionality reduction for BERT features
        self.l_feature_projection = nn.Linear(input_dim, hidden_dim)
        self.r_feature_projection = nn.Linear(input_dim, hidden_dim)
        self.l_ln = nn.LayerNorm(hidden_dim)
        self.r_ln = nn.LayerNorm(hidden_dim)

        # Cross-attention layers using MultiheadAttention
        self.l2r_attention_layers = nn.ModuleList([
            nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout, batch_first=True)
            for _ in range(num_layers)
        ])
        self.r2l_attention_layers = nn.ModuleList([
            nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout, batch_first=True)
            for _ in range(num_layers)
        ])

        # Layer normalization for cross-attention outputs
        self.l2r_norm_layers = nn.ModuleList([
            nn.LayerNorm(hidden_dim) for _ in range(num_layers)
        ])
        self.r2l_norm_layers = nn.ModuleList([
            nn.LayerNorm(hidden_dim) for _ in range(num_layers)
        ])

        # Final prediction layers
        self.interaction_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dim),  # Added normalization here
            nn.Linear(hidden_dim, hidden_dim // 2),  # New linear layer
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dim // 2),  # Normalization after new layer
            nn.Linear(hidden_dim // 2, 1)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, l_features, r_features):
        batch_size = l_features.size(0)

        # Project features to lower dimension
        l_hidden = self.l_ln(self.dropout(F.relu(self.l_feature_projection(l_features))))
        r_hidden = self.r_ln(self.dropout(F.relu(self.r_feature_projection(r_features))))

        # Apply cross-attention layers
        for l2r_attn, r2l_attn, l2r_norm, r2l_norm in zip(self.l2r_attention_layers, 
                                                         self.r2l_attention_layers,
                                                         self.l2r_norm_layers, 
                                                         self.r2l_norm_layers):
            # Ligand attending to receptor
            l2r_output, _ = l2r_attn(l_hidden, r_hidden, r_hidden)
            l_hidden = l2r_norm(l_hidden + self.dropout(l2r_output))  # Residual + normalization

            # Receptor attending to ligand
            r2l_output, _ = r2l_attn(r_hidden, l_hidden, l_hidden)
            r_hidden = r2l_norm(r_hidden + self.dropout(r2l_output))  # Residual + normalization

        # Create all pairs of residue representations
        l_hidden_expanded = l_hidden.unsqueeze(2).expand(-1, -1, r_hidden.size(1), -1)
        r_hidden_expanded = r_hidden.unsqueeze(1).expand(-1, l_hidden.size(1), -1, -1)

        # Concatenate ligand and receptor representations for each pair
        pair_repr = torch.cat([l_hidden_expanded, r_hidden_expanded], dim=-1)

        # Predict interaction scores
        interaction_scores = self.interaction_predictor(pair_repr).squeeze(-1)

        return interaction_scores


In [5]:
class WeightedBCELoss(nn.Module):
    def __init__(self, pos_weight=None):
        super().__init__()
        self.pos_weight = pos_weight

    def forward(self, pred, target):
        bce_loss = F.binary_cross_entropy_with_logits(
            pred,
            target,
            pos_weight=self.pos_weight
        )
        return bce_loss.mean()

In [6]:
def train_epoch(model, dataloader, optimizer, criterion, device, gradient_clip_val=1.0):
    model.train()
    total_loss = 0

    for batch in dataloader:
        l_features = batch['l_features'].to(device)
        r_features = batch['r_features'].to(device)
        labels = batch['labels']
        
        optimizer.zero_grad()

        # Forward pass
        pred = model(l_features, r_features)

        # Compute loss for ground-truth interactions
        interaction_targets = []
        interaction_preds = []
        for label_tensor in labels[0]:
            l_idx, r_idx, target = label_tensor
            interaction_targets.append(target.float().to(device))
            interaction_preds.append(pred[0, l_idx, r_idx])

        interaction_targets = torch.stack(interaction_targets).to(device)
        interaction_preds = torch.stack(interaction_preds).to(device)
        loss = criterion(interaction_preds, interaction_targets)

        # Backward pass and optimization
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_val)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [7]:
def init_weights(m):
    if hasattr(m, 'weight') and m.weight is not None and m.weight.dim() >= 2:
        nn.init.xavier_uniform_(m.weight)
    if hasattr(m, 'bias') and m.bias is not None:
        nn.init.constant_(m.bias, 0)


In [8]:
from torch.optim.lr_scheduler import CosineAnnealingLR
def train(config, protein_pairs):
    # Seed setup for reproducibility
    seed = 2024
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    train_dataset = PPIDataset(protein_pairs)
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)

    model = PPIPredictor(input_dim=1024, hidden_dim=config['hidden_dim'],
                         num_heads=config['num_heads'], num_layers=config['num_layers'],
                         dropout=config['dropout']).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
    criterion = WeightedBCELoss(pos_weight=torch.tensor(config['pos_weight'], dtype=torch.float).to(device))
    scheduler = CosineAnnealingLR(optimizer, T_max=config['epochs'], eta_min=config['eta_min'])


    # Initialize model weights
    model.apply(init_weights)

    for epoch in range(config['epochs']):
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device,
                                 gradient_clip_val=config['gradient_clip_val'])
        scheduler.step()

        print(f"Epoch {epoch + 1}/{config['epochs']}, Loss: {train_loss:.4f}")

    return model

In [9]:
from sklearn import metrics
def eval(model, protein_pairs, device='cuda'):
    test_dataset = PPIDataset(protein_pairs)
    dataloader = DataLoader(test_dataset, batch_size=1)
    all_labels = []
    all_preds = []

    model.eval()

    with torch.no_grad():
        for batch in dataloader:
            l_features = batch['l_features'].to(device)
            r_features = batch['r_features'].to(device)
            labels = batch['labels']

            pred = model(l_features, r_features)

            for label_tensor in labels[0]:
                l_idx, r_idx, target = label_tensor
                all_labels.append(target.item())
                all_preds.append(torch.sigmoid(pred[0, l_idx, r_idx]).item())

    # Convert lists to NumPy arrays for metric calculations
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)

    # Metrics calculation
    binary_preds = (all_preds >= 0.5).astype(int)
    accuracy = metrics.accuracy_score(all_labels, binary_preds)
    precision = metrics.precision_score(all_labels, binary_preds)
    recall = metrics.recall_score(all_labels, binary_preds)
    f1 = metrics.f1_score(all_labels, binary_preds)
    pr_auc = metrics.average_precision_score(all_labels, all_preds)
    roc_auc = metrics.roc_auc_score(all_labels, all_preds)

    return accuracy, precision, recall, f1, pr_auc, roc_auc

In [10]:
config = {
    'batch_size': 1,  
    'hidden_dim': 768,  
    'num_heads': 16,
    'num_layers': 2,
    'dropout': 0.2,  
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'eta_min': 1e-6,
    'epochs': 58,
    'pos_weight': 10.0,  
    'gradient_clip_val': 1.0
}

In [11]:
with open("/kaggle/input/ppispd/train_2140.pkl", "rb") as f:
    data = pickle.load(f)
train_protein_pairs = data[1]

In [12]:
with open("/kaggle/input/ppispd/test_2140.pkl", "rb") as f:
    data = pickle.load(f)
test_protein_pairs = data[1]

In [13]:
model = train(config,train_protein_pairs)
torch.save(model, "model.pth")

Epoch 1/58, Loss: 1.1789
Epoch 2/58, Loss: 1.0934
Epoch 3/58, Loss: 1.0496
Epoch 4/58, Loss: 0.9980
Epoch 5/58, Loss: 0.9419
Epoch 6/58, Loss: 0.8734
Epoch 7/58, Loss: 0.8095
Epoch 8/58, Loss: 0.7502
Epoch 9/58, Loss: 0.7017
Epoch 10/58, Loss: 0.6579
Epoch 11/58, Loss: 0.6191
Epoch 12/58, Loss: 0.5875
Epoch 13/58, Loss: 0.5563
Epoch 14/58, Loss: 0.5301
Epoch 15/58, Loss: 0.5076
Epoch 16/58, Loss: 0.4876
Epoch 17/58, Loss: 0.4668
Epoch 18/58, Loss: 0.4519
Epoch 19/58, Loss: 0.4337
Epoch 20/58, Loss: 0.4173
Epoch 21/58, Loss: 0.4042
Epoch 22/58, Loss: 0.3909
Epoch 23/58, Loss: 0.3790
Epoch 24/58, Loss: 0.3657
Epoch 25/58, Loss: 0.3549
Epoch 26/58, Loss: 0.3457
Epoch 27/58, Loss: 0.3352
Epoch 28/58, Loss: 0.3250
Epoch 29/58, Loss: 0.3167
Epoch 30/58, Loss: 0.3078
Epoch 31/58, Loss: 0.3003
Epoch 32/58, Loss: 0.2926
Epoch 33/58, Loss: 0.2846
Epoch 34/58, Loss: 0.2775
Epoch 35/58, Loss: 0.2704
Epoch 36/58, Loss: 0.2642
Epoch 37/58, Loss: 0.2566
Epoch 38/58, Loss: 0.2517
Epoch 39/58, Loss: 0.

In [14]:
accuracy, precision, recall, f1, pr_auc, roc_auc = eval(model, test_protein_pairs)
print(f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")
print(f"PR AUC: {pr_auc:.4f}, ROC AUC: {roc_auc:.4f}")

Accuracy: 0.8583, Precision: 0.2292, Recall: 0.2363, F1 Score: 0.2327
PR AUC: 0.1998, ROC AUC: 0.6967
