In [24]:
SYNA_PARENT_DIR = "/home/shd-sun-lab/SynapseNavigator"

In [25]:
import sys
import os
sys.path.append(SYNA_PARENT_DIR) # append the path of protgps
from argparse import Namespace
import pickle
import copy
import yaml
import requests
from tqdm import tqdm
from p_tqdm import p_map
import numpy as np
import pandas as pd
from collections import defaultdict
import torch 
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz
import protpy
from protpy import amino_acids as protpyAA
from sklearn.metrics import roc_auc_score
from matplotlib import pyplot as plt
from protgps.utils.loading import get_object

In [26]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Functions

### Make sure you run these, but can be ignored after

In [27]:
#Compartment labels & helper

COMPARTMENTS = [
    "cytosol",
    "ER",
    "mitochondrion",
    "nucleus",
    "Excitatory Synapse",
    "Inhibitory Synapses",
]

# In the original code this is identical, but conceptually OLDCOMPS is
# "the order used in the committed code / old model".
OLDCOMPS = [
    "cytosol",
    "ER",
    "mitochondrion",
    "nucleus",
    "Excitatory Synapse",
    "Inhibitory Synapses",
]

def transform_y(y: torch.Tensor):
    """
    Convert labels from COMPARTMENTS order to OLDCOMPS order.
    """
    indices = torch.nonzero(y).flatten()
    new_indices = torch.tensor([OLDCOMPS.index(COMPARTMENTS[i]) for i in indices])
    return torch.zeros(len(OLDCOMPS)).scatter_(0, new_indices, 1)


In [28]:

UNIPROT_FASTA_URL = "https://rest.uniprot.org/uniprotkb/{}.fasta"
UNIPROT_JSON_URL  = "https://rest.uniprot.org/uniprotkb/{}.json"


def get_organism(uniprot_id: str) -> str:
    """
    Return the scientific name of the organism for a UniProt ID,
    or empty string if not found.
    """
    r = requests.get(UNIPROT_JSON_URL.format(uniprot_id))
    if r.status_code != 200:
        return ""
    js = r.json()
    if "organism" in js:
        return js["organism"].get("scientificName", "")
    return ""


def parse_fasta(f: str) -> str:
    """
    Parse fasta text and return the amino acid sequence.
    """
    seq = []
    for line in f.splitlines():
        if line.startswith(">"):
            continue
        seq.append(line.strip())
    return "".join(seq)

from typing import Optional

def get_protein_fasta(uniprot_id: str) -> Optional[str]:
    """
    Download protein sequence for a UniProt ID.
    Returns None if request fails.
    """
    r = requests.get(UNIPROT_FASTA_URL.format(uniprot_id))
    if r.status_code != 200:
        return None
    return parse_fasta(r.text)


In [29]:
#Quick UniProt test
test_id = "O14983"
seq = get_protein_fasta(test_id)
org = get_organism(test_id)
print("ID:", test_id)
print("Organism:", org)
print("Sequence length:", len(seq) if seq else None)
print("First 60 AA:", (seq[:60] + "...") if seq else "None")



ID: O14983
Organism: Homo sapiens
Sequence length: 1001
First 60 AA: MEAAHAKTTEECLAYFGVSETTGLTPDQVKRNLEKYGLNELPAEEGKTLWELVIEQFEDL...


In [30]:
# [6] Model helper functions

def load_model(snargs):
    """
    Loads classifier model from args file + checkpoint path in snargs.model_path
    """
    modelpath = snargs.model_path
    model = get_object(snargs.lightning_name, "lightning")(snargs)
    model = model.load_from_checkpoint(
        checkpoint_path=modelpath,
        strict=not snargs.relax_checkpoint_matching,
        **{"args": snargs},
    )
    return model, snargs


def predict_condensates(model, sequences, batch_size=8, round_scores=True):
    """
    Run the model on a list of sequences and return scores tensor [N, n_labels].
    """
    all_scores = []
    model.eval()
    for i in tqdm(range(0, len(sequences), batch_size), ncols=100):
        batch = sequences[i : i + batch_size]
        with torch.no_grad():
            out = model.model({"x": batch})
        s = torch.sigmoid(out["logit"]).to("cpu")
        all_scores.append(s)
    scores = torch.vstack(all_scores)
    if round_scores:
        scores = torch.round(scores, decimals=3)
    return scores


def get_valid_rows(df, cols):
    """
    Keep only rows where all sequences in `cols` are shorter than 1800 AA.
    """
    rows_with_valid_seq_len = []
    for i in range(len(df)):
        if all(len(df.iloc[i][c]) < 1800 for c in cols):
            rows_with_valid_seq_len.append(i)
    return rows_with_valid_seq_len


In [31]:
#Load args + checkpoint from your SynapseNavigator folder

import glob

checkpoint_dir = os.path.join(
    SYNA_PARENT_DIR,
    "checkpoints",
    "protgps",
    "2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)",
)

print("Checkpoint dir:", checkpoint_dir)
print("Contents:", os.listdir(checkpoint_dir))

args_files = glob.glob(os.path.join(checkpoint_dir, "*.args"))
ckpt_files = glob.glob(os.path.join(checkpoint_dir, "*.ckpt"))

print("Found args files:", args_files)
print("Found ckpt files:", ckpt_files)

assert len(args_files) == 1, "Expected exactly one .args file here."
assert len(ckpt_files) == 1, "Expected exactly one .ckpt file here."

args_path = args_files[0]
ckpt_path = ckpt_files[0]

print("Using ARGS:", args_path)
print("Using CKPT:", ckpt_path)

# Load args
with open(args_path, "rb") as f:
    args = Namespace(**pickle.load(f))

# Fix up paths inside args
args.pretrained_hub_dir = os.path.join(SYNA_PARENT_DIR, "checkpoints", "esm2")
args.model_path         = ckpt_path
args.dataset_file_path  = os.path.join(SYNA_PARENT_DIR, "data", "dataset.json")

# Load model
model, args = load_model(args)
model = model.to(device)
model.eval()

print("Loaded model:", type(model))
print("Dataset name:", args.dataset_name)


