In [None]:
#@title Setup **SPPIDER-seq** (~5min)
%%time

print("Installing libraries and importing dependencies...")

# Install Required Libraries
!pip install transformers biopython torch --quiet

# Import Dependencies
import os
import re
import torch
import numpy as np
import pandas as pd
from transformers import EsmTokenizer, EsmModel
from Bio import SeqIO
from io import StringIO
from itertools import product
import ipywidgets as ipw
from IPython.display import display, Markdown
import matplotlib.pyplot as plt
import torch.nn as nn
import warnings
from datetime import datetime
import urllib.request

warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

# Upload Pretrained Models

print("Copying the PPI prediction models...")

# Create virtual folders on Colab
os.makedirs("models", exist_ok=True)
os.makedirs("outputs", exist_ok=True)

# Copy the models to Colab
peptide_model_url = "https://raw.githubusercontent.com/aporollo-lab/SPPIDER-seq/main/models/best_model-AH16_EP10-PEP2REC_DR01_CHUNKS-Iter22.pt"
peptide_model_path = "models/peptide_model.pt"
_ = urllib.request.urlretrieve(peptide_model_url, peptide_model_path)
receptor_model_url = "https://raw.githubusercontent.com/aporollo-lab/SPPIDER-seq/main/models/best_model-AH16_EP10-REC2PEP_DR01_CHUNKS-Iter24.pt"
receptor_model_path = "models/receptor_model.pt"
_ = urllib.request.urlretrieve(receptor_model_url, receptor_model_path)


# Define global variables

# Track all prediction sessions
output_sessions = []
current_output_folder = None


# Define functions

# Embedding Utilities
def parse_fasta(text):
    records = list(SeqIO.parse(StringIO(text.strip()), "fasta"))
    return [(i, rec.id, str(rec.seq)) for i, rec in enumerate(records)]

def embed_sequence_chunks(model, tokenizer, sequence, seq_name, max_len=1024, stride=512):
    model.eval()
    device = next(model.parameters()).device
    tokens = tokenizer(sequence, return_tensors='pt', truncation=False)['input_ids'][0]

    chunk_data = []
    for i in range(0, len(tokens), stride):
        chunk = tokens[i:i + max_len]
        if chunk.size(0) < 2:
            continue

        input_ids = chunk.unsqueeze(0).to(device)
        with torch.no_grad():
            out = model(input_ids, output_hidden_states=True)
            emb = out.last_hidden_state[:, 1:-1, :]  # Remove CLS and EOS

        tokens_decoded = tokenizer.convert_ids_to_tokens(chunk[1:-1])
        aa_seq = "".join(t.replace("\u2581", "") for t in tokens_decoded)

        chunk_data.append({
            "chunk_start": i,
            "chunk_end": i + max_len,
            "chunk_seq": aa_seq,
            "embedding": emb.squeeze(0).cpu()
        })

    return chunk_data

# Embedding Workflow
def run_embedding_workflow(btn):
    global current_output_folder, output_sessions
    global all_pair_chunks
    all_pair_chunks = []

    with output_box:
        output_box.clear_output()

        # Create timestamped output folder
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        current_output_folder = f"outputs/{timestamp}"
        os.makedirs(current_output_folder, exist_ok=True)
        output_sessions.append(current_output_folder)

        print("\nLoading ESM-2 model...")
        model_name = "facebook/esm2_t33_650M_UR50D"
        tokenizer = EsmTokenizer.from_pretrained(model_name)
        model = EsmModel.from_pretrained(model_name).eval().to("cuda" if torch.cuda.is_available() else "cpu")

        print("\nParsing FASTA inputs...")
        seqs1 = parse_fasta(query_seq_input.value)
        seqs2 = parse_fasta(partner_seq_input.value)

        pairs = list(product(seqs1, seqs2))

        for (i1, id1, s1), (i2, id2, s2) in pairs:
            print(f"\nGenerating embeddings for the pair: {id1} vs {id2}")
            rec_chunks = embed_sequence_chunks(model, tokenizer, s1, id1)
            lig_chunks = embed_sequence_chunks(model, tokenizer, s2, id2)

            print(f"  → {len(rec_chunks)} receptor chunks")
            print(f"  → {len(lig_chunks)} ligand chunks")

            # global all_pair_chunks
            # if 'all_pair_chunks' not in globals():
            #     all_pair_chunks = []

            all_pair_chunks.append({
                "receptor_id": id1,
                "ligand_id": id2,
                "receptor_seq": s1,
                "ligand_seq": s2,
                "receptor_chunks": rec_chunks,
                "peptide_chunks": lig_chunks
            })

        print("\nRunning PPI predictions...")
        run_ppi_predictions(peptide_model_path="models/peptide_model.pt",
                            receptor_model_path="models/receptor_model.pt",
                            output_dir=current_output_folder)

# Load Model Architecture
class CrossAttentionLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, query, context, context_mask=None):
        attn_output, _ = self.cross_attn(query, context, context, key_padding_mask=context_mask)
        return self.norm(query + attn_output)

class ChunkwiseInteractionModel(nn.Module):
    def __init__(self, embed_dim=1280, num_heads=16, initial_bias=None):
        super().__init__()
        self.cross_attn = CrossAttentionLayer(embed_dim, num_heads)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(embed_dim, 1)
        )
        if initial_bias is not None:
            self.mlp[-1].bias.data.fill_(initial_bias)

    def forward(self, chunks_A, chunks_B, stride=512):
        position_logits = {}
        for i, a_chunk in enumerate(chunks_A):
            if a_chunk.ndim == 2:
                a_chunk = a_chunk.unsqueeze(0)

            all_logits = []
            for b_chunk in chunks_B:
                if b_chunk.ndim == 2:
                    b_chunk = b_chunk.unsqueeze(0)

                context_mask = (b_chunk.abs().sum(dim=-1) == 0)
                x = self.cross_attn(a_chunk, b_chunk, context_mask=context_mask)
                logits = self.mlp(x).squeeze(0).squeeze(-1)

                if len(all_logits) > 0:
                    max_len = max(l.shape[0] for l in all_logits)
                    if logits.shape[0] < max_len:
                        pad_size = max_len - logits.shape[0]
                        logits = torch.cat([logits, torch.zeros(pad_size, device=logits.device)], dim=0)
                all_logits.append(logits)

            pooled_logits = torch.max(torch.stack(all_logits, dim=0), dim=0).values
            start = i * stride
            for j in range(pooled_logits.shape[0]):
                pos = start + j
                if pos not in position_logits:
                    position_logits[pos] = []
                position_logits[pos].append(pooled_logits[j])

        return position_logits

# Predict PPI from a Given Pair
def predict_probs(chunks_A, chunks_B, model_path, seq_len, stride=512):
    model = ChunkwiseInteractionModel().to("cuda" if torch.cuda.is_available() else "cpu")
    model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False)
    model.eval()

    device = next(model.parameters()).device
    A_embs = [x.to(device) for x in chunks_A]
    B_embs = [x.to(device) for x in chunks_B]

    position_logits = model(A_embs, B_embs)
    full_probs = np.zeros(seq_len)
    for pos in range(seq_len):
        if pos in position_logits:
            logits = torch.stack(position_logits[pos])
            full_probs[pos] = torch.sigmoid(logits).max().item()
    return full_probs

# File saving functions
def sanitize_filename(text):
    return re.sub(r'[^a-zA-Z0-9_.-]', '_', text)

def save_probabilities(filename, seq_id, partner_id, model_type, sequence, probabilities):
    with open(filename, 'w') as f:
        f.write(f"# Query:{seq_id} Partner:{partner_id} Model:{model_type}-centric\n")
        f.write("Position\tAminoAcid\tProbability\n")
        for i, (aa, prob) in enumerate(zip(sequence, probabilities), 1):
            f.write(f"{i}\t{aa}\t{prob:.3f}\n")

