In [None]:
#@title mount drive
import pandas as pd
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title install dependencies
!pip install torch_geometric
!pip install torch_scatter
!pip install torch_sparse
!pip install Biopython
!pip install gradio

In [None]:
import torch
import os
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
from torch_geometric.utils import add_self_loops
# --- 1. SCALED AA PHYSICS (Change A) ---
# Values are normalized: (x - min) / (max - min) roughly.
# This prevents Volume from drowning out everything else.
AA_PHYSICS_5D_SCALED = {
    'A': [0.70, 0.17, 0.11, 0.40, 0.00], 'R': [0.00, 0.67, 0.71, 1.00, 1.00],
    'N': [0.11, 0.32, 0.33, 0.33, 0.00], 'D': [0.11, 0.30, 0.26, 0.00, 0.31],
    'C': [0.78, 0.29, 0.31, 0.29, 0.67], 'Q': [0.11, 0.50, 0.44, 0.36, 0.00],
    'E': [0.11, 0.47, 0.37, 0.06, 0.34], 'G': [0.46, 0.00, 0.00, 0.40, 0.00],
    'H': [0.14, 0.55, 0.56, 0.60, 0.48], 'I': [1.00, 0.64, 0.45, 0.41, 0.00],
    'L': [0.92, 0.64, 0.45, 0.40, 0.00], 'K': [0.07, 0.65, 0.54, 0.87, 0.84],
    'M': [0.66, 0.61, 0.54, 0.37, 0.00], 'F': [0.81, 0.77, 0.72, 0.34, 0.00],
    'P': [0.32, 0.31, 0.32, 0.44, 0.00], 'S': [0.41, 0.17, 0.15, 0.36, 0.00],
    'T': [0.42, 0.33, 0.26, 0.35, 0.00], 'W': [0.40, 1.00, 1.00, 0.39, 0.00],
    'Y': [0.36, 0.79, 0.73, 0.36, 0.81], 'V': [0.97, 0.48, 0.34, 0.40, 0.00],
    'X': [0.00, 0.00, 0.00, 0.00, 0.00]
}

class PPIDatasetV3(Dataset):
    def __init__(self, csv_path, id_map, base_path):
        self.df = pd.read_csv(csv_path)
        self.id_map = id_map
        self.base_path = base_path
        self.cache = {}

    def _get_graph(self, p_id):
        p_id = str(p_id)
        if p_id in self.cache: return self.cache[p_id]

        info = self.id_map[p_id]
        adj = np.load(os.path.join(self.base_path, info['adj_file'])).astype(np.float32)
        seq = np.load(os.path.join(self.base_path, info['seq_file']), allow_pickle=True)[0]

        # Build 5D features using the SCALED dictionary
        x = torch.tensor([AA_PHYSICS_5D_SCALED.get(a, AA_PHYSICS_5D_SCALED['X']) for a in seq], dtype=torch.float32)

        edge_index = torch.from_numpy(np.array(np.nonzero(adj))).to(torch.long)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        graph = Data(x=x, edge_index=edge_index)
        self.cache[p_id] = graph
        return graph

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return self._get_graph(row['InteractorA']), self._get_graph(row['InteractorB']), torch.tensor([float(row['Label'])])

def collate(batch):
    p1, p2, y = zip(*batch)
    return Batch.from_data_list(p1), Batch.from_data_list(p2), torch.stack(y)

# Training

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
from torch_geometric.utils import add_self_loops
from tqdm.auto import tqdm
from sklearn.metrics import matthews_corrcoef
BASE_PATH = '/content/drive/MyDrive/gppipredv2_2/'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# --- 1. SCALED AA PHYSICS ---
AA_PHYSICS_5D_SCALED = {
    'A': [0.70, 0.17, 0.11, 0.40, 0.00], 'R': [0.00, 0.67, 0.71, 1.00, 1.00],
    'N': [0.11, 0.32, 0.33, 0.33, 0.00], 'D': [0.11, 0.30, 0.26, 0.00, 0.31],
    'C': [0.78, 0.29, 0.31, 0.29, 0.67], 'Q': [0.11, 0.50, 0.44, 0.36, 0.00],
    'E': [0.11, 0.47, 0.37, 0.06, 0.34], 'G': [0.46, 0.00, 0.00, 0.40, 0.00],
    'H': [0.14, 0.55, 0.56, 0.60, 0.48], 'I': [1.00, 0.64, 0.45, 0.41, 0.00],
    'L': [0.92, 0.64, 0.45, 0.40, 0.00], 'K': [0.07, 0.65, 0.54, 0.87, 0.84],
    'M': [0.66, 0.61, 0.54, 0.37, 0.00], 'F': [0.81, 0.77, 0.72, 0.34, 0.00],
    'P': [0.32, 0.31, 0.32, 0.44, 0.00], 'S': [0.41, 0.17, 0.15, 0.36, 0.00],
    'T': [0.42, 0.33, 0.26, 0.35, 0.00], 'W': [0.40, 1.00, 1.00, 0.39, 0.00],
    'Y': [0.36, 0.79, 0.73, 0.36, 0.81], 'V': [0.97, 0.48, 0.34, 0.40, 0.00],
    'X': [0.00, 0.00, 0.00, 0.00, 0.00]
}