Checkpoint dir: /home/shd-sun-lab/SynapseNavigator/checkpoints/protgps/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)
Contents: ['5f71fcd178bd6ceebc393003a6db70d9.args', '2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)_sampled.xlsx', '2025.10.17_Genes_Hein_2025_&_Marc_2023(1Synapse).xlsx:Zone.Identifier', '5f71fcd178bd6ceebc393003a6db70d9epoch=20.ckpt', 'Predictions.xlsx', '2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse).json', '2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)_with_preds.csv', '2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)_sampled.xlsx:Zone.Identifier', '2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse).xlsx:Zone.Identifier', 'Predictions.xlsx:Zone.Identifier', '2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse).xlsx', 'Recorded Data']
Found args files: ['/home/shd-sun-lab/SynapseNavigator/checkpoints/protgps/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)/5f71fcd178bd6ceebc393003a6db70d9.args']
Found ckpt files: ['/home/shd-sun-lab/SynapseNavigator/checkpoints/pr

Using cache found in /home/shd-sun-lab/SynapseNavigator/checkpoints/esm2/facebookresearch_esm_main
Using cache found in /home/shd-sun-lab/SynapseNavigator/checkpoints/esm2/facebookresearch_esm_main


Using ESM hidden layers 6
Loaded model: <class 'protgps.lightning.base.Base'>
Dataset name: protein_condensates_combined


In [32]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Move ALL nested submodules to GPU
model = model.to(device)
model.model = model.model.to(device)
model.model.encoder = model.model.encoder.to(device)
model.model.encoder.model = model.model.encoder.model.to(device)

model.eval()

def forward(batch_tokens):
    batch_tokens = batch_tokens.to(device)
    model.zero_grad()

    # ESM forward
    result = model.model.encoder.model(
        batch_tokens,
        repr_layers=[model.model.encoder.repr_layer],
        return_contacts=False,
    )

    hidden = result["representations"][model.model.encoder.repr_layer].mean(dim=1)

    out = model.model.mlp({"x": hidden})["logit"]
    scores = torch.sigmoid(out)

    return scores


# Predictions on Additional Data 

## Predict files can also be used

In [None]:
#Load train/dev/test datasets from dataset
# Ignore, this is just to get an overview of original dataset size

train_dataset = get_object(args.dataset_name, "dataset")(args, "train")
dev_dataset   = get_object(args.dataset_name, "dataset")(args, "dev")
test_dataset  = get_object(args.dataset_name, "dataset")(args, "test")

print("Train/dev/test sizes:",
      len(train_dataset.dataset),
      len(dev_dataset.dataset),
      len(test_dataset.dataset))

train_sequences = set(d["x"] for d in (train_dataset.dataset + dev_dataset.dataset))
len(train_sequences)


100%|██████████| 6627/6627 [00:00<00:00, 110291.02it/s]


TRAIN DATASET CREATED FOR PROTEIN_CONDENSATES_COMBINED.
* 3502 Proteins.
* 1254 CYTOSOL -- 444 ER -- 519 MITOCHONDRION -- 854 NUCLEUS -- 260 EXCITATORY SYNAPSE -- 358 INHIBITORY SYNAPSES


100%|██████████| 6627/6627 [00:00<00:00, 299793.48it/s]


DEV DATASET CREATED FOR PROTEIN_CONDENSATES_COMBINED.
* 739 Proteins.
* 225 CYTOSOL -- 110 ER -- 104 MITOCHONDRION -- 199 NUCLEUS -- 59 EXCITATORY SYNAPSE -- 72 INHIBITORY SYNAPSES


100%|██████████| 6627/6627 [00:00<00:00, 327146.23it/s]

TEST DATASET CREATED FOR PROTEIN_CONDENSATES_COMBINED.
* 737 Proteins.
* 262 CYTOSOL -- 103 ER -- 95 MITOCHONDRION -- 170 NUCLEUS -- 71 EXCITATORY SYNAPSE -- 69 INHIBITORY SYNAPSES
Train/dev/test sizes: 3502 739 737





4200

In [29]:
#Load the Excel file you want to make predictions on

import pandas as pd

excel_path = "/home/shd-sun-lab/SynapseNavigator/checkpoints/protgps/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse).xlsx"

data = pd.read_excel(excel_path)
print("Loaded Excel:", excel_path)
print("Shape:", data.shape)

Loaded Excel: /home/shd-sun-lab/SynapseNavigator/checkpoints/protgps/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse).xlsx
Shape: (7011, 23)


In [30]:
# [10] Detect sequence or UniProt columns

import numpy as np

# Common sequence column names
possible_seq_cols = ["Sequence", "AA_sequence", "ProteinSequence", "WT_Sequence"]

# Common UniProt ID column names
possible_uniprot_cols = ["Entry", "UniProt", "Accession", "ProteinID"]

seq_col = next((c for c in possible_seq_cols if c in data.columns), None)
uniprot_col = next((c for c in possible_uniprot_cols if c in data.columns), None)

print("Detected sequence column:", seq_col)
print("Detected UniProt ID column:", uniprot_col)


Detected sequence column: Sequence
Detected UniProt ID column: Entry


In [31]:
# [11] Extract sequences depending on what the Excel provides
#If sequences are not present in the Excel, we fetch them from UniProt using the IDs provided

if seq_col is not None:
    # Use sequences directly
    print("Using sequences from Excel column:", seq_col)
    
    sequences = []
    valid_indices = []
    
    for idx, seq in data[seq_col].items():
        if isinstance(seq, str) and 0 < len(seq) < 1800:
            sequences.append(seq)
            valid_indices.append(idx)
    
    print("Number of valid sequences found:", len(sequences))

elif uniprot_col is not None:
    # Fallback: fetch sequences from UniProt
    print("Fetching sequences via UniProt column:", uniprot_col)
    protein_ids = data[uniprot_col].astype(str).tolist()
    
    from p_tqdm import p_map
    seqs = p_map(get_protein_fasta, protein_ids)
    
    sequences = []
    valid_indices = []
    
    for idx, seq in enumerate(seqs):
        if seq is not None and len(seq) < 1800:
            sequences.append(seq)
            valid_indices.append(idx)
    
    print("Valid sequences obtained from UniProt:", len(sequences))

else:
    raise ValueError("ERROR: Neither sequence nor UniProt ID column found.")


Using sequences from Excel column: Sequence
Number of valid sequences found: 6694


In [32]:
# [12] Run predictions on the sequences

print(f"Running predictions on {len(sequences)} sequences...")
scores = predict_condensates(model, sequences, batch_size=8, round_scores=True)

print("Score tensor shape:", scores.shape)


Running predictions on 6694 sequences...


100%|█████████████████████████████████████████████████████████████| 837/837 [03:19<00:00,  4.20it/s]

Score tensor shape: torch.Size([6694, 6])





In [33]:
# [13] Insert model scores into the DataFrame

import numpy as np

for j, comp in enumerate(OLDCOMPS):
    col_name = f"{comp.upper()}_Score"
    score_col = [np.nan] * len(data)
    
    for df_idx, score_vec in zip(valid_indices, scores):
        score_col[df_idx] = float(score_vec[j])
    
    data[col_name] = score_col

data.head()


