Cell 1: Install dependencies

In [1]:
pip install numpy torch pandas matplotlib scikit-learn fair-esm

Note: you may need to restart the kernel to use updated packages.




Cell 2: Imports & Environment Setup f

In [2]:
import os, sys
import numpy as np
import esm
import torch
import tkinter as tk
from tkinter import scrolledtext, messagebox
import pandas as pd
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import List
import pickle
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support
import itertools
import csv
import collections

from sklearn.metrics import classification_report
# Allow import from src/
THIS_DIR = os.getcwd()  # Modify if needed
sys.path.insert(0, THIS_DIR)


Cell 3: Import your model code

In [3]:
# Assuming these are available under src/
from grid_search_scripts.grid_search_binding_predictor_esm_without_weight import (
    BindingPredictor,
    load_esm_model,
    get_esm_embeddings,
)



 Cell 4: Define model path with weight

In [8]:
# Adjust path to saved model
PROJECT_ROOT = os.path.abspath(os.path.join(THIS_DIR, ".."))
MODEL_PATH = os.path.join(PROJECT_ROOT, "saved_models", "model_emb_best_weight.pt")
EMBEDDING_SIZE = 640


 Cell 4:[**Optional**] Define model path **without** weight

In [11]:
# Adjust path to saved model
PROJECT_ROOT = os.path.abspath(os.path.join(THIS_DIR, ".."))
MODEL_PATH = os.path.join(PROJECT_ROOT, "saved_models", "model_emb_best_without_weight.pt")
EMBEDDING_SIZE = 1280


Cell 5: GUI App & Prediction Logic & Output Handler & Main GUI Launch

In [12]:
class SequencePredictorApp:
    def __init__(self, root):
        self.root = root
        root.title("NES–CRM1 Contact Predictor")

        # 1) load ESM backbone (once)
        self.esm_model, self.alphabet, self.batch_converter, self.device = \
            load_esm_model(embedding_size=EMBEDDING_SIZE)

        # 2) instantiate & load your trained classifier
        self.classifier = self._load_classifier(MODEL_PATH)

        # ─── build the simple UI ────────────────────────────────
        tk.Label(root, text="Enter amino-acid sequence:").pack(padx=10, pady=(10, 0))
        self.seq_var = tk.StringVar()
        tk.Entry(root, textvariable=self.seq_var, width=60).pack(padx=10, pady=5)
        tk.Button(root, text="Predict", command=self.on_predict).pack(padx=10, pady=5)
        tk.Label(root, text="Per-residue contact probability:").pack(padx=10, pady=(10, 0))

        self.output = scrolledtext.ScrolledText(root, width=70, height=15, state="disabled")
        self.output.pack(padx=10, pady=5)

    def _load_classifier(self, path):
        model = BindingPredictor(emb_dim=EMBEDDING_SIZE, hidden_dim=128)
        state = torch.load(path, map_location=self.device)
        model.load_state_dict(state)
        model.to(self.device)
        model.eval()
        return model

    def on_predict(self):
        seq = self.seq_var.get().strip().upper()
        if not seq:
            messagebox.showwarning("Input Error", "Please enter a sequence.")
            return

        try:
            probs = self._predict(seq)  # → np.ndarray shape (L,)
            text = "\n".join(
                f"Pos {i+1} ({aa}): {p:.3f}"
                for i, (aa, p) in enumerate(zip(seq, probs))
            )
            self._set_output(text)
        except Exception as e:
            self._set_output(f"Prediction error:\n{e}")
    def topk_predictions(self, pred_probs, k=5):
        pred = np.zeros_like(pred_probs)
        topk_idx = np.argsort(pred_probs)[-k:]
        pred[topk_idx] = 1
        return pred
    def _predict(self, seq: str):
        """
        Discrete prediction (0/1) per residue with adaptive thresholding.
        Ensures at least 3 residues are predicted as binding.
        """
        # 1) Get ESM embedding
        emb_list = get_esm_embeddings(
            [seq],
            self.esm_model,
            self.alphabet,
            self.batch_converter,
            self.device,
            layer=6,
        )
        emb = torch.tensor(emb_list[0], dtype=torch.float32).unsqueeze(0).to(self.device)

        # 2) Forward pass through model (LSTM + classifier)
        with torch.no_grad():
            logits = self.classifier(emb)
            probs = torch.sigmoid(logits)[0, :].cpu().numpy()

        # 3) Apply top-k thresholding
        pred_mask = self.topk_predictions(probs, k=5)

        return pred_mask  # shape (L,), values 0 or 1

    def _set_output(self, text: str):
        self.output.config(state="normal")
        self.output.delete("1.0", tk.END)
        self.output.insert(tk.END, text)
        self.output.config(state="disabled")



Cell 6: Launch the App

In [13]:
if __name__ == "__main__":
    root = tk.Tk()
    SequencePredictorApp(root)
    root.mainloop()