<div align="center">

# unZipro  [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Gabriel-QIN/unZipro/blob/master/notebooks/unZipro.ipynb)  [![GitHub](https://img.shields.io/badge/-GitHub-181717?logo=github&logoColor=white)](https://github.com/Gabriel-Qin/unZipro)  [![Server Status](https://img.shields.io/badge/Website-unZipro-green)](https://ai4bio.online/unZipro/home)

> **This is the official Google Colab tutorial of _unZipro_** ‚Äî an unsupervised zero-shot inverse folding framework for protein evolution and high-fitness variant prediction.

</div>

---

<div align="center">
<img src="https://raw.githubusercontent.com/Gabriel-QIN/unZipro/master/image/easy_workflow.png" width="90%">
</div>

# Overview

**unZipro** (<u>un</u>supervised Zero-shot <u>i</u>nverse folding framework for <u>pro</u>tein evolution) is a lightweight **graph neural network (GNN)-based framework** designed for AI-guided protein engineering.

By combining general inverse folding constraints with family-specific adaptation, unZipro efficiently prioritizes high-fitness mutations without exhaustive screening.

<div align="center">
<img src="https://raw.githubusercontent.com/Gabriel-QIN/unZipro/master/image/applications.png" width="90%">
</div>

## ‚öôÔ∏è How it works

unZipro tackles protein engineering like ‚Äúhunting for the needle in the haystack‚Äù:

- üß† **Zero-shot transfer learning** captures a universal protein fitness landscape.
- üß© **Meta-learning** adapts to family-specific fitness landscapes.
- ‚úÖ **Prioritization** of the most promising high-fitness variants for experimental validation.

---



# 1. Installation

In [1]:
#@title Install unZipro {display-mode: "form"}
! rm -rf unZipro
!git clone https://github.com/Gabriel-Qin/unZipro.git

Cloning into 'unZipro'...
remote: Enumerating objects: 3727, done.[K
remote: Total 3727 (delta 0), reused 0 (delta 0), pack-reused 3727 (from 2)[K
Receiving objects: 100% (3727/3727), 68.28 MiB | 12.47 MiB/s, done.
Resolving deltas: 100% (544/544), done.


In [2]:
#@title Install unZipro and dependencies {display-mode: "form"}
!pip install numpy pandas biotite requests plotly py3Dmol
### Install PyTorch
## If use in you local env, please uncomment the following line!
# !pip install torch==2.4.1+cu124 --index-url https://download.pytorch.org/whl/cu124rch==2.4.1+cu124 --index-url https://download.pytorch.org/whl/cu124
# !pip install learn2learn==0.2.0
print("‚úÖ unZipro installation complete.")

Collecting biotite
  Downloading biotite-1.5.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (5.5 kB)
Collecting py3Dmol
  Downloading py3dmol-2.5.3-py2.py3-none-any.whl.metadata (2.1 kB)
Collecting biotraj<2.0,>=1.0 (from biotite)
  Downloading biotraj-1.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (32 kB)
Downloading biotite-1.5.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (57.5 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m57.5/57.5 MB[0m [31m42.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading py3dmol-2.5.3-py2.py3-none-any.whl (7.2 kB)
Downloading biotraj-1.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m2.2/2.2 MB[0m [31m67.2 MB/s[0m eta [36m0:00:

# 2. Zero-shot inference (without fine-tuning)

In [3]:
#@title unZipro Mutation Prioritization {display-mode: "form"}
#@markdown **Parameter Descriptions**
#@markdown - **pdb_code**: PDB ID or AlphaFold DB UniProt ID to load
#@markdown   (e.g., `"6vpc"` for PDB or `"Q9NUG6"` for AFDB).
#@markdown - **chain_id**: Protein chain to analyze.
#@markdown   For PDB structures, specify the exact chain (e.g., `"E"`);
#@markdown   for AFDB models, always set to `"A"`.

pdb_code = "6vpc"  #@param {type:"string"}
chain_id = "E" #@param {type:"string"}

import os
import sys
import json
import torch
import torch.nn as nn
sys.path.append("unZipro/script/")
from utils import *
from model import unZipro
from fetch_PDB_parallel import safe_fetch
from unZipro_mutation import infer_single_protein

model_param = 'unZipro/Models/unZipro_params.pt'
cache_dir = 'tmp/'
outdir = 'output'
config_path = 'unZipro/config/unZipro_pretrain.json'
pdb_dir = 'PDB'
pdb_name = f'{pdb_code}{chain_id}'
temperature = 1.0
nneighbor = 20
os.makedirs(outdir, exist_ok=True)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
res = None
print(f"INFO | Starting mutation prioritization for {pdb_name} with res={res or 'ALL'}")

safe_fetch(pdb_name, pdb_dir)

with open(config_path, "r") as f:
    data = json.load(f)
model_config = Config(**data)
model = unZipro(model_config).to(device)
state_dict = torch.load(model_param, map_location=device, weights_only=True)
try:
    model.load_state_dict(state_dict)
except:
    new_state_dict = {k.replace("module.", ""):v for k,v in state_dict.items()}
    model.load_state_dict(new_state_dict)

dataset = GraphDataset(
    datalist=[os.path.join(pdb_dir, f'{pdb_name}.pdb')],
    nneighbor=nneighbor,
    noise=0,
    cache_dir=cache_dir
)
loader = get_loader(dataset=dataset, batchsize=1)
criterion = nn.CrossEntropyLoss().to(device)
model.eval()

df_list = infer_single_protein(
    model, criterion, loader, pdb_dir, outdir,
    temperature=temperature, device=device,
    output_prob=True, output_logits=True,
    rank_by_prob=True, res=res or None
)
print("‚úÖ Done! Results saved in", outdir)


INFO | Starting mutation prioritization for 6vpcE with res=ALL
‚úÖ Download 6vpcE
[INFO] 6vpcE | Recovery: 44.83% | Loss: 1.9499
[INFO] Saved in silico mutation scores to output/6vpcE.info.csv!
[INFO] Saved ranked scores to output/6vpcE.info_rank_by_prob.csv!
[INFO] Saved per-residue probability matrix  to output/6vpcE.info_probs.csv!
[INFO] Saved per-residue logits to output/6vpcE.info_logits.csv!
Total time: 1.972 s
Average per item: 1.972 s
‚úÖ Done! Results saved in output


### Visualize fitness landscape

In [4]:
#@title Fitness Landscape Viewer {display-mode:"form"}
color_scale = "YlGnBu_r"  #@param ["YlGnBu_r","Viridis","YlOrRd","purples", "rdpu", "reds"]
fig_width = 1400  #@param {type:"number"}
fig_height = 600  #@param {type:"number"}
font_size = 14  #@param {type:"number"}
title_size = 22  #@param {type:"number"}

import pandas as pd
import plotly.express as px

df = pd.read_csv(f'{outdir}/{pdb_name}.info_probs.csv')
aa_cols = [c for c in df.columns if c.startswith("prob_")]

# Prepare score matrix
score = df[aa_cols].copy()
score_T = score.T
score_T.index = [c.replace("prob_", "") for c in aa_cols]
score_T.columns = df["auth_idx"]

fig = px.imshow(
    score_T,
    color_continuous_scale=color_scale,
    aspect="auto",
    labels=dict(color="Fitness"),
    title=f"<b>Fitness Landscape for {pdb_name}</b>",
)

fig.update_layout(
    width=fig_width,
    height=fig_height,
    font=dict(family="Arial", size=font_size),
    title=dict(x=0.5, y=0.95, font=dict(size=title_size)),
    xaxis_title="<b>Position</b>",
    yaxis_title="<b>Amino Acid</b>",
    margin=dict(l=120, r=80, t=100, b=120),
    coloraxis_colorbar=dict(
        title="<b>Score</b>",
        thickness=18,
        len=0.7,
        tickfont=dict(size=12)
    )
)

fig.update_xaxes(
    tickfont=dict(family="Arial", size=16),
    tickangle=0,
    showgrid=False,
    title_font=dict(family="Arial", size=18)
)
fig.update_yaxes(
    tickfont=dict(family="Arial", size=16),
    showgrid=False,
    title_font=dict(family="Arial", size=18)
)

fig.update_traces(
    xgap=0, ygap=0,
    hovertemplate="AA: %{y}<br>Pos: %{x}<br>Value: %{z}<extra></extra>"
)

fig.show()


### Prioritization of mutations on structure

In [7]:
#@title 3D Structure Viewer (click to expand) {display-mode:"form"}
import py3Dmol
import pandas as pd
import matplotlib.cm as cm

style = "cartoon"  #@param ["cartoon", "rainbow"]
highlight_color = "salmon"  #@param ["salmon", "red", "orange", "blue"]
sphere_size = 1.0  #@param {type:"number"}
stick_size = 0.3  #@param {type:"number"}

mut_df = pd.read_csv(f'{outdir}/{pdb_name}.info_rank_by_prob.csv')
top_mut = mut_df.sort_values("mut_prob", ascending=False).head(10)
top_positions = top_mut["auth_idx"].tolist()
top_values = top_mut["mut_prob"].tolist()

with open(f"{pdb_dir}/{pdb_name}.pdb", "r") as f:
    pdb_data = f.read()

view = py3Dmol.view(width=900, height=650)
view.addModel(pdb_data, 'pdb')

if style == "cartoon":
    view.setStyle({'cartoon': {
        'color': 'lightblue',
        'opacity': 0.85,
        'smooth': True,
        'thickness': 0.5
    }})
elif style == "rainbow":
    view.setStyle({'cartoon': {'color': 'spectrum'}})

color_dict = {
    "salmon": "rgb(250,128,114)",
    "red":    "rgb(255,0,0)",
    "yellow": "rgb(255,215,0)",
    "cyan":   "rgb(0,255,255)"
}
HL = color_dict[highlight_color]

for pos in top_positions:
    view.addStyle(
        {'resi': str(pos), 'atom': 'CA'},
        {'sphere': {'color': HL, 'radius': sphere_size},
         'stick':  {'color': HL, 'radius': stick_size}}
    )

view.setBackgroundColor("white")
view.zoomTo()
view.show()


# 3.Family-specific inference (meta transfer learning)

In [8]:
#@title unZipro family-specific fine-tuning {display-mode: "form"}
#@markdown **Parameter Descriptions**
#@markdown - **pdb_code**: PDB ID or AlphaFold DB UniProt ID to load
#@markdown   (e.g., `"6vpc"` for PDB or `"Q9NUG6"` for AFDB).
#@markdown - **chain_id**: Protein chain to analyze.
#@markdown   Use the exact chain for PDB (e.g., `"E"`); AFDB models should use `"A"`.
#@markdown - **train_size**: Number of family sequences used for fine-tuning. A larger value improves performance but increases computation.
#@markdown - **patience**: Early-stopping patience (in epochs). Training stops if validation does not improve for this number of epochs.
#@markdown - **epochs**: Maximum number of training epochs.
#@markdown - **adapt_step**: Update interval (in steps) for adapting model parameters during fine-tuning.
pdb_code = "6vpc"  #@param {type:"string"}
chain_id = "E" #@param {type:"string"}
train_size = 100 #@param {type:"number"}
patience = 5 #@param {type:"number"}
epochs = 10 #@param {type:"number"}
adapt_step = 10 #@param {type:"number"}

import os
import sys
sys.path.append("unZipro/script/")
import time
import json
import argparse
from tempfile import gettempdir
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import learn2learn as l2l
from tqdm import tqdm
import requests
import biotite.structure as struc
import biotite.structure.io as strucio
from utils import *
from model import unZipro, weights_init
from foldseek_api import submit_pdb_to_foldseek, parse_pdb_input
from parse_foldseek_results import parse_mmseqs, write_ids
from fetch_PDB_parallel import fetch_and_save, safe_fetch
from unZipro_finetuning import unZipro_finetune
from unZipro_mutation import infer_single_protein

outdir = 'output'
pdb_dir = 'pdb'
work_dir = './'
train_size = 100
config_path = 'unZipro/config/unZipro_pretrain.json'
param = 'unZipro/Models/unZipro_params.pt'
cache_dir = 'tmp/'
outdir = 'output'
temperature = 1.0
nneighbor = 20
pdb_name = f'{pdb_code}{chain_id}'
safe_fetch(pdb_name, pdb_dir)
os.makedirs(outdir, exist_ok=True)
cpu = os.cpu_count()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
cpu_only = False if torch.cuda.is_available() else True
adapt_lr = 1e-6
meta_lr = 1e-6
res = None
print(f"INFO | Step 1: start to retrieve similar structures using Foldseek!")
_ = submit_pdb_to_foldseek(
    pdb_list=parse_pdb_input(pdb_name),
    outdir=work_dir,
    pdb_dir=pdb_dir,
    use_wget=True,
    use_aria2c=False,
    wait_time=30,
    only_download=False
)
print(f"INFO | Step 2: split data into meta-training and testing!")
try:
    m8_path = os.path.join(work_dir, f'{pdb_name}/alis_pdb100.m8')
    savepath = os.path.join(work_dir, f'{pdb_name}_pdb100.csv')
    data_dir = os.path.join(work_dir, f'{pdb_name}/')
    pdb_id_path = os.path.join(data_dir, 'PDB_IDs.txt')
    os.makedirs(data_dir, exist_ok=True)
    train_path = os.path.join(data_dir, f'train.csv')
    test_path = os.path.join(data_dir, f'test.csv')
    train, test, download_pdblist, num_pdbs, num_af_pdbs = parse_mmseqs(m8_path, savepath, train_path, test_path, train_size=train_size, include_af2=True)
    if len(train) < 10 or len(test) < 10:
        raise ValueError(f"Too little homologs for {pdb_name} Train size {len(train)} | Test size {len(test)}!")
    if len(train) < 100 or len(test) < 20:
        print(f'Warning! [{pdb_name}]: Train size {len(train)} | Test size {len(test)}')
except Exception as e:
    print(f'Error in {pdb_name}: {e}')
print(f"INFO | Step 3: start to download structural data!")
pdb_dir = os.path.join(work_dir, 'pdb')
os.makedirs(pdb_dir, exist_ok=True)
print(f"INFO | Found {len(train)} meta-training structures and {len(test)} meta-testing structures!")
with ThreadPoolExecutor(max_workers=4) as executor:
    # futures = [executor.submit(fetch_and_save, pdb, pdb_dir) for pdb in download_pdblist]
    futures = [executor.submit(safe_fetch, pdb, pdb_dir) for pdb in download_pdblist]
    for f in futures:
        f.result()
print(f"INFO | Step 4: start unZipro finetuning!")
model_store_dir = os.path.join(work_dir, 'model/')
cache_dir = os.path.join(work_dir, 'tmp/')
# unZipro finetuning
model_param = unZipro_finetune(train_path, test_path, config_path=config_path, pdb_dir=pdb_dir, model_store_dir=model_store_dir,
                param_file=param, project_name=pdb_name, epochs=epochs,
                adapt_lr=adapt_lr, meta_lr=meta_lr, adapt_step=adapt_step,
                batchsize=1, cpu=cpu, gpu=0,cpu_only=cpu_only,
                noise=0.01, nneighbor=nneighbor, patience=patience, cache_dir=cache_dir, save_model_ckp=False)
os.system(f'rm -rf {cache_dir}/*')
model_param = f'{model_store_dir}/{pdb_name}.pt'
print(f"INFO | Step 5: start mutation prioritation!")
os.makedirs(f'{outdir}', exist_ok=True)
device = torch.device(f'cuda:0' if torch.cuda.is_available() else 'cpu')
with open(config_path, "r") as f:
    data = json.load(f)
model_config = Config(**data)
model = unZipro(model_config).to(device)
state_dict = torch.load(model_param, map_location=torch.device(device), weights_only=True)
try:
    model.load_state_dict(state_dict)
except:
    new_state_dict = {k.replace("module.", ""):v for k,v in state_dict.items()}
    model.load_state_dict(new_state_dict)
# dataloader setup
dataset = GraphDataset(datalist=[osp.join(pdb_dir, f'{pdb_name}.pdb')], nneighbor=nneighbor, noise=0, cache_dir=cache_dir)
loader = get_loader(dataset=dataset, batchsize=1)
criterion = nn.CrossEntropyLoss().to(device)
model.eval()
df_list= infer_single_protein(model, criterion, loader, pdb_dir, outdir, temperature=1.0, device=device, output_prob=True, output_logits=True, rank_by_prob=True, res=res)
end = time.perf_counter()
print("‚úÖ Done! Results saved in", outdir)

‚úÖ Download 6vpcE
INFO | Step 1: start to retrieve similar structures using Foldseek!
Status: 200 | Response: {"id":"H-56SZLu1Hfs1T_dF4TYXbbVZ3GeCCt4X3PUog","status":"COMPLETE"}

Query 6vpcE | Ticket: H-56SZLu1Hfs1T_dF4TYXbbVZ3GeCCt4X3PUog
Processed 1/1: 6vpcE. Waiting 30s...
All tasks submitted. Now waiting for results...
Executing download script for 6vpcE...
Download completed for 6vpcE.
All tasks processed and results downloaded.
INFO | Step 2: split data into meta-training and testing!
INFO | Step 3: start to download structural data!
INFO | Found 100 meta-training structures and 20 meta-testing structures!
‚úÖ Download 2hxvA
‚úÖ Download 2o3kB
‚úÖ Download 8dqcA
‚úÖ Download 5xkpB
‚úÖ Download 2b3zB
‚úÖ Download 3zpgA
‚úÖ Download 8dqbA
‚úÖ Download 5xkpC
‚úÖ Download 7bv5D
‚úÖ Download 7bv5C
‚úÖ Download 5xkoA
‚úÖ Download 2o7pA
‚úÖ Download 8dq9B
‚úÖ Download 5xkqD
‚úÖ Download 5xkqA
‚úÖ Download 5xkqC
‚úÖ Download 8aw33
‚úÖ Download 2w4lD
‚úÖ Download 2w4lF
‚úÖ Download 2w4lE

Finetuning:  10%|‚ñà         | 1/10 [00:25<03:47, 25.27s/epoch]

Epoch 1 | Train Loss: 1.6870 | Train Acc: 49.05% | Valid Loss: 1.5825 | Valid Acc: 51.72%


Finetuning:  20%|‚ñà‚ñà        | 2/10 [00:42<02:43, 20.49s/epoch]

Epoch 2 | Train Loss: 1.8210 | Train Acc: 46.51% | Valid Loss: 1.7899 | Valid Acc: 45.95%


Finetuning:  30%|‚ñà‚ñà‚ñà       | 3/10 [00:56<02:03, 17.58s/epoch]

Epoch 3 | Train Loss: 1.7903 | Train Acc: 47.70% | Valid Loss: 1.5803 | Valid Acc: 48.98%


Finetuning:  40%|‚ñà‚ñà‚ñà‚ñà      | 4/10 [01:08<01:32, 15.49s/epoch]

Epoch 4 | Train Loss: 1.7411 | Train Acc: 48.67% | Valid Loss: 1.6479 | Valid Acc: 49.97%


Finetuning:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 5/10 [01:20<01:10, 14.19s/epoch]

Epoch 5 | Train Loss: 1.6898 | Train Acc: 49.18% | Valid Loss: 1.6444 | Valid Acc: 49.40%


Finetuning:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 6/10 [01:32<00:53, 13.44s/epoch]

Epoch 6 | Train Loss: 1.8021 | Train Acc: 46.80% | Valid Loss: 1.6107 | Valid Acc: 50.35%


Finetuning:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 7/10 [01:44<00:39, 13.03s/epoch]

Epoch 7 | Train Loss: 1.7821 | Train Acc: 47.25% | Valid Loss: 1.6301 | Valid Acc: 49.54%


Finetuning:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 7/10 [01:56<00:49, 16.66s/epoch]


Epoch 8 | Train Loss: 1.7866 | Train Acc: 47.18% | Valid Loss: 1.6859 | Valid Acc: 49.34%
Early stopping triggered.
Finished training. Total elapsed time: 117s. Best valid loss: 1.5803 at epoch 3
INFO | Step 5: start mutation prioritation!
[INFO] 6vpcE | Recovery: 49.66% | Loss: 1.7615
[INFO] Saved in silico mutation scores to output/6vpcE.info.csv!
[INFO] Saved ranked scores to output/6vpcE.info_rank_by_prob.csv!
[INFO] Saved per-residue probability matrix  to output/6vpcE.info_probs.csv!
[INFO] Saved per-residue logits to output/6vpcE.info_logits.csv!
Total time: 0.150 s
Average per item: 0.150 s
‚úÖ Done! Results saved in output


## Visualize fitness landscape

In [9]:
#@title Fitness Landscape Viewer {display-mode:"form"}
color_scale = "YlGnBu_r"  #@param ["YlGnBu_r","Viridis","YlOrRd","purples", "rdpu", "reds"]
fig_width = 1400  #@param {type:"number"}
fig_height = 600  #@param {type:"number"}
font_size = 14  #@param {type:"number"}
title_size = 22  #@param {type:"number"}

import pandas as pd
import plotly.express as px

df = pd.read_csv(f'{outdir}/{pdb_name}.info_probs.csv')
aa_cols = [c for c in df.columns if c.startswith("prob_")]

# Prepare score matrix
score = df[aa_cols].copy()
score_T = score.T
score_T.index = [c.replace("prob_", "") for c in aa_cols]
score_T.columns = df["auth_idx"]

fig = px.imshow(
    score_T,
    color_continuous_scale=color_scale,
    aspect="auto",
    labels=dict(color="Fitness"),
    title=f"<b>Fitness Landscape for {pdb_name}</b>",
)

fig.update_layout(
    width=fig_width,
    height=fig_height,
    font=dict(family="Arial", size=font_size),
    title=dict(x=0.5, y=0.95, font=dict(size=title_size)),
    xaxis_title="<b>Position</b>",
    yaxis_title="<b>Amino Acid</b>",
    margin=dict(l=120, r=80, t=100, b=120),
    coloraxis_colorbar=dict(
        title="<b>Score</b>",
        thickness=18,
        len=0.7,
        tickfont=dict(size=12)
    )
)

fig.update_xaxes(
    tickfont=dict(family="Arial", size=16),
    tickangle=0,
    showgrid=False,
    title_font=dict(family="Arial", size=18)
)
fig.update_yaxes(
    tickfont=dict(family="Arial", size=16),
    showgrid=False,
    title_font=dict(family="Arial", size=18)
)

fig.update_traces(
    xgap=0, ygap=0,
    hovertemplate="AA: %{y}<br>Pos: %{x}<br>Value: %{z}<extra></extra>"
)

fig.show()


## Prioritization of mutations on structure

In [10]:
#@title 3D Structure Viewer (click to expand) {display-mode:"form"}
import py3Dmol
import pandas as pd
import matplotlib.cm as cm

style = "cartoon"  #@param ["cartoon", "rainbow"]
highlight_color = "salmon"  #@param ["salmon", "red", "orange", "blue"]
sphere_size = 1.0  #@param {type:"number"}
stick_size = 0.3  #@param {type:"number"}

mut_df = pd.read_csv(f'{outdir}/{pdb_name}.info_rank_by_prob.csv')
top_mut = mut_df.sort_values("mut_prob", ascending=False).head(10)
top_positions = top_mut["auth_idx"].tolist()
top_values = top_mut["mut_prob"].tolist()

with open(f"{pdb_dir}/{pdb_name}.pdb", "r") as f:
    pdb_data = f.read()

view = py3Dmol.view(width=900, height=650)
view.addModel(pdb_data, 'pdb')

if style == "cartoon":
    view.setStyle({'cartoon': {
        'color': 'lightblue',
        'opacity': 0.85,
        'smooth': True,
        'thickness': 0.5
    }})
elif style == "rainbow":
    view.setStyle({'cartoon': {'color': 'spectrum'}})

color_dict = {
    "salmon": "rgb(250,128,114)",
    "red":    "rgb(255,0,0)",
    "yellow": "rgb(255,215,0)",
    "cyan":   "rgb(0,255,255)"
}
HL = color_dict[highlight_color]

for pos in top_positions:
    view.addStyle(
        {'resi': str(pos), 'atom': 'CA'},
        {'sphere': {'color': HL, 'radius': sphere_size},
         'stick':  {'color': HL, 'radius': stick_size}}
    )

view.setBackgroundColor("white")
view.zoomTo()
view.show()
