In [4]:
from data_utils import *

import os

In [5]:
# Paths
train_pairs_f = "/ix/djishnu/Aaron_F/ES_interact/Results/training_data/out_files/pos_neg_pairs.csv"
kinase_emb_f = "/ix/djishnu/Aaron_F/ES_interact/Results/ser_thr_kinases/out_files/kinase_esm2_embeddings.pkl"
out_dir = "/ix/djishnu/Aaron_F/ES_interact/Results/substrate_attention"

os.makedirs(f"{out_dir}/out_files", exist_ok=True)
os.makedirs(f"{out_dir}/figures", exist_ok=True)

In [6]:
# Define your ID corrections
id_corrections = {
    'C0HM02': 'P24723',
    # Add more mappings as needed
}

df, kinase_embeddings = load_training_data(
    train_pairs_f, 
    kinase_emb_f,
    id_remap=id_corrections
)

Loading training data from /ix/djishnu/Aaron_F/ES_interact/Results/training_data/out_files/pos_neg_pairs.csv

Applying ID remapping for 1 kinase(s):
  C0HM02 -> P24723 (134 examples)

Loading kinase embeddings from /ix/djishnu/Aaron_F/ES_interact/Results/ser_thr_kinases/out_files/kinase_esm2_embeddings.pkl
Loaded 49029 examples (16343 positive, 32686 negative)
Loaded embeddings for 311 kinases


In [7]:
# Test encoding
test_kmer = "SASPYPEHA"
encoded = encode_substrate_sequence(test_kmer)
print(f"\nTest encoding: {test_kmer} -> {encoded}")


Test encoding: SASPYPEHA -> [1 0 1 0 1 0 3 2 0]


In [8]:
# Create splits
train_df, val_df, test_df = create_train_val_test_splits(df)

# Create dataloaders
train_loader, val_loader, test_loader = create_dataloaders(
    train_df, val_df, test_df, kinase_embeddings, batch_size=32
)



Split sizes:
  Train: 34320 (11440 pos, 22880 neg)
  Val:   7354 (2451 pos, 4903 neg)
  Test:  7355 (2452 pos, 4903 neg)


In [10]:
# Test batch loading
batch = next(iter(train_loader))
print(f"\nBatch shapes:")
print(f"  kinase_embedding: {batch['kinase_embedding'].shape}")  # Should be [32, 1280]
print(f"  substrate_encoded: {batch['substrate_encoded'].shape}")  # Should be [32, 9]
print(f"  label: {batch['label'].shape}")  # Should be [32]
print(f"\nFirst example:")
print(f"  Kinase: {batch['kinase_name'][0]}")
print(f"  Substrate: {batch['substrate_seq'][0]}")
print(f"  Encoded: {batch['substrate_encoded'][0]}")
print(f"  Label: {batch['label'][0]}")


Batch shapes:
  kinase_embedding: torch.Size([32, 1280])
  substrate_encoded: torch.Size([32, 9])
  label: torch.Size([32])

First example:
  Kinase: P36897
  Substrate: NGSPRPRRG
  Encoded: tensor([1, 1, 1, 0, 2, 0, 2, 2, 1])
  Label: 0.0


In [11]:
from model import KinaseSubstrateAttentionModel, count_parameters
import torch

In [12]:

# Initialize model
model = KinaseSubstrateAttentionModel(
    num_groups=5,
    substrate_embedding_dim=64,
    kinase_dim=1280,
    attention_heads=4,
    hidden_dim=256,
    dropout=0.1
)

print("Model Architecture:")
print(model)
print(f"\nTotal parameters: {count_parameters(model):,}")

# Test with a batch from your dataloader
batch = next(iter(train_loader))

kinase_emb = batch['kinase_embedding']  # [32, 1280]
substrate_enc = batch['substrate_encoded']  # [32, 9]
labels = batch['label']  # [32]

print(f"\nInput shapes:")
print(f"  Kinase: {kinase_emb.shape}")
print(f"  Substrate: {substrate_enc.shape}")
print(f"  Labels: {labels.shape}")

# Forward pass
predictions, attention_weights = model(kinase_emb, substrate_enc)

print(f"\nOutput shapes:")
print(f"  Predictions: {predictions.shape}")  # Should be [32, 1]
print(f"  Attention weights: {attention_weights.shape}")  # Should be [32, 9, 1]

print(f"\nSample outputs:")
print(f"  First prediction: {predictions[0].item():.4f}")
print(f"  First label: {labels[0].item()}")
print(f"  First attention weights: {attention_weights[0].squeeze()}")

Model Architecture:
KinaseSubstrateAttentionModel(
  (substrate_embedding): Embedding(5, 64)
  (cross_attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
  )
  (kinase_projection): Linear(in_features=1280, out_features=64, bias=True)
  (classifier): Sequential(
    (0): Linear(in_features=64, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=128, out_features=1, bias=True)
    (7): Sigmoid()
  )
)

Total parameters: 148,609

Input shapes:
  Kinase: torch.Size([32, 1280])
  Substrate: torch.Size([32, 9])
  Labels: torch.Size([32])

Output shapes:
  Predictions: torch.Size([32, 1])
  Attention weights: torch.Size([32, 9, 1])

Sample outputs:
  First prediction: 0.5034
  First label: 1.0
  First attention weights: tensor([0.8333, 1.1111, 1.11