Unnamed: 0,From,Entry,Reviewed,Entry Name,Protein names,Gene Names,Organism,Sequence,centrosome,cytosol,...,translation,Excitatory Synapse,Inhibitory Synapses,Dopaminergic Synapses,CYTOSOL_Score,ER_Score,MITOCHONDRION_Score,NUCLEUS_Score,EXCITATORY SYNAPSE_Score,INHIBITORY SYNAPSES_Score
0,A0A023T6R1,A0A023T6R1,unreviewed,A0A023T6R1_HUMAN,"Mago nashi protein (Mago-nashi homolog, isofor...",FLJ10292 hCG_1773848,Homo sapiens (Human),MAVASDFYLRYYVGHKGKFGHEFLEFEFRPDGKLRYANNSNYKNDV...,0,0,...,0,0,0,0,0.685,0.027,0.023,0.032,0.172,0.263
1,A0A024QYR6,A0A024QYR6,unreviewed,A0A024QYR6_HUMAN,"Phosphatidylinositol 3,4,5-trisphosphate 3-pho...",PTEN,Homo sapiens (Human),MERGGEAAAAAAAAAAAPGRGSESPVTISRAGNAGELVSPLLLPPT...,0,1,...,0,0,0,0,0.357,0.0,0.0,0.718,0.042,0.021
2,A0A024QYS2,A0A024QYS2,unreviewed,A0A024QYS2_HUMAN,Transmembrane 9 superfamily member,,Homo sapiens (Human),MRPLPGALGVAAAAALWLLLLLLPRTRADEHEHTYQDKEEVVLWMN...,0,0,...,0,0,0,0,0.001,0.453,0.066,0.003,0.064,0.322
3,A0A024QYX0,A0A024QYX0,unreviewed,A0A024QYX0_HUMAN,Emopamil binding protein,EBP,Homo sapiens (Human),MTTNAGPLHPYWPQHLRLDNFVPNDRPTWHILAGLFSVTGVLVVTT...,0,0,...,0,0,0,0,0.043,0.894,0.007,0.03,0.02,0.026
4,A0A024QZ64,A0A024QZ64,unreviewed,A0A024QZ64_HUMAN,Fructose-bisphosphate aldolase (EC 4.1.2.13),,Homo sapiens (Human),MPHSYPALSAEQKKELSDIALRIVAPGKGILAADESVGSMAKRLSQ...,0,1,...,0,0,0,0,0.689,0.002,0.007,0.055,0.121,0.182


In [34]:
# [14] Save output

out_path = excel_path.replace(".xlsx", "_with_preds.csv")
data.to_csv(out_path, index=False)

print("Saved predictions to:\n", out_path)


Saved predictions to:
 /home/shd-sun-lab/SynapseNavigator/checkpoints/protgps/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)_with_preds.csv


## Predictions on additional datasets (template)

This section is a template for running the trained SynapseNavigator model
on *any* new Excel dataset.

To use:
1. Point `excel_path` to your `.xlsx`.
2. Either:
   - set `seq_col` to the column containing amino acid sequences, **or**
   - set `uniprot_col` to the column containing UniProt IDs (if no sequences).
3. Run all cells below to produce a `<filename>_with_preds.csv` file.


In [10]:
# [T1] CONFIG: point to your new Excel here

import os
import pandas as pd
import numpy as np

# Example: copy+adapt this line when someone adds a new dataset
excel_path = "/path/to/your/new_dataset.xlsx"  # <--- CHANGE THIS

data = pd.read_excel(excel_path)
print("Loaded Excel:", excel_path)
print("Shape:", data.shape)
data.head()


FileNotFoundError: [Errno 2] No such file or directory: '/path/to/your/new_dataset.xlsx'

In [None]:
# [T2] Detect / set sequence or UniProt ID column

# Option A: let the code auto-detect common names
possible_seq_cols     = ["Sequence", "AA_sequence", "ProteinSequence", "WT_Sequence"]
possible_uniprot_cols = ["Entry", "UniProt", "Accession", "ProteinID"]

seq_col = next((c for c in possible_seq_cols if c in data.columns), None)
uniprot_col = next((c for c in possible_uniprot_cols if c in data.columns), None)

print("Auto-detected sequence column:", seq_col)
print("Auto-detected UniProt column:", uniprot_col)

In [None]:
# [T3] Build list of sequences to score

from p_tqdm import p_map

sequences = []
valid_indices = []

if seq_col is not None:
    print(f"Using sequences from column: {seq_col}")
    for idx, seq in data[seq_col].items():
        if isinstance(seq, str) and 0 < len(seq) < 1800:
            sequences.append(seq)
            valid_indices.append(idx)

elif uniprot_col is not None:
    print(f"Using UniProt IDs from column: {uniprot_col}")
    ids = data[uniprot_col].astype(str).tolist()
    seqs = p_map(get_protein_fasta, ids)
    for idx, seq in enumerate(seqs):
        if seq is not None and len(seq) < 1800:
            sequences.append(seq)
            valid_indices.append(idx)

else:
    raise ValueError("No usable sequence or UniProt column found.")

print("Number of valid sequences:", len(sequences))


In [None]:
# [T4] Predict scores

scores = predict_condensates(model, sequences, batch_size=8, round_scores=True)
print("Scores shape:", scores.shape)


In [None]:
# [T5] Add one score column per compartment

for j, comp in enumerate(OLDCOMPS):
    col_name = f"{comp.upper()}_Score"
    col_vals = [np.nan] * len(data)
    for df_idx, score_vec in zip(valid_indices, scores):
        col_vals[df_idx] = float(score_vec[j])
    data[col_name] = col_vals

data.head()


In [None]:
# [T6] Save <original>_with_preds.csv next to input

out_path = excel_path.replace(".xlsx", "_with_preds.csv").replace(".xls", "_with_preds.csv")
data.to_csv(out_path, index=False)
print("Saved predictions to:", out_path)


# AUCs

In [None]:
#prediction score on the 5% dataset

import os
import pandas as pd
import numpy as np
from sklearn.metrics import (
    confusion_matrix,
    roc_auc_score,
    precision_recall_curve,
    auc
)

# ================= CONFIG =================
THRESH = 0.5  # threshold for calling positives

# ================= FILES =================
pred_file = os.path.join(checkpoint_dir, "Predictions.xlsx")
truth_file = os.path.join(checkpoint_dir, "2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)_sampled.xlsx")

if not os.path.exists(pred_file):
    raise FileNotFoundError(f"Prediction file not found: {pred_file}")
if not os.path.exists(truth_file):
    raise FileNotFoundError(f"Truth file not found: {truth_file}")

print("Using prediction file:", pred_file)
print("Using truth file:", truth_file)

# ================= LOAD EXCEL FILES =================
pred_df = pd.read_excel(pred_file)
truth_df = pd.read_excel(truth_file)

# ================= COLUMN MAPPING =================
# Map prediction columns to truth columns
column_map = {
    "CYTOSOL_Score": "cytosol",
    "ER_Score": "ER",
    "MITOCHONDRION_Score": "mitochondrion",
    "NUCLEUS_Score": "nucleus",
    "EXCITATORY SYNAPSE_Score": "Excitatory Synapse",
    "INHIBITORY SYNAPSES_Score": "Inhibitory Synapses"
}

truth_label_cols = list(column_map.values())
pred_label_cols = list(column_map.keys())

# Ensure all columns exist
for col in truth_label_cols:
    if col not in truth_df.columns:
        raise ValueError(f"Truth column missing: {col}")
for col in pred_label_cols:
    if col not in pred_df.columns:
        raise ValueError(f"Prediction column missing: {col}")

