In [None]:
%pip install -q "airavata-python-sdk[notebook]"
import airavata_jupyter_magic

%authenticate
%request_runtime hpc_gpu --file=cybershuttle.yml --walltime=60 --use=NeuroData25VC2:cloud,Anvil:gpu
%switch_runtime hpc_gpu

In [None]:
!git clone https://github.com/vizfold/attention-viz-demo /workspace/attention-viz-demo
!cd /workspace/attention-viz-demo
%cd /workspace/attention-viz-demo

In [None]:
!wget -N --no-check-certificate -P openfold/resources \
    https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
!bash scripts/download_openfold_params.sh openfold/resources/params/
!mv openfold/resources/params/openfold_params/* openfold/resources/params/
!rmdir openfold/resources/params/openfold_params/
!mkdir -p /workspace/attention-viz-demo/outputs

In [None]:
# Imports and configuration
import os
import subprocess

# User-defined parameters
PROT = "6KWC"
TRI_RESIDUE_IDX = 18

# Input/output paths
BASE_DATA_DIR = ""  # path to AlphaFold data
FASTA_DIR = "./examples/monomer/fasta_dir" # directory containing fasta file to evaluate
ATTN_MAP_DIR = "./outputs" # directory for saving text files with top-k attention scores
ALIGNMENT_DIR = "./examples/monomer/alignments" # directory containing pre-computed alignment files (and MSAs)
OUTPUT_DIR = "./outputs" # directory to save outputs



In [None]:
# Run OpenFold inference and save top attention scores to text files 
inference_cmd = f"""
python3 run_pretrained_openfold.py \
    {FASTA_DIR} \
    {BASE_DATA_DIR}/pdb_mmcif/mmcif_files \
    --use_precomputed_alignments {ALIGNMENT_DIR} \
    --output_dir {OUTPUT_DIR} \
    --config_preset model_1_ptm \
    --uniref90_database_path {BASE_DATA_DIR}/uniref90/uniref90.fasta \
    --mgnify_database_path {BASE_DATA_DIR}/mgnify/mgy_clusters_2022_05.fa \
    --pdb70_database_path {BASE_DATA_DIR}/pdb70/pdb70 \
    --uniclust30_database_path {BASE_DATA_DIR}/uniclust30/uniclust30_2018_08 \
    --bfd_database_path {BASE_DATA_DIR}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
    --save_outputs \
    --model_device "cuda:0" \
    --attn_map_dir {ATTN_MAP_DIR} \
    --num_recycles_save 1 \
    --triangle_residue_idx {TRI_RESIDUE_IDX} \
    --demo_attn
"""

subprocess.run(inference_cmd, shell=True, check=True)


In [None]:
# Generate attention arc diagrams (MSA + Triangle Start)
from visualize_attention_arc_diagram_demo_utils import generate_arc_diagrams, parse_fasta_sequence
arc_output_dir_msa = "" # directory for saving msa attention arc diagrams
arc_output_dir_tri = "" # directory for saving triangle attention arc diagrams
FASTA_PATH = "" # fasta file for labeling plots
LAYER_IDX = 47 # selected layer for attention evaluation
TOP_K = 50 # top-k lines to plot (limit to 500)

residue_seq = parse_fasta_sequence(FASTA_PATH)

generate_arc_diagrams(
    attention_dir=ATTN_MAP_DIR,
    residue_sequence=residue_seq,
    output_dir=arc_output_dir_msa,
    attention_type="msa_row",
    top_k=TOP_K,
    layer_idx=LAYER_IDX
)

generate_arc_diagrams(
    attention_dir=ATTN_MAP_DIR,
    residue_sequence=residue_seq,
    output_dir=arc_output_dir_tri,
    attention_type="triangle_start",
    residue_indices=[TRI_RESIDUE_IDX],
    top_k=TOP_K,
    layer_idx=LAYER_IDX
)


In [None]:
# Generate PyMOL 3D attention visualizations
from visualize_attention_3d_demo_utils import plot_pymol_attention_heads
output_dir_msa = "" # directory for saving msa attention 3D visuals
output_dir_tri = "" # directory for saving triangle attention 3D visuals
PDB_FILE = "" # PDB file for plotting in 3D
LAYER_IDX = 47 # selected layer for attention evaluation
TOP_K = 50 # top-k lines to plot (limit to 500)

plot_pymol_attention_heads(
    pdb_file=PDB_FILE,
    attention_dir=ATTN_MAP_DIR,
    output_dir=output_dir_msa,
    attention_type="msa_row",
    top_k=TOP_K,
    layer_idx=LAYER_IDX
)

plot_pymol_attention_heads(
    pdb_file=PDB_FILE,
    attention_dir=ATTN_MAP_DIR,
    output_dir=output_dir_tri,
    attention_type="triangle_start",
    residue_indices=[TRI_RESIDUE_IDX],
    top_k=TOP_K,
    layer_idx=LAYER_IDX
)
