# **S-PLM v2: Quickstart**

This notebook is a **usage example** of **S-PLM v2**.

* **Purpose:**

    1. Process PDB structures into the standardized inputs expected by our model.
    
    2. Generate **protein-level** and **residue-level** embeddings.
    
    3. Run sample evaluations and export metrics/logs.
* **Checkpoint:** An S-PLM v2 `.pth` checkpoint. Download from the provided [SharePoint link](https://mailmissouri-my.sharepoint.com/:u:/g/personal/wangdu_umsystem_edu/EUZ74fO3NOxHjTvc6uvKwDsB5fELaaw-oiPHFU9CJky_hg?e=4phwL0).



### **Environment Setup**

We **recommend** using an NVIDIA **A100** in Colab; other GPUs/CPU will work but may be slower or run into memory limits.


In [None]:
import os
from google.colab import drive

if not os.path.ismount('/content/drive'):
    drive.mount('/content/drive')

In [None]:
# Clone S-PLM
!git clone -q https://github.com/Yichuan0712/SPLM-V2-GVP /content/SPLMv2

# Install minimal deps
!pip install 'git+https://github.com/facebookresearch/esm.git' -q
!pip install 'git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup' -q
!pip install biopython -q

In [None]:
!pip install -q "torch==2.5.0" "torchvision==0.20.0" "torchaudio==2.5.0" \
  --index-url https://download.pytorch.org/whl/cu121
import torch
TORCH = "2.5.0"
CUDA = "cu" + torch.version.cuda.replace(".", "")
whl_url = f"https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html"
print("Using wheel URL:", whl_url)
!pip install -q pyg_lib torch-scatter torch-sparse torch-cluster torch-spline-conv \
    -f {whl_url}
!pip install -q torch-geometric

In [None]:
import os
os.chdir('/content/SPLMv2')
!git pull origin main

### **Prepare Checkpoint**

1. **Download the model** from the provided **[SharePoint link](https://mailmissouri-my.sharepoint.com/:u:/g/personal/wangdu_umsystem_edu/EUZ74fO3NOxHjTvc6uvKwDsB5fELaaw-oiPHFU9CJky_hg?e=4phwL0)** to your local machine.
2. **Upload to your Colab runtime** (Files pane → Upload to session storage), then set:




In [None]:
CHECKPOINT_PATH = "/content/checkpoint_0280000_gvp.pth"

3. **Faster option (recommended):** Mount Google Drive and copy the checkpoint from Drive into the Colab runtime.


In [None]:
from google.colab import drive, files
import os, shutil
drive.mount('/content/drive', force_remount=True)
shutil.copy("/content/drive/MyDrive/checkpoint_0280000_gvp.pth",
            "/content/checkpoint_0280000_gvp.pth")
CHECKPOINT_PATH = "/content/checkpoint_0280000_gvp.pth"

### **Generate Sequence Embeddings**

Use GVP model to generate embeddings from FASTA sequences, with optional truncation and residue-level outputs.

* **Standard run:** produces **protein-level** embeddings from `.fasta` to `.pkl`
* **Truncated run:** sets `--truncate_inference 1 --max_length_inference 1022` to handle long sequences

* **Residue-level run:** adds `--residue_level`

**Inputs:** `--input_seq` (FASTA), `--config_path`, `--checkpoint_path`.

**Outputs:** pickled embeddings in the working directory (per protein or per residue, depending on flags).


In [None]:
import os
os.chdir('/content/SPLMv2')

# standard run
!python3 -m utils.generate_seq_embedding --input_seq /content/SPLMv2/dataset/protein.fasta \
  --config_path /content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek.yaml \
  --checkpoint_path /content/checkpoint_0280000_gvp.pth \
  --result_path ./

In [None]:
import os
os.chdir('/content/SPLMv2')

# truncate_inference with max_length_inference=1022
!python3 -m utils.generate_seq_embedding --input_seq /content/SPLMv2/dataset/protein.fasta \
--config_path /content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek.yaml \
--checkpoint_path /content/checkpoint_0280000_gvp.pth \
--result_path ./ --out_file truncate_protein_embeddings.pkl \
--truncate_inference 1 --max_length_inference 1022

import pickle
with open('truncate_protein_embeddings.pkl', 'rb') as f:
    data = pickle.load(f)

In [None]:
import os
os.chdir('/content/SPLMv2')

# residue_level representations
!python3 -m utils.generate_seq_embedding --input_seq /content/SPLMv2/dataset/protein.fasta \
--config_path /content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek.yaml \
--checkpoint_path /content/checkpoint_0280000_gvp.pth \
--result_path ./ --out_file truncate_protein_residue_embeddings.pkl \
--truncate_inference 1 --max_length_inference 1022 --residue_level

import pickle
with open('truncate_protein_residue_embeddings.pkl', 'rb') as f:
    data = pickle.load(f)

### **Preprocess PDB**

First preprocess your PDB files using the provided script; only the resulting HDF5 files can be fed into the S-PLM v2 GVP model.


In [None]:
!python /content/SPLMv2/data/preprocess_pdb.py --data /content/SPLMv2/dataset/CATH_4_3_0_non-rep_pdbs/ --save_path /content/CATH_4_3_0_non-rep_gvp/ --max_workers 4

### **Generate Structure Embeddings**

Use GVP model to produce **residue-level structure embeddings** from **preprocessed HDF5** inputs and save them to `protein_struct_embeddings.pkl`, then quickly print the loaded result for inspection.

**Inputs:** `--hdf5_path` (preprocessed data), `--config_path`, `--checkpoint_path`.

**Output:** `protein_struct_embeddings.pkl` in the current directory (embeddings per protein/chain).

**Note:** You **must preprocess** PDB first, the model only accepts the processed HDF5 tensors.


In [None]:
import os
os.chdir('/content/SPLMv2')
!python -m utils.generate_struct_embedding \
  --hdf5_path /content/CATH_4_3_0_non-rep_gvp/ \
  --config_path /content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek.yaml \
  --checkpoint_path /content/checkpoint_0280000_gvp.pth \
  --result_path ./ \
  --residue_level

import pickle
with open('protein_struct_embeddings.pkl', 'rb') as f:
    print(pickle.load(f))

### **General clustering evaluation (CATH / Kinase)**

We evaluate embedding quality using clustering-based analyses. The evaluation supports both structure embeddings and sequence embeddings. We report both visualizations (t-SNE scatter plots) and quantitative metrics (Calinski–Harabasz, ARI, silhouette).


### Inputs

* `checkpoint_path`: path to the pretrained model checkpoint (`.pth`)
* `config_path`: path to the YAML config used for the checkpoint
* Path to the evaluation dataset (format depends on `task`)

  * `cath_struct`: preprocessed CATH HDF5 directory, e.g. `dataset/CATH_4_3_0_non-rep_h5/`
  * `cath_seq`: CATH FASTA file with CATH codes in headers (e.g., `1.10.10.2080|cath|...`)
  * `kinase_seq`: TSV file containing kinase metadata and sequences (e.g., `Kinase_group`, `Kinase_domain`)

### What it does
* **Computes embeddings** for all samples in the dataset.
* **Runs clustering evaluation** at one or more label granularities (e.g., CATH Class / Architecture / Fold, or Kinase Group).
* **Generates visualizations**:

  * Projects embeddings to 2D using t-SNE and saves scatter plots colored by ground-truth labels.
* **Computes clustering metrics**:

  * Calinski–Harabasz score (full space and t-SNE 2D)
  * Adjusted Rand Index (ARI) using k-means on the t-SNE space
  * Silhouette score in the full embedding space
* **Saves outputs**.




In [None]:
import os
os.chdir('/content/SPLMv2')
!python cath_with_struct.py --checkpoint_path /content/checkpoint_0280000_gvp.pth \
--config_path /content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek.yaml \
--cath_path /content/SPLMv2/dataset/CATH_4_3_0_non-rep_h5/

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

paths = [
    "/content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek/CATH_test_release/CATHgvp_1.png",
    "/content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek/CATH_test_release/CATHgvp_2.png",
    "/content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek/CATH_test_release/CATHgvp_3.png",
]

imgs = [Image.open(p) for p in paths]

total_width = sum(im.width for im in imgs)
max_height = max(im.height for im in imgs)

new_img = Image.new("RGB", (total_width, max_height), (255, 255, 255))
x = 0
for im in imgs:
    new_img.paste(im, (x, 0))
    x += im.width


dpi = 800
plt.figure(figsize=(total_width / dpi, max_height / dpi), dpi=dpi)
plt.imshow(new_img)
plt.axis("off")
plt.show()

In [None]:
import os
os.chdir('/content/SPLMv2')

!python cath_with_seq.py \
  --cath_seq ./dataset/Rep_subfamily_basedon_S40pdb.fa \
  --checkpoint_path /content/checkpoint_0280000_gvp.pth \
  --config_path /content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek.yaml

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

paths = [
    "/content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek/CATH_test_release_seq/step_0_CATH_1.png",
    "/content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek/CATH_test_release_seq/step_0_CATH_2.png",
    "/content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek/CATH_test_release_seq/step_0_CATH_3.png",
]

imgs = [Image.open(p) for p in paths]

total_width = sum(im.width for im in imgs)
max_height = max(im.height for im in imgs)

new_img = Image.new("RGB", (total_width, max_height), (255, 255, 255))
x = 0
for im in imgs:
    new_img.paste(im, (x, 0))
    x += im.width


dpi = 800
plt.figure(figsize=(total_width / dpi, max_height / dpi), dpi=dpi)
plt.imshow(new_img)
plt.axis("off")
plt.show()

In [None]:
import os
os.chdir('/content/SPLMv2')

!python kinase_with_seq.py \
  --kinase_path ./dataset/GPS5.0_homo_hasPK_with_kinasedomain.txt \
  --checkpoint_path /content/checkpoint_0280000_gvp.pth \
  --config_path /content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek.yaml


In [None]:
from PIL import Image
import matplotlib.pyplot as plt

img = Image.open("/content/SPLMv2/configs/config_plddtallweight_noseq_rotary_foldseek/Kinase_test_release_seq/step_0_kinase.png")
plt.imshow(img)
plt.axis("off")
plt.show()