# ================= EXTRACT NUMERIC ARRAYS =================
y_true = truth_df[truth_label_cols].values.astype(float)
y_pred_probs = pred_df[pred_label_cols].values.astype(float)
num_classes = y_true.shape[1]

# ================= COMPUTE METRICS =================
print("\n=== Per-Class Metrics ===\n")
for j in range(num_classes):
    y_true_col = y_true[:, j]
    y_score_col = y_pred_probs[:, j]
    y_pred_col = (y_score_col >= THRESH).astype(int)

    # Confusion matrix
    try:
        tn, fp, fn, tp = confusion_matrix(y_true_col, y_pred_col, labels=[0,1]).ravel()
    except ValueError:
        tn = fp = fn = tp = 0

    # Metrics
    accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall + 1e-8)

    try:
        roc_auc = roc_auc_score(y_true_col, y_score_col)
    except ValueError:
        roc_auc = float('nan')

    try:
        pr, rc, _ = precision_recall_curve(y_true_col, y_score_col)
        pr_auc = auc(rc, pr)
    except ValueError:
        pr_auc = float('nan')

    print(
        f"{truth_label_cols[j]:20s}: TP={tp:4d}, FP={fp:4d}, FN={fn:4d}, TN={tn:4d}, "
        f"Acc={accuracy:.3f}, Prec={precision:.3f}, Recall={recall:.3f}, "
        f"F1={f1:.3f}, ROC-AUC={roc_auc:.3f}"
    )


Using prediction file: /home/shd-sun-lab/SynapseNavigator/checkpoints/protgps/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)/Predictions.xlsx
Using truth file: /home/shd-sun-lab/SynapseNavigator/checkpoints/protgps/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)_sampled.xlsx

=== Per-Class Metrics ===

cytosol             : TP=  75, FP=  73, FN=  31, TN= 205, Acc=0.729, Prec=0.507, Recall=0.708, F1=0.591, ROC-AUC=0.827
ER                  : TP=  27, FP=   8, FN=  11, TN= 338, Acc=0.951, Prec=0.771, Recall=0.711, F1=0.740, ROC-AUC=0.887
mitochondrion       : TP=  29, FP=   1, FN=  10, TN= 344, Acc=0.971, Prec=0.967, Recall=0.744, F1=0.841, ROC-AUC=0.978
nucleus             : TP=  50, FP=  19, FN=  20, TN= 295, Acc=0.898, Prec=0.725, Recall=0.714, F1=0.719, ROC-AUC=0.894
Excitatory Synapse  : TP=  23, FP=  25, FN=  16, TN= 320, Acc=0.893, Prec=0.479, Recall=0.590, F1=0.529, ROC-AUC=0.909
Inhibitory Synapses : TP=   7, FP=   6, FN=  42, 

In [None]:
#Prediction scores on the Test_dataset (15%)
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    precision_recall_curve,
    auc,
    confusion_matrix,
)

THRESH = 0.5

print("Using dataset for AUCs:", args.dataset_name)

# 1) Load test dataset
test_dataset = get_object(args.dataset_name, "dataset")(args, "test")
print("Loaded test samples:", len(test_dataset.dataset))

# 2) Build test sequences and labels
test_x = [sample["x"] for sample in test_dataset.dataset]
test_y = torch.vstack([transform_y(s["y"]) for s in test_dataset.dataset])

# 3) Run model predictions (probabilities)
test_preds = predict_condensates(
    model, test_x, batch_size=32, round_scores=False
)

print("Shapes -> y:", test_y.shape, " preds:", test_preds.shape)

print("\n=== Classification Metrics (per condensate) ===")

for j, comp in enumerate(OLDCOMPS):
    y_true = test_y[:, j].cpu().numpy()
    y_score = test_preds[:, j].cpu().numpy()
    y_pred = (y_score >= THRESH).astype(int)

    print(f"\n{comp}")

    # Accuracy
    acc = accuracy_score(y_true, y_pred)

    # Confusion matrix → FP/FN/etc.
    tn, fp, fn, tp = confusion_matrix(
        y_true, y_pred, labels=[0, 1]
    ).ravel()

    # Precision / Recall / F1
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)

    # PR-AUC
    pr, rc, _ = precision_recall_curve(y_true, y_score)
    pr_auc = auc(rc, pr)

    # ROC-AUC
    try:
        roc_auc = roc_auc_score(y_true, y_score)
    except ValueError:
        roc_auc = float("nan")

    print(
        f"{comp:20s}  "
        f"TP={tp:4d}  FP={fp:4d}  FN={fn:4d}  TN={tn:4d}  "
        f"  acc={acc:.3f}  "
        f"prec={prec:.3f}  rec={rec:.3f}  "
        f"f1={f1:.3f}  pr_auc={pr_auc:.3f}  roc_auc={roc_auc:.3f}"
    )

Using dataset for AUCs: protein_condensates_combined


Exception: INVALID DATASET NAME: /home/shd-sun-lab/SynapseNavigator/data/dataset.json. AVAILABLE dict_keys(['protein_compartment', 'protein_compartment_guy', 'protein_compartment_uniprot_combined', 'protein_condensates_combined', 'reverse_homology'])

In [35]:
# === AUC EVALUATION (Self-contained) ===
# Uses the same `args` and the same loaded `model` already in memory.

from sklearn.metrics import roc_auc_score

print("Using dataset for AUCs:", args.dataset_file_path)

# 1) Load test dataset
test_dataset = get_object(args.dataset_name, "dataset")(args, "test")
print("Loaded test samples:", len(test_dataset.dataset))

# 2) Build test sequences and labels
test_x = [sample["x"] for sample in test_dataset.dataset]
test_y = [transform_y(sample["y"]) for sample in test_dataset.dataset]
test_y = torch.vstack(test_y)

# 3) Run model predictions (probabilities)
test_preds = predict_condensates(model, test_x, batch_size=32, round_scores=False)

print("Shapes -> y:", test_y.shape, " preds:", test_preds.shape)
print("\n=== ROC-AUC Scores ===")

# 4) Compute AUC per condensate
for j, comp in enumerate(OLDCOMPS):
    y_true = test_y[:, j].numpy()
    y_score = test_preds[:, j].numpy()
    try:
        auc = roc_auc_score(y_true, y_score)
        print(f"{comp:20s}  {auc:.3f}")
    except ValueError as e:
        print(f"{comp:20s}  AUC undefined ({e})")


Using dataset for AUCs: /home/shd-sun-lab/SynapseNavigator/data/dataset.json


100%|██████████| 6627/6627 [00:00<00:00, 316370.20it/s]


TEST DATASET CREATED FOR PROTEIN_CONDENSATES_COMBINED.
* 737 Proteins.
* 262 CYTOSOL -- 103 ER -- 95 MITOCHONDRION -- 170 NUCLEUS -- 71 EXCITATORY SYNAPSE -- 69 INHIBITORY SYNAPSES
Loaded test samples: 737


100%|███████████████████████████████████████████████████████████████| 24/24 [05:56<00:00, 14.85s/it]

Shapes -> y: torch.Size([737, 6])  preds: torch.Size([737, 6])

=== ROC-AUC Scores ===
cytosol               0.815
ER                    0.919
mitochondrion         0.934
nucleus               0.870
Excitatory Synapse    0.747
Inhibitory Synapses   0.738





# Analysis

In [10]:
import pandas as pd

# --- Load Excel ---
excel_path = "/home/shd-sun-lab/SynapseNavigator/checkpoints/protgps/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse).xlsx"
data = pd.read_excel(excel_path)

print("Loaded Excel:", excel_path)
print("Shape before filtering:", data.shape)

# --- Detect columns ---
possible_seq_cols = ["Sequence", "AA_sequence", "ProteinSequence", "WT_Sequence"]
possible_name_cols = ["Gene", "Name", "ID", "GeneName", "Protein", "Symbol"]
possible_uniprot_cols = ["Entry", "UniProt", "Accession", "ProteinID"]

seq_col = next((c for c in possible_seq_cols if c in data.columns), None)
name_col = next((c for c in possible_name_cols if c in data.columns), None)
uniprot_col = next((c for c in possible_uniprot_cols if c in data.columns), None)

print("Detected sequence column:", seq_col)
print("Detected name column:", name_col)
print("Detected UniProt column:", uniprot_col)

if seq_col is None:
    raise ValueError("No sequence column detected!")

# -----------------------------
# ⭐ COMPARTMENT FILTER HERE ⭐
# -----------------------------
COMPARTMENT = "Excitatory Synapse"   # <<< change this to whichever column you want
FILTER_VALUE = 1                     # only select proteins annotated with "1"

if COMPARTMENT not in data.columns:
    raise ValueError(f"Compartment column '{COMPARTMENT}' not found in Excel.")

filtered = data[data[COMPARTMENT] == FILTER_VALUE]

print(f"Filtering for compartment '{COMPARTMENT}' == {FILTER_VALUE}")
print("Shape after filtering:", filtered.shape)

# --- Settings ---
DEFAULT_CONDENSATE = COMPARTMENT  # optional: name condensate after the column
DEFAULT_SEED = 1

# --- Build IG list ---
sequences_for_ig = []

for idx, row in filtered.iterrows():
    seq = str(row[seq_col]).strip()
    if not isinstance(seq, str) or len(seq) == 0:
        continue

    # name selection
    if name_col and pd.notna(row[name_col]):
        name = str(row[name_col])
    elif uniprot_col and pd.notna(row[uniprot_col]):
        name = str(row[uniprot_col])
    else:
        name = f"Seq{idx+1}"

    sequences_for_ig.append({
        "condensate": DEFAULT_CONDENSATE,
        "name": name,
        "seed": DEFAULT_SEED,
        "sequence": seq,
    })

# Optional: define a tag that is always fused, or set to "" if not needed
mcherry = (
    ""
)

print(f"Prepared {len(sequences_for_ig)} sequences for IG.")


Loaded Excel: /home/shd-sun-lab/SynapseNavigator/checkpoints/protgps/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse)/2025.10.17_Genes_Hein_2025_&_Marc_2023(2Synapse).xlsx
Shape before filtering: (7011, 23)
Detected sequence column: Sequence
Detected name column: None
Detected UniProt column: Entry
Filtering for compartment 'Excitatory Synapse' == 1
Shape after filtering: (413, 23)
Prepared 413 sequences for IG.


In [73]:
# [AN1-TEMPLATE] Define sequences for attribution analysis
#
# HOW TO USE:
# - Fill the `sequences_for_ig` list with your own sequences.
# - Each entry is a dict with:
#     - "name":         an ID for the construct (for plots / tables)
#     - "condensate":   any label / class name you want to attach (optional, for you)
#     - "seed":         an integer identifier (optional, can be anything)
#     - "sequence":     the amino acid sequence (without tags like mCherry)
#
# You can reuse this template and change only `sequences_for_ig`.

sequences_for_ig = [
    {
        "condensate": "Synapse",    # e.g. "nucleolus", "synapse", "test_construct"
        "name": "VATG2",       # short ID
        "seed": 1,                     # any integer (or remove if not needed)
        "sequence": "MASQSQGIQQLLQAEKRAAEKVADARKRKARRLKQAKEEAQMEVEQYRREREHEFQSKQQAAMGSQGNLSAEVEQATRRQVQGMQSSQQRNRERVLAQLLGMVCDVRPQVHPNYRISA"  # <-- your AA seq
    },
    {
        "condensate": "Synapse",
        "name": "NeuM",
        "seed": 2,
        "sequence": "MLCCMRRTKQVEKNDDDQKIEQDGIKPEDKAHKAATKIQASFRGHITRKKLKGEKKDDVQAAEAEANKKDEAPVADGVEKKGEGTTTAEAAPATGSKPDEPGKAGETPSEEKKGEGDAATEQAAPQAPASSEEKAGSAETESATKASTDNSPSSKAEDAPAKEEPKQADVPAAVTAAAATTPAAEDAAAKATAQPPTETGESSQAEENIEAVDETKPKESARQDEGKEEEPEADQEHA"                 # another AA seq
    },
    # Add more dicts as needed
]

# Optional: define a tag that is always fused, or set to "" if not needed
mcherry = (
    ""
)
# If you don't want to fuse anything, just do:
# mcherry = ""

# For the rest of the IG pipeline, we will refer to this list:
experimental_sequences = sequences_for_ig
print(f"Prepared {len(experimental_sequences)} sequences for IG.")



Prepared 2 sequences for IG.


In [11]:
# [AN2] Visualization helper (Captum)

def visualize_text(datarecords, legend: bool = True) -> "HTML":
    dom = ["<table width: 100%>"]
    rows = [
        "<th>Amino Acid Importance</th>"
        "<th>Sample ID</th>"
        "<th>Target (score)</th>"
    ]
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tr>",
                    viz.format_word_importances(
                        datarecord.raw_input_ids, datarecord.word_attributions
                    ),
                    # just use true_class as-is (no split)
                    viz.format_classname(str(datarecord.true_class)),
                    viz.format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.pred_class, datarecord.pred_prob
                        )
                    ),
                    "<tr>",
                ]
            )
        )

    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=viz._get_color(value), label=label
                )
            )
        dom.append("</div>")

    dom.append("".join(rows))
    dom.append("</table>")
    html = viz.HTML("".join(dom))
    viz.display(html)

    return html


In [12]:
# [AN3] Forward function for IG and Captum setup

model = model.to("cpu")
model.eval()

def forward(batch_tokens):
    """
    Forward pass that:
    - gets ESM representations
    - mean-pools across positions
    - passes through the MLP head
    - returns sigmoid probabilities
    """
    model.zero_grad()
    result = model.model.encoder.model(
        batch_tokens,
        repr_layers=[model.model.encoder.repr_layer],
        return_contacts=False,
    )
    hidden = result["representations"][model.model.encoder.repr_layer].mean(axis=1)
    scores = torch.sigmoid(model.model.mlp({"x": hidden})["logit"])
    return scores