# Run PPI Predictions
def run_ppi_predictions(peptide_model_path="models/peptide_model.pt",
                        receptor_model_path="models/receptor_model.pt",
                        output_dir="outputs"):
    for i, pair in enumerate(all_pair_chunks):
        rec_id, pep_id = pair['receptor_id'], pair['ligand_id']
        rec_seq, pep_seq = pair['receptor_seq'], pair['ligand_seq']
        rec_chunks = [c["embedding"] for c in pair["receptor_chunks"]]
        pep_chunks = [c["embedding"] for c in pair["peptide_chunks"]]

        print(f"[{i+1}/{len(all_pair_chunks)}] Predicting for: {rec_id} ↔ {pep_id}")

        probs_peptide = predict_probs(rec_chunks, pep_chunks, peptide_model_path, len(rec_seq))
        probs_receptor = predict_probs(rec_chunks, pep_chunks, receptor_model_path, len(rec_seq))

        # Save results

        # Safe filenames
        safe_rec_id = sanitize_filename(rec_id)
        safe_pep_id = sanitize_filename(pep_id)

        # Save detailed output
        pep_file = os.path.join(output_dir, f"{safe_rec_id}__{safe_pep_id}__peptide_centric.txt")
        rec_file = os.path.join(output_dir, f"{safe_rec_id}__{safe_pep_id}__receptor_centric.txt")
        save_probabilities(pep_file, rec_id, pep_id, "peptide", rec_seq, probs_peptide)
        save_probabilities(rec_file, rec_id, pep_id, "receptor", rec_seq, probs_receptor)

        # Generate a combined plot, show it and save to a file
        fig_name = os.path.join(output_dir, f"{safe_rec_id}__{safe_pep_id}__plot.png")
        plt.figure(figsize=(10, 4))
        plt.plot(probs_receptor, label="Receptor-Centric", color='orange')
        plt.plot(probs_peptide, label="Peptide-Centric", linestyle='--', color='blue')
        plt.title(f"PPI Site Probabilities\nQuery: {rec_id} with Partner: {pep_id}")
        plt.xlabel("Query Residue Position")
        plt.ylabel("Probability")
        plt.legend(title="Prediction Model")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(fig_name)
        plt.show()



In [None]:
#@title Input Form
# Process Input and Run Predictions

# Input FASTA Sequences from User
query_seq_input = ipw.Textarea(
    value='>query\nMKTFFVGLAALVTMATGVHS',
    description='Query',
    layout=ipw.Layout(width='80%', height='100px')
)

partner_seq_input = ipw.Textarea(
    value='>partner\nMTEITAAMVKELRESTGAGM',
    description='Partner(s)',
    layout=ipw.Layout(width='80%', height='100px')
)

run_button = ipw.Button(
    description="Run PPI predictions",
    button_style='primary',
    layout=ipw.Layout(margin='20px 0px 0px 0px')
)

output_box = ipw.Output()

display(Markdown("### Input your query and partner protein sequences in FASTA format"))
display(query_seq_input, partner_seq_input, run_button, output_box)

run_button.on_click(run_embedding_workflow)



In [None]:
#@title Review and Download Results
import zipfile
from glob import glob
from IPython.display import FileLink, display
from google.colab import files

# Widget containers
view_box = ipw.Output()
download_button = ipw.Button(description="⬇️ Download All Files", button_style='success')
download_button.layout.display = 'none'  # Hidden until selection
current_zip_path = {"path": None}

# Create ZIP inside outputs/ folder
def create_zip(folder):
    base_name = os.path.basename(folder.rstrip("/"))
    zip_path = f"{folder.rstrip('/')}.zip"
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for file in glob(f"{folder}/*"):
            arcname = os.path.basename(file)
            zipf.write(file, arcname=arcname)
    return zip_path

# Dropdown callback
def on_select_run(change):
    folder = change['new']
    if not folder:
        view_box.clear_output()
        download_button.layout.display = 'none'
        return

    with view_box:
        view_box.clear_output()
        print(f"📂 Files in: {folder}")
        for f in sorted(glob(f"{folder}/*")):
            print(f"• {os.path.basename(f)}")

        # Create ZIP file
        zip_path = create_zip(folder)
        current_zip_path["path"] = zip_path
        download_button.description = f"Download All ({os.path.basename(zip_path)})"
        download_button.layout.display = 'inline-block'
        download_button.layout=ipw.Layout(width='300px')

# Download button click handler
def on_download_clicked(btn):
    zip_path = current_zip_path.get("path")
    if zip_path and os.path.exists(zip_path):
        files.download(zip_path)

# Hook up download action
download_button.on_click(on_download_clicked)

# List all subfolders under outputs/
output_sessions = sorted(glob("outputs/*/"))
session_selector = ipw.Dropdown(
    options=[""] + output_sessions,
    description=f'Select query ({len(output_sessions)} found):',
    layout=ipw.Layout(min_width='400px', max_width='800px'),
    style={'description_width': 'initial'}
)

session_selector.observe(on_select_run, names='value')

# Display UI
display(Markdown(
    "🔁 **Tip:** If you've run multiple PPI predictions, re-run this cell to **refresh** the list of available results in the dropdown menu."
))
display(ipw.VBox([session_selector, view_box, download_button]))