# --- 2. DATASET & COLLATE ---
class PPIDatasetV3(Dataset):
    def __init__(self, csv_path, id_map, base_path):
        self.df = pd.read_csv(csv_path)
        self.id_map = id_map
        self.base_path = base_path
        self.cache = {}

    def _get_graph(self, p_id):
        p_id = str(p_id)
        if p_id in self.cache: return self.cache[p_id]

        info = self.id_map[p_id]
        adj = np.load(os.path.join(self.base_path, info['adj_file'])).astype(np.float32)
        seq = np.load(os.path.join(self.base_path, info['seq_file']), allow_pickle=True)[0]

        x = torch.tensor([AA_PHYSICS_5D_SCALED.get(a, AA_PHYSICS_5D_SCALED['X']) for a in seq], dtype=torch.float32)
        edge_index = torch.from_numpy(np.array(np.nonzero(adj))).to(torch.long)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        graph = Data(x=x, edge_index=edge_index)
        self.cache[p_id] = graph
        return graph

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return self._get_graph(row['InteractorA']), self._get_graph(row['InteractorB']), torch.tensor([float(row['Label'])])

def collate(batch):
    p1, p2, y = zip(*batch)
    return Batch.from_data_list(p1), Batch.from_data_list(p2), torch.stack(y)

# --- 3. INITIALIZATION ---
# Update these paths to your current working files
# --- UPDATE THESE PATHS IN YOUR TRAINING CELL ---
TRAIN_CSV = "/content/drive/MyDrive/gppipredv2_2/train_subset_90.csv"
VAL_CSV = "/content/drive/MyDrive/gppipredv2_2/internal_sanity_val.csv"



train_ds = PPIDatasetV3(TRAIN_CSV, master_map, BASE_PATH)
val_ds = PPIDatasetV3(VAL_CSV, master_map, BASE_PATH)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, collate_fn=collate)
val_loader = DataLoader(val_ds, batch_size=128, shuffle=False, collate_fn=collate)

model = SiameseGAT_v3(in_channels=5).to(DEVICE)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None: nn.init.zeros_(m.bias)

model.apply(init_weights)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
criterion = nn.BCEWithLogitsLoss()

best_mcc = 0.0
print("üöÄ Fresh Training Session Started. Goal: Stable Positive MCC.")

# --- 4. TRAINING LOOP ---
for epoch in range(1, 31):
    model.train()
    epoch_loss = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch} [Train]")

    for g1, g2, y in pbar:
        g1, g2, y = g1.to(DEVICE), g2.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        logits = model(g1, g2)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})

    # --- VALIDATION ---
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for g1, g2, y in tqdm(val_loader, desc=f"Epoch {epoch} [Val]"):
            g1, g2, y = g1.to(DEVICE), g2.to(DEVICE), y.to(DEVICE)
            logits = model(g1, g2)
            # LOGITS > 0.0 is the correct threshold for BCEWithLogitsLoss
            preds = (logits > 0.0).float()
            all_preds.extend(preds.cpu().numpy().flatten())
            all_labels.extend(y.cpu().numpy().flatten())

    val_mcc = matthews_corrcoef(all_labels, all_preds)
    avg_loss = epoch_loss / len(train_loader)

    print(f"üìà Epoch {epoch} | Loss: {avg_loss:.4f} | Val MCC: {val_mcc:.4f}")

    if val_mcc > best_mcc:
        best_mcc = val_mcc
        torch.save(model.state_dict(), f"{BASE_PATH}/best_siamese_v3_boat.pt")
        print(f"‚≠ê New Best MCC: {best_mcc:.4f}")

# Validation

In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
from torch_geometric.utils import add_self_loops
from tqdm.auto import tqdm
from sklearn.metrics import matthews_corrcoef, confusion_matrix, roc_auc_score, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns

# --- 1. LOCAL DATA PREPARATION ---
LOCAL_DATA_DIR = "/content/local_graphs"
os.makedirs(LOCAL_DATA_DIR, exist_ok=True)

print("üöö Copying files to local disk (this speed up is necessary for 72k pairs)...")
# Copy only npy files from your Drive folder to the local VM
!cp -r /content/drive/MyDrive/gppipredv2_2/validation_protein_features/*.npy {LOCAL_DATA_DIR}/
print("‚úÖ Files localized.")

# --- 2. DATASET DEFINITION (RAM-SAFE) ---
class PPIDatasetV3_Final(Dataset):
    def __init__(self, df, id_map, data_folder):
        self.df = df
        self.id_map = id_map
        self.data_folder = data_folder
        # We removed self.cache to avoid RAM crashes on large datasets

    def _get_graph(self, p_id):
        p_id = str(p_id)
        info = self.id_map[p_id]

        # Load directly from LOCAL disk
        adj = np.load(os.path.join(self.data_folder, info['adj_file'])).astype(np.float32)
        seq = np.load(os.path.join(self.data_folder, info['seq_file']), allow_pickle=True)[0]

        x = torch.tensor([AA_PHYSICS_5D_SCALED.get(a, AA_PHYSICS_5D_SCALED['X']) for a in seq], dtype=torch.float32)
        edge_index = torch.from_numpy(np.array(np.nonzero(adj))).to(torch.long)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        return Data(x=x, edge_index=edge_index)

    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return self._get_graph(row['InteractorA']), self._get_graph(row['InteractorB']), torch.tensor([float(row['Label'])])

def collate(batch):
    p1, p2, y = zip(*batch)
    return Batch.from_data_list(p1), Batch.from_data_list(p2), torch.stack(y)

# --- 3. RUN VALIDATION ---
# Ensure variables from earlier (df_ready, master_map, model) are still in memory
val_ds = PPIDatasetV3_Final(df_ready, master_map, LOCAL_DATA_DIR)
val_loader = DataLoader(val_ds, batch_size=128, shuffle=False, collate_fn=collate, num_workers=2)

model.eval()
all_logits, all_labels = [], []

print("üî¨ Starting Validation Loop...")
with torch.no_grad():
    for g1, g2, y in tqdm(val_loader, desc="Validating 72k Pairs"):
        logits = model(g1.to(DEVICE), g2.to(DEVICE))
        all_logits.extend(logits.cpu().numpy().flatten())
        all_labels.extend(y.cpu().numpy().flatten())

# --- 4. CALCULATE & PLOT ---
probs = 1 - (1 / (1 + np.exp(-np.array(all_logits)))) # Flipped Boat Logic
preds = (probs > 0.5).astype(float)

mcc = matthews_corrcoef(all_labels, preds)
auc = roc_auc_score(all_labels, probs)
print(f"\nüìä FINAL RESULTS: MCC={mcc:.4f} | AUC={auc:.4f}")



cm = confusion_matrix(all_labels, preds)
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title(f'Final Validation (n=72k) | MCC: {mcc:.4f}')
plt.show()

In [None]:
best_mcc = -1
best_threshold = 0.5

# Test different thresholds to maximize MCC
for t in np.arange(0.1, 0.9, 0.05):
    t_preds = (probs > t).astype(float)
    t_mcc = matthews_corrcoef(all_labels, t_preds)
    if t_mcc > best_mcc:
        best_mcc = t_mcc
        best_threshold = t

print(f"üéØ Optimal Threshold: {best_threshold:.2f}")
print(f"üöÄ Optimized MCC: {best_mcc:.4f}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, confusion_matrix

def generate_final_figures(labels, probabilities, threshold=0.85):
    # Calculate predictions based on optimized threshold
    preds = (probabilities > threshold).astype(float)

    # Setup Figure
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    plt.rcParams.update({'font.size': 12, 'font.family': 'sans-serif'})

    # --- Panel A: ROC Curve ---
    fpr, tpr, _ = roc_curve(labels, probabilities)
    roc_auc = auc(fpr, tpr)

    axes[0].plot(fpr, tpr, color='#1f77b4', lw=3, label=f'ROC AUC = {roc_auc:.4f}')
    axes[0].plot([0, 1], [0, 1], color='grey', lw=1, linestyle='--')
    axes[0].set_xlim([0.0, 1.0])
    axes[0].set_ylim([0.0, 1.05])
    axes[0].set_xlabel('False Positive Rate (1 - Specificity)')
    axes[0].set_ylabel('True Positive Rate (Sensitivity)')
    axes[0].set_title('A. Receiver Operating Characteristic')
    axes[0].legend(loc="lower right", frameon=True)
    axes[0].grid(alpha=0.2)

    # --- Panel B: Confusion Matrix ---
    cm = confusion_matrix(labels, preds)
    # Normalized for percentage-based view
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[1], cbar=False)

    axes[1].set_title(f'B. Confusion Matrix (Threshold = {threshold})')
    axes[1].set_xlabel('Predicted Label')
    axes[1].set_ylabel('True Label')
    axes[1].set_xticklabels(['Non-Binder', 'Binder'])
    axes[1].set_yticklabels(['Non-Binder', 'Binder'])

    plt.tight_layout()

    # Save files
    plt.savefig('Final_GATv3_Metrics.pdf', dpi=300, bbox_inches='tight')
    plt.savefig('Final_GATv3_Metrics.png', dpi=300, bbox_inches='tight')
    plt.show()

# Run the plot
generate_final_figures(all_labels, probs, threshold=0.85)

In [None]:
from sklearn.metrics import classification_report, f1_score, precision_score, recall_score

# 1. Calculate metrics at the optimal 0.85 threshold
final_preds = (probs > 0.85).astype(float)

# 2. Compute components
precision = precision_score(all_labels, final_preds)
recall = recall_score(all_labels, final_preds)
f1 = f1_score(all_labels, final_preds)

# 3. Create a summary dataframe for the paper
metrics_data = {
    "Metric": ["ROC-AUC", "MCC", "Precision", "Recall (Sensitivity)", "F1-Score", "Accuracy"],
    "Value": [
        f"{roc_auc_score(all_labels, probs):.4f}",
        f"{matthews_corrcoef(all_labels, final_preds):.4f}",
        f"{precision:.4f}",
        f"{recall:.4f}",
        f"{f1:.4f}",
        f"{accuracy_score(all_labels, final_preds):.4f}"
    ]
}

metrics_df = pd.DataFrame(metrics_data)

# Print as a nice Markdown table
print("### üìä Final Validation Performance (Threshold = 0.85)")
print(metrics_df.to_markdown(index=False))

# Optional: Print the full per-class report
print("\n### üß¨ Detailed Per-Class Report")
print(classification_report(all_labels, final_preds, target_names=['Non-Binder (0)', 'Binder (1)']))

In [None]:
thresholds = np.arange(0.1, 0.95, 0.02)
mccs = [matthews_corrcoef(all_labels, (probs > t).astype(float)) for t in thresholds]

plt.figure(figsize=(8, 5))
plt.plot(thresholds, mccs, color='#e67e22', lw=2)
plt.axvline(0.85, color='red', linestyle='--', label='Optimum (0.85)')
plt.title('MCC vs. Classification Threshold')
plt.xlabel('Threshold')
plt.ylabel('Matthews Correlation Coefficient')
plt.legend()
plt.grid(alpha=0.3)
plt.savefig('MCC_Threshold_Optimization.png', dpi=300)
plt.show()

In [None]:
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt

def plot_precision_recall_vs_threshold(labels, probabilities, optimal_t=0.85):
    # Calculate precision and recall for all possible thresholds
    precisions, recalls, thresholds = precision_recall_curve(labels, probabilities)

    # Setup the figure
    plt.figure(figsize=(10, 6))
    plt.rcParams.update({'font.size': 12})

    # Plot Precision and Recall
    # Note: precision and recall have one more element than thresholds, so we slice them
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
    plt.plot(thresholds, recalls[:-1], "g-", label="Recall (Sensitivity)", linewidth=2)

    # Highlight the 0.85 threshold
    plt.axvline(x=optimal_t, color='red', linestyle=':', alpha=0.7)
    plt.annotate(f'Selected Threshold: {optimal_t}',
                 xy=(optimal_t, 0.65), xytext=(optimal_t-0.3, 0.4),
                 arrowprops=dict(facecolor='black', shrink=0.05, width=1, headwidth=8))

    # Add labels and styling
    plt.title('Precision and Recall vs. Decision Threshold')
    plt.xlabel('Threshold')
    plt.ylabel('Metric Score')
    plt.legend(loc="lower left")
    plt.grid(True, alpha=0.3)
    plt.xlim([0, 1])
    plt.ylim([0, 1.05])

    # Save for publication
    plt.savefig('Precision_Recall_Threshold_Tradeoff.png', dpi=300, bbox_inches='tight')
    plt.savefig('Precision_Recall_Threshold_Tradeoff.pdf', dpi=300, bbox_inches='tight')
    plt.show()

# Execute the plot
plot_precision_recall_vs_threshold(all_labels, probs, optimal_t=0.85)

# gPPipred APP

In [None]:
import os
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
import tempfile
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GATv2Conv, global_mean_pool

# --- 1. CONFIGURATION & PHYSICS ---
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
AA_PHYSICS_5D_SCALED = {
    'A': [0.70, 0.17, 0.11, 0.40, 0.00], 'R': [0.00, 0.67, 0.71, 1.00, 1.00],
    'N': [0.11, 0.32, 0.33, 0.33, 0.00], 'D': [0.11, 0.30, 0.26, 0.00, 0.31],
    'C': [0.78, 0.29, 0.31, 0.29, 0.67], 'Q': [0.11, 0.50, 0.44, 0.36, 0.00],
    'E': [0.11, 0.47, 0.37, 0.06, 0.34], 'G': [0.46, 0.00, 0.00, 0.40, 0.00],
    'H': [0.14, 0.55, 0.56, 0.60, 0.48], 'I': [1.00, 0.64, 0.45, 0.41, 0.00],
    'L': [0.92, 0.64, 0.45, 0.40, 0.00], 'K': [0.07, 0.65, 0.54, 0.87, 0.84],
    'M': [0.66, 0.61, 0.54, 0.37, 0.00], 'F': [0.81, 0.77, 0.72, 0.34, 0.00],
    'P': [0.32, 0.31, 0.32, 0.44, 0.00], 'S': [0.41, 0.17, 0.15, 0.36, 0.00],
    'T': [0.42, 0.33, 0.26, 0.35, 0.00], 'W': [0.40, 1.00, 1.00, 0.39, 0.00],
    'Y': [0.36, 0.79, 0.73, 0.36, 0.81], 'V': [0.97, 0.48, 0.34, 0.40, 0.00],
    'X': [0.00, 0.00, 0.00, 0.00, 0.00]
}

# --- 2. ARCHITECTURE ---
class SiameseGAT_v3(nn.Module):
    def __init__(self, in_channels=5, hidden_channels=64, num_layers=8):
        super(SiameseGAT_v3, self).__init__()
        self.node_lin = nn.Linear(in_channels, hidden_channels)
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        for _ in range(num_layers):
            self.convs.append(GATv2Conv(hidden_channels, hidden_channels // 4, heads=4))
            self.batch_norms.append(nn.BatchNorm1d(hidden_channels))

        self.fc = nn.Sequential(
            nn.Linear(hidden_channels * 2, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward_once(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.node_lin(x)
        for conv, bn in zip(self.convs, self.batch_norms):
            h = conv(x, edge_index)
            h = bn(h)
            h = F.elu(h)
            x = x + h
        return global_mean_pool(x, batch)

    def forward(self, g1, g2):
        out1 = self.forward_once(g1)
        out2 = self.forward_once(g2)
        combined = torch.cat([out1, out2], dim=1)
        return self.fc(combined)

# --- 3. HOTSPOT LOGIC ---
def get_hotspots(model, g_bait, g_prey, bait_seq, prey_seq, top_n=10):
    g_bait.x.requires_grad = True
    g_prey.x.requires_grad = True
    logits = model(Batch.from_data_list([g_bait]), Batch.from_data_list([g_prey]))
    model.zero_grad()
    logits.backward()

    def extract_top(grad, seq):
        importance = grad.abs().sum(dim=1).cpu().numpy()
        indices = np.argsort(importance)[-top_n:][::-1]
        return ", ".join([f"{seq[i]}{i+1}" for i in indices if i < len(seq)])

    return extract_top(g_bait.x.grad, bait_seq), extract_top(g_prey.x.grad, prey_seq)

# --- 4. DATA UTILS ---
CACHE_DIR = "protein_cache"
os.makedirs(CACHE_DIR, exist_ok=True)

def get_protein_data(uniprot_id):
    uniprot_id = uniprot_id.strip().upper()
    cache_path = os.path.join(CACHE_DIR, f"{uniprot_id}.pt")
    if os.path.exists(cache_path): return torch.load(cache_path)

    # Fetch FASTA for sequence
    res_fasta = requests.get(f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta")
    # Fetch JSON for metadata (Protein Name)
    res_json = requests.get(f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.json")

    if res_fasta.status_code != 200: raise ValueError(f"ID {uniprot_id} not found.")

    seq = "".join(res_fasta.text.split("\n")[1:])

    # Extract Protein Name
    p_name = "Unknown Protein"
    if res_json.status_code == 200:
        json_data = res_json.json()
        p_name = json_data.get('proteinDescription', {}).get('recommendedName', {}).get('fullName', {}).get('value', "Unknown")

    L = len(seq)
    adj = np.eye(L)
    for i in range(L - 1):
        adj[i, i+1] = 1; adj[i+1, i] = 1
    edge_index = torch.from_numpy(adj).nonzero().t().contiguous().long()

    x = torch.tensor([AA_PHYSICS_5D_SCALED.get(a, AA_PHYSICS_5D_SCALED['X']) for a in seq], dtype=torch.float32)
    data = Data(x=x, edge_index=edge_index)

    package = (data, seq, p_name)
    torch.save(package, cache_path)
    return package

# --- 5. PREDICTION CORE ---
def run_prediction(bait_id, prey_raw, threshold):
    try:
        prey_ids = [p.strip().upper() for p in prey_raw.replace(',', ' ').split() if p.strip()]
        g_bait, bait_seq, bait_name = get_protein_data(bait_id)

        results = []
        for p_id in prey_ids:
            g_prey, prey_seq, prey_name = get_protein_data(p_id)

            model.eval()
            with torch.no_grad():
                logits = model(Batch.from_data_list([g_bait.to(DEVICE)]),
                               Batch.from_data_list([g_prey.to(DEVICE)]))
                prob = torch.sigmoid(logits).item()

            b_hot, p_hot = get_hotspots(model, g_bait.to(DEVICE), g_prey.to(DEVICE), bait_seq, prey_seq)

            results.append({
                "Bait ID": bait_id,
                "Bait Name": bait_name,
                "Prey ID": p_id,
                "Prey Name": prey_name,
                "Probability": round(prob, 4),
                "Binds": "Yes" if prob >= threshold else "No",
                "Bait Hotspots": b_hot,
                "Prey Hotspots": p_hot,
                "Bait Sequence": bait_seq,
                "Prey Sequence": prey_seq
            })
            torch.cuda.empty_cache()

        df = pd.DataFrame(results)

        # Plotting logic
        plt.figure(figsize=(8, max(4, 0.5 * len(df))))
        sns.barplot(data=df, x='Probability', y='Prey ID', palette='viridis')
        plt.axvline(threshold, color='red', linestyle='--', label=f'Threshold {threshold}')
        plt.title(f"PPI Analysis for Bait: {bait_id}")
        plt.xlim(0, 1)
        plt.legend()
        plt.tight_layout()

        tmp_csv = os.path.join(tempfile.gettempdir(), "results.csv")
        df.to_csv(tmp_csv, index=False)
        return plt.gcf(), df, tmp_csv, f"‚úÖ Analyzed {len(df)} preys."
    except Exception as e:
        return None, None, None, f"‚ùå Error: {str(e)}"

# --- 6. INITIALIZATION & UI ---
model = SiameseGAT_v3(in_channels=5).to(DEVICE)
MODEL_PATH = "GATv3_FINAL_CORRECTED.pth"
if os.path.exists(MODEL_PATH):
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True))

custom_theme = gr.themes.Soft(font=[gr.themes.GoogleFont("Inter"), "sans-serif"])

with gr.Blocks() as demo:
    gr.Markdown("# üß¨ gPPIpred v3 ‚Äì Advanced Siamese GAT Predictor")

    with gr.Row():
        with gr.Column(scale=2):
            b_in = gr.Textbox(label="Bait ID (UniProt)", value="P04637")
            p_in = gr.Textbox(label="Prey IDs (comma separated)", value="O15169, P00519")
        with gr.Column(scale=1):
            t_slider = gr.Slider(0.1, 0.95, value=0.90, step=0.01, label="Sensitivity Threshold")
            btn = gr.Button("üîç Predict", variant="primary")
            status = gr.Textbox(label="Status")

    with gr.Tabs():
        with gr.TabItem("Chart"): plot = gr.Plot()
        with gr.TabItem("Data"):
            table = gr.Dataframe()
            csv_file = gr.File(label="Download CSV")

    btn.click(run_prediction, inputs=[b_in, p_in, t_slider], outputs=[plot, table, csv_file, status])

if __name__ == "__main__":
    demo.launch(theme=custom_theme)