alphabet = model.model.encoder.alphabet

lig = LayerIntegratedGradients(forward, model.model.encoder.model.embed_tokens)


In [None]:
# [AN4] Integrated Gradients for the experimental sequences

records = []
sequence_dict_copy = copy.deepcopy(experimental_sequences)

for sequence_dict in sequence_dict_copy:
    seq = sequence_dict["sequence"]
    input_seq = seq + mcherry

    # Baseline: CLS + [MASK] * len(seq) + EOS
    baseline = torch.tensor(
        [alphabet.cls_idx] + [alphabet.mask_idx] * len(input_seq) + [alphabet.eos_idx]
    ).unsqueeze(0)

    # Inputs: use ESM batch_converter
    fair_x = [(0, input_seq)]
    _, _, batch_tokens = model.model.encoder.batch_converter(fair_x)

    # Prediction for the full fusion
    with torch.no_grad():
        model.eval()
        out = model.model({"x": [input_seq]})
    probs = torch.sigmoid(out["logit"]).detach().cpu()
    pred_idx = probs.argmax().item()
    pred_class_name = OLDCOMPS[pred_idx]

    # Store model's predicted class info alongside original label
    sequence_dict["model_pred_class_idx"] = int(pred_idx)
    sequence_dict["model_pred_class_name"] = pred_class_name
    sequence_dict["model_pred_probs"] = probs.squeeze().tolist()

    # IG attributions w.r.t. the predicted class
    attributions, delta = lig.attribute(
        inputs=batch_tokens,
        baselines=baseline,
        return_convergence_delta=True,
        target=pred_idx,
        n_steps=25,
    )
    A = attributions.sum(-1)[0, 1:-1]  # drop CLS/EOS
    A = A / torch.norm(A)
    sequence_dict["attributions"] = A.tolist()

    # Build Captum visualization record
    record = viz.VisualizationDataRecord(
        word_attributions=A * 10,
        pred_prob=probs.max().item(),
        pred_class=pred_class_name,              # model's predicted class
        true_class=sequence_dict["name"],        # your sequence label / ID
        attr_class="-",
        attr_score=attributions[0, 1:-1].sum(),
        raw_input_ids=input_seq,
        convergence_score=delta,
    )
    records.append(record)

# Take a quick look at what model predicted vs your labels
sequence_dict_copy


[{'condensate': 'Synapse',
  'name': 'VATG2',
  'seed': 1,
  'sequence': 'MASQSQGIQQLLQAEKRAAEKVADARKRKARRLKQAKEEAQMEVEQYRREREHEFQSKQQAAMGSQGNLSAEVEQATRRQVQGMQSSQQRNRERVLAQLLGMVCDVRPQVHPNYRISA',
  'model_pred_class_idx': 4,
  'model_pred_class_name': 'Excitatory Synapse',
  'model_pred_probs': [0.06139502674341202,
   0.0011658864095807076,
   1.1990982784482185e-05,
   0.13033036887645721,
   0.9232924580574036,
   0.47822701930999756],
  'attributions': [0.04107086880367465,
   -0.27569252430774005,
   0.6031264164626648,
   -0.017165788675634522,
   0.0016837581281229576,
   0.007321488045300473,
   0.031743149265642916,
   -0.05979586770967675,
   0.055249870115039186,
   0.03399575650794573,
   0.09451724376733123,
   -0.08618136275329669,
   0.0426475515773979,
   0.017179287172215604,
   0.11569484279894426,
   0.06945644800351981,
   -0.06515974882636677,
   0.028947182604885296,
   0.035146172980187025,
   0.036831958599218784,
   0.051163747350563345,
   0.01829941594464258,


In [77]:
# [AN5] Visualize attributions and save to disk

html = visualize_text(records)

with open("html_file.html", "w") as f:
    f.write(html.data)

# Also save the attribution data as CSV
for sequence_dict in sequence_dict_copy:
    seq = sequence_dict["sequence"]
    sequence_dict["full_sequence"] = seq + mcherry

attr_df = pd.DataFrame(sequence_dict_copy)
attr_df.to_csv("attributions.csv", index=False)

attr_df.head()


0,1,2
M A S Q S Q G I Q Q L L Q A E K R A A E K V A D A R K R K A R R L K Q A K E E A Q M E V E Q Y R R E R E H E F Q S K Q Q A A M G S Q G N L S A E V E Q A T R R Q V Q G M Q S S Q Q R N R E R V L A Q L L G M V C D V R P Q V H P N Y R I S A,VATG2,Excitatory Synapse (0.92)
,,
M L C C M R R T K Q V E K N D D D Q K I E Q D G I K P E D K A H K A A T K I Q A S F R G H I T R K K L K G E K K D D V Q A A E A E A N K K D E A P V A D G V E K K G E G T T T A E A A P A T G S K P D E P G K A G E T P S E E K K G E G D A A T E Q A A P Q A P A S S E E K A G S A E T E S A T K A S T D N S P S S K A E D A P A K E E P K Q A D V P A A V T A A A A T T P A A E D A A A K A T A Q P P T E T G E S S Q A E E N I E A V D E T K P K E S A R Q D E G K E E E P E A D Q E H A,NeuM,Excitatory Synapse (0.88)
,,


Unnamed: 0,condensate,name,seed,sequence,model_pred_class_idx,model_pred_class_name,model_pred_probs,attributions,full_sequence
0,Synapse,VATG2,1,MASQSQGIQQLLQAEKRAAEKVADARKRKARRLKQAKEEAQMEVEQ...,4,Excitatory Synapse,"[0.06139502674341202, 0.0011658864095807076, 1...","[0.04107086880367465, -0.27569252430774005, 0....",MASQSQGIQQLLQAEKRAAEKVADARKRKARRLKQAKEEAQMEVEQ...
1,Synapse,NeuM,2,MLCCMRRTKQVEKNDDDQKIEQDGIKPEDKAHKAATKIQASFRGHI...,4,Excitatory Synapse,"[0.115046925842762, 0.001337190275080502, 1.78...","[0.07333355394692458, 0.17847705944933295, -0....",MLCCMRRTKQVEKNDDDQKIEQDGIKPEDKAHKAATKIQASFRGHI...


# LARGE BATCH PROCESSING (Now with higher steps and lower batches)

In [64]:
# [AN1] Load Excel and prepare sequences for attribution analysis (with truncation/windowing)
import os
import pandas as pd

# Adjustable batch size
BATCH_SIZE = 1  # can increase if memory allows

# --- NEW CONTROL VARIABLE ---
# Set to True to filter sequences by the COMPARTMENT and FILTER_VALUE criteria.
# Set to False to process ALL sequences in the Excel file.
ACTIVATE_FILTERING = False
# --------------------------

# Targeted sequence truncation/windowing
TRUNCATE_THRESHOLD = 3100  # sequences longer than this get split
WINDOW_SIZE = 1500
OVERLAP = 200

# Filtering parameters (used only if ACTIVATE_FILTERING is True)
COMPARTMENT = "Inhibitory Synapses"
FILTER_VALUE = 1

# Load Excel
excel_path = "/home/shd-sun-lab/SynapseNavigator/SynGO_SynapsePredicted_Proteins.xlsx"
data = pd.read_excel(excel_path)
print("Loaded Excel:", excel_path)
print("Shape before processing:", data.shape)

# Detect columns
possible_seq_cols = ["Sequence", "AA_sequence", "ProteinSequence", "WT_Sequence", "sequences"]
possible_name_cols = ["Gene", "Name", "ID", "GeneName", "Protein", "Symbol"]
possible_uniprot_cols = ["Entry", "UniProt", "Accession", "ProteinID"]

seq_col = next((c for c in possible_seq_cols if c in data.columns), None)
name_col = next((c for c in possible_name_cols if c in data.columns), None)
uniprot_col = next((c for c in possible_uniprot_cols if c in data.columns), None)

if seq_col is None:
    raise ValueError("No sequence column detected!")

# --- Conditional Filtering Logic ---
if ACTIVATE_FILTERING:
    if COMPARTMENT not in data.columns:
        raise ValueError(f"Compartment column '{COMPARTMENT}' not found in Excel.")
        
    sequences_to_process = data[data[COMPARTMENT] == FILTER_VALUE].copy()
    print(f"Filtering activated: Processing sequences where '{COMPARTMENT}' == {FILTER_VALUE}.")
else:
    sequences_to_process = data.copy()
    print("Filtering deactivated: Processing ALL sequences.")

print("Shape of data to process:", sequences_to_process.shape)
# -----------------------------------

# Prepare sequences for IG
DEFAULT_CONDENSATE = COMPARTMENT if ACTIVATE_FILTERING else "All"
DEFAULT_SEED = 1
mcherry = ""  # optional fused tag

experimental_sequences = []

for idx, row in sequences_to_process.iterrows():
    seq = str(row[seq_col]).strip()
    if not isinstance(seq, str) or len(seq) == 0:
        continue

    if name_col and pd.notna(row[name_col]):
        base_name = str(row[name_col])
    elif uniprot_col and pd.notna(row[uniprot_col]):
        base_name = str(row[uniprot_col])
    else:
        base_name = f"Seq{idx+1}"

    if len(seq) > TRUNCATE_THRESHOLD:
        # split into windows
        window_count = 0
        for start in range(0, len(seq), WINDOW_SIZE - OVERLAP):
            sub_seq = seq[start:start + WINDOW_SIZE]
            if len(sub_seq) == 0:
                continue
            window_name = f"{base_name}_win{start}"
            experimental_sequences.append({
                "condensate": DEFAULT_CONDENSATE,
                "name": window_name,
                "seed": DEFAULT_SEED,
                "sequence": sub_seq,
            })
            window_count += 1
        print(f"Sequence {base_name} ({len(seq)} AA) split into {window_count} windows")
    else:
        experimental_sequences.append({
            "condensate": DEFAULT_CONDENSATE,
            "name": base_name,
            "seed": DEFAULT_SEED,
            "sequence": seq,
        })

print(f"\nPrepared {len(experimental_sequences)} sequences for IG.")

Loaded Excel: /home/shd-sun-lab/SynapseNavigator/SynGO_SynapsePredicted_Proteins.xlsx
Shape before processing: (366, 4)
Filtering deactivated: Processing ALL sequences.
Shape of data to process: (366, 4)
Sequence PCLO_HUMAN (5142 AA) split into 4 windows
Sequence BSN_HUMAN (3926 AA) split into 4 windows

Prepared 372 sequences for IG.


In [65]:
# [AN2] Setup device, model, forward, and Captum LayerIntegratedGradients
import torch
from captum.attr import LayerIntegratedGradients

# Device setup: GPU if available, else CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Move model to device
model = model.to(device)
model.eval()

# Forward function
def forward(batch_tokens):
    batch_tokens = batch_tokens.to(device)
    model.zero_grad()
    result = model.model.encoder.model(
        batch_tokens,
        repr_layers=[model.model.encoder.repr_layer],
        return_contacts=False,
    )
    hidden = result["representations"][model.model.encoder.repr_layer].mean(dim=1)
    out = model.model.mlp({"x": hidden})["logit"]
    scores = torch.sigmoid(out)
    return scores

alphabet = model.model.encoder.alphabet
lig = LayerIntegratedGradients(forward, model.model.encoder.model.embed_tokens)


Using device: cuda


In [None]:
# [AN3] Batch processing with incremental CSV saving (GPU if available)
import torch
from captum.attr import LayerIntegratedGradients
import pandas as pd
import gc
from tqdm import tqdm
import os

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Move model to device
model = model.to(device)
model.eval()

# Forward function
def forward(batch_tokens):
    batch_tokens = batch_tokens.to(device)
    model.zero_grad()
    result = model.model.encoder.model(
        batch_tokens,
        repr_layers=[model.model.encoder.repr_layer],
        return_contacts=False,
    )
    hidden = result["representations"][model.model.encoder.repr_layer].mean(dim=1)
    out = model.model.mlp({"x": hidden})["logit"]
    scores = torch.sigmoid(out)
    return scores

alphabet = model.model.encoder.alphabet
lig = LayerIntegratedGradients(forward, model.model.encoder.model.embed_tokens)

# Output CSV
OUTPUT_CSV = "attributions_precise.csv"
N_STEPS = 25  # can lower if GPU memory issues persist - Lower value reduces quality of score

# Load existing CSV to resume
if os.path.exists(OUTPUT_CSV):
    existing_df = pd.read_csv(OUTPUT_CSV)
    processed_names = set(existing_df["name"].tolist())
else:
    existing_df = pd.DataFrame()
    processed_names = set()

# Skip already processed sequences
remaining_sequences = [seq for seq in experimental_sequences if seq["name"] not in processed_names]
print(f"Remaining sequences to process: {len(remaining_sequences)}")

def process_batch(batch_seqs, internal_batch_size=1):
    batch_results = []
    for seq_dict in batch_seqs:
        seq = seq_dict["sequence"]
        input_seq = seq + mcherry

        baseline = torch.tensor(
            [alphabet.cls_idx] + [alphabet.mask_idx] * len(input_seq) + [alphabet.eos_idx]
        ).unsqueeze(0).to(device)

        # Batch tokens via ESM batch_converter
        fair_x = [(0, input_seq)]
        _, _, batch_tokens = model.model.encoder.batch_converter(fair_x)
        batch_tokens = batch_tokens.to(device)

        # Forward pass for prediction
        with torch.no_grad():
            out = model.model({"x": [input_seq]})
        probs = torch.sigmoid(out["logit"]).detach().cpu()
        pred_idx = probs.argmax().item()
        pred_class_name = OLDCOMPS[pred_idx]

        # IG attributions
        attributions, delta = lig.attribute(
            inputs=batch_tokens,
            baselines=baseline,
            target=pred_idx,
            n_steps=N_STEPS,
            internal_batch_size=internal_batch_size,
            return_convergence_delta=True,
        )
        A = attributions.sum(-1)[0, 1:-1]  # drop CLS/EOS
        A = A / torch.norm(A)

        seq_dict["model_pred_class_idx"] = int(pred_idx)
        seq_dict["model_pred_class_name"] = pred_class_name
        seq_dict["model_pred_probs"] = probs.squeeze().tolist()
        seq_dict["attributions"] = A.tolist()
        seq_dict["full_sequence"] = input_seq

        batch_results.append(seq_dict)

        # Free memory
        del batch_tokens, baseline, attributions, out, A
        gc.collect()

    return batch_results

# Process sequences in batches
for i in tqdm(range(0, len(remaining_sequences), BATCH_SIZE)):
    batch = remaining_sequences[i:i+BATCH_SIZE]
    batch_results = process_batch(batch, internal_batch_size=1)

    # Append to CSV incrementally
    batch_df = pd.DataFrame(batch_results)
    batch_df.to_csv(
        OUTPUT_CSV,
        mode="a",
        index=False,
        header=not os.path.exists(OUTPUT_CSV)
    )

    # Clean memory
    del batch, batch_results, batch_df
    gc.collect()

print("All batches processed. Results saved to:", OUTPUT_CSV)


Using device: cuda
Remaining sequences to process: 372


100%|██████████| 372/372 [07:30<00:00,  1.21s/it]

All batches processed. Results saved to: attributions_precise.csv





# Combine split up sequences


In [68]:
import pandas as pd
import numpy as np
from ast import literal_eval
import re

# --- CRITICAL FIX: Ensure OLDCOMPS is correctly defined in the environment ---
OLDCOMPS = [
    "cytosol",
    "ER",
    "mitochondrion",
    "nucleus",
    "Excitatory Synapse",
    "Inhibitory Synapses",
]
NUM_COMPARTMENTS = 6 
# -------------------------------------------------------------------------

# --- Configuration ---
INPUT_CSV = "attributions_precise.csv"
OUTPUT_CSV = "attributions_precise_aligned.csv"

# Load the CSV
converters = {
    'attributions': literal_eval,
    'model_pred_probs': literal_eval
}
df = pd.read_csv(INPUT_CSV, converters=converters)

print(f"Loaded CSV with {df.shape[0]} rows.")

# --- Prepare Helper Columns ---
df["is_window"] = df["name"].str.contains(r"_win\d+$", na=False)
df["base_name"] = df["name"].apply(lambda x: re.sub(r"_win\d+$", "", x))
df["win_start"] = df["name"].str.extract(r"_win(\d+)$").astype(float).fillna(0).astype(int)

# --- Group and Aggregate ---
aggregated_rows = []

for base_name, group in df.groupby("base_name"):
    if group["is_window"].any():
        # 1. Determine True Full Length and Sequence
        max_win_end = (group["win_start"] + group["sequence"].str.len()).max()
        full_length = int(max_win_end)
        
        # Sequence reconstruction logic
        seq_parts = {}
        for _, row in group.sort_values("win_start").iterrows():
             start = row["win_start"]
             sub_seq = row["sequence"]
             for i, char in enumerate(sub_seq):
                 if start + i < full_length and start + i not in seq_parts:
                     seq_parts[start + i] = char
                          
        full_sequence = "".join(seq_parts[i] for i in range(full_length) if i in seq_parts)
        full_length = len(full_sequence) 

        # Initialize arrays for averaging
        summed_attr = np.zeros(full_length, dtype=float)
        counts = np.zeros(full_length, dtype=int)
        summed_probs = np.zeros(NUM_COMPARTMENTS, dtype=float)

        # 2. Sum Attributions and Probabilities
        for idx, row in group.iterrows(): 
            start = row["win_start"]
            attr = np.array(row["attributions"])
            probs = np.array(row["model_pred_probs"])
            
            # Since model_pred_probs is guaranteed to be 6, this is fine
            if probs.shape[0] != NUM_COMPARTMENTS:
                 print(f"FATAL WARNING: Skipping {row['name']} (idx {idx}). Probability vector size is {probs.shape[0]}, expected {NUM_COMPARTMENTS}.")
                 continue
            
            attr_len = len(attr)
            
            # Defensive Check: Skip if window is outside bounds or has no attribution
            if start >= full_length or attr_len == 0:
                print(f"Warning: Skipping {row['name']} (idx {idx}). Start ({start}) >= Full length ({full_length}) or attr_len=0.")
                continue 

            # Define the end position for the summation
            end = min(start + attr_len, full_length)
            
            if start >= end:
                print(f"Warning: Skipping {row['name']} (idx {idx}). Slice is invalid: [{start}:{end}].")
                continue

            # Core summation for attributions
            summed_attr[start:end] += attr[:end-start] 
            counts[start:end] += 1
            
            # Probability summation
            summed_probs += probs

        # 3. Compute Final Averages and Metadata
        
        # Calculate average attributions
        with np.errstate(divide='ignore', invalid='ignore'):
            avg_attr = np.where(counts > 0, summed_attr / counts, 0.0).tolist()
        
        # Calculate average probabilities
        if len(group) == 0:
            avg_probs_arr = np.zeros(NUM_COMPARTMENTS)
        else:
            avg_probs_arr = summed_probs / len(group)
        
        # Handle potential NaN/Inf values before argmax
        avg_probs_arr[~np.isfinite(avg_probs_arr)] = 0.0
        
        pred_idx = int(np.argmax(avg_probs_arr))
        avg_probs = avg_probs_arr.tolist()
        
        # Final index check (now guaranteed to pass if OLDCOMPS size is 6)
        if pred_idx >= len(OLDCOMPS) or pred_idx < 0:
             # Should no longer be reached
            pred_class_name = "Unknown"
        else:
            pred_class_name = OLDCOMPS[pred_idx]
             
        # Take representative metadata
        rep_row = group.iloc[0].copy()
        
        # Update required fields
        rep_row["name"] = base_name 
        rep_row["sequence"] = full_sequence
        rep_row["attributions"] = avg_attr
        rep_row["model_pred_probs"] = avg_probs
        rep_row["model_pred_class_idx"] = pred_idx
        rep_row["model_pred_class_name"] = pred_class_name
        rep_row["full_sequence"] = full_sequence
        rep_row["aligned_avg"] = True 

        aggregated_rows.append(rep_row)
    else:
        # Not windowed, keep original row, add False for aligned_avg
        row = group.iloc[0].copy()
        row["aligned_avg"] = False
        aggregated_rows.append(row)

# Combine all rows
final_df = pd.DataFrame(aggregated_rows)

# Drop helper columns
final_df = final_df.drop(columns=["is_window", "base_name", "win_start"])

# Save final CSV
final_df.to_csv(OUTPUT_CSV, index=False)
print(f"Final CSV saved as {OUTPUT_CSV}")

Loaded CSV with 372 rows.
Final CSV saved as attributions_precise_aligned.csv
