In [None]:
from pathlib import Path
import pandas as pd
import torch
from tqdm import tqdm
from ThermalFold.predict_utils import thermalFold_predictor
from ThermalFold.esm_utils import ESM_embedding,ESMH5Cache
from ThermalFold.data_utils import parse_key_name
from lightning.fabric import seed_everything
import warnings
import shutil
import MDAnalysis as mda
from MDAnalysis.analysis import align
warnings.filterwarnings('ignore')

## Temperature ramping examples

In [35]:
cfg_fname = '../weight/model_conf.yaml'
weight_path  = '../weight/model_weight.pt'

df = pd.read_csv("seqs/examples_1.csv")

seed_everything(42)
temp_lst = list(range(273,474,5))
# Remove next line '#' before rerun the prediction
#if Path("ThermalFold_prediction").is_dir(): shutil.rmtree('ThermalFold_prediction')
with ESM_embedding('esm2_650M',device='cuda:0',cache_path='./esm_cache') as esm:
    predictor = thermalFold_predictor(cfg_fname=cfg_fname,weight_path=weight_path,esm=esm,device='cuda:0')
    for name,seq in tqdm(zip(df['name'],df['sequence']),total=len(df)):
        tpath = Path(f"ThermalFold_prediction/{name}")
        tpath.mkdir(exist_ok=True,parents=True)
        for temp in temp_lst:
            tfname = tpath/f"NAME={name}:TEMP={temp}.pdb"
            if tfname.is_file(): continue
            inputs = [[seq,temp]]
            res = predictor.predict(inputs)[0]
            with open(tfname,'w') as f:f.write(res)
torch.cuda.empty_cache()

Seed set to 42
Using cache found in /home/b208/.cache/torch/hub/facebookresearch_esm_main
100%|██████████| 3/3 [01:22<00:00, 27.65s/it]


In [None]:
# Install the py3Dmol to visualize the predicted structures
!pip install py3Dmol

In [41]:
import py3Dmol

fname = Path("ThermalFold_prediction/Ice_structuring_protein/NAME=Ice_structuring_protein:TEMP=363.pdb")
with open(fname, 'r') as f:
    pdb_str = f.read()


view = py3Dmol.view(width=600, height=500)
view.addModel(pdb_str, 'pdb')


view.setStyle({'ss': 'h'}, {'cartoon': {'color': '#c23779'}})
view.setStyle({'ss': 's'}, {'cartoon': {'color': '#f7b731'}})
view.setStyle({'ss': 'c'}, {'cartoon': {'color': 'grey'}}) 

view.zoomTo()
view

<py3Dmol.view at 0x7eb17bfc9990>

## Single Temperature Sampling

In [None]:
cfg_fname = '../weight/model_conf.yaml'
weight_path  = '../weight/model_weight.pt'

df = pd.read_csv("seqs/examples_2.csv")

seed_lst = list(range(20))
temp = 310
# Remove next line '#' before rerun the prediction
#if Path("ThermalFold_prediction").is_dir(): shutil.rmtree('ThermalFold_prediction')
with ESM_embedding('esm2_650M',device='cuda:0',cache_path='./esm_cache') as esm:
    predictor = thermalFold_predictor(cfg_fname=cfg_fname,weight_path=weight_path,esm=esm,device='cuda:0')
    for name,seq in tqdm(zip(df['name'],df['sequence']),total=len(df)):
        tpath = Path(f"ThermalFold_prediction/{name}")
        tpath.mkdir(exist_ok=True,parents=True)
        for seed in seed_lst:
            seed_everything(seed)
            tfname = tpath/f"NAME={name}:SEED={seed}:TEMP={temp}.pdb"
            if tfname.is_file(): continue
            inputs = [[seq,temp]]
            res = predictor.predict(inputs)[0]
            with open(tfname,'w') as f:f.write(res)
torch.cuda.empty_cache()

Using cache found in /home/b208/.cache/torch/hub/facebookresearch_esm_main
  0%|          | 0/3 [00:00<?, ?it/s]Seed set to 0
Seed set to 1
Seed set to 2
Seed set to 3
Seed set to 4
Seed set to 5
Seed set to 6
Seed set to 7
Seed set to 8
Seed set to 9
Seed set to 10
Seed set to 11
Seed set to 12
Seed set to 13
Seed set to 14
Seed set to 15
Seed set to 16
Seed set to 17
Seed set to 18
Seed set to 19
 33%|███▎      | 1/3 [00:13<00:27, 13.64s/it]Seed set to 0
Seed set to 1
Seed set to 2
Seed set to 3
Seed set to 4
Seed set to 5
Seed set to 6
Seed set to 7
Seed set to 8
Seed set to 9
Seed set to 10
Seed set to 11
Seed set to 12
Seed set to 13
Seed set to 14
Seed set to 15
Seed set to 16
Seed set to 17
Seed set to 18
Seed set to 19
 67%|██████▋   | 2/3 [00:27<00:14, 14.03s/it]Seed set to 0
Seed set to 1
Seed set to 2
Seed set to 3
Seed set to 4
Seed set to 5
Seed set to 6
Seed set to 7
Seed set to 8
Seed set to 9
Seed set to 10
Seed set to 11
Seed set to 12
Seed set to 13
Seed set to 14
See

In [None]:
import py3Dmol

f_lst = sorted(list(Path("ThermalFold_prediction/Spectrin").glob("*pdb")))
u = mda.Universe(f_lst[0],f_lst)
aligner = align.AlignTraj(u, u, select="name CA", filename='temp.pdb', match_atoms=True)
aligner.run()

with open('temp.pdb', 'r') as f:
    pdb_content = f.read()

lines = pdb_content.splitlines()

blocks = []
current = []

for line in lines:
    if line.startswith('MODEL') and current:
        if any(l.startswith(('ATOM', 'HETATM')) for l in current):
            blocks.append('\n'.join(current) + '\nENDMDL')
        current = []
    current.append(line)

if current and any(l.startswith(('ATOM', 'HETATM')) for l in current):
    blocks.append('\n'.join(current) + '\nENDMDL')

view = py3Dmol.view(width=800, height=600)


for i, block in enumerate(blocks):
    view.addModel(block, 'pdb')
    view.setStyle({'model': i}, {'cartoon': {'color': 'red', 'opacity': 1}})

view.zoomTo()
view.show()

## Single Mutation Effect of PF1066

In [None]:
cfg_fname = '../weight/model_conf.yaml'
weight_path  = '../weight/model_weight.pt'

df = pd.read_csv("seqs/examples_3.csv")

seed_everything(42)
temp_lst = list(range(273,474,5))
# Remove next line '#' before rerun the prediction
#if Path("ThermalFold_prediction").is_dir(): shutil.rmtree('ThermalFold_prediction')
with ESM_embedding('esm2_650M',device='cuda:0',cache_path='./esm_cache') as esm:
    predictor = thermalFold_predictor(cfg_fname=cfg_fname,weight_path=weight_path,esm=esm,device='cuda:0')
    for name,seq in tqdm(zip(df['name'],df['sequence']),total=len(df)):
        tpath = Path(f"ThermalFold_prediction/{name}")
        tpath.mkdir(exist_ok=True,parents=True)
        for temp in temp_lst:
            tfname = tpath/f"NAME={name}:TEMP={temp}.pdb"
            if tfname.is_file(): continue
            inputs = [[seq,temp]]
            res = predictor.predict(inputs)[0]
            with open(tfname,'w') as f:f.write(res)
torch.cuda.empty_cache()

Seed set to 42
Using cache found in /home/b208/.cache/torch/hub/facebookresearch_esm_main
100%|██████████| 2/2 [00:52<00:00, 26.34s/it]


In [None]:
import py3Dmol

fname = Path("ThermalFold_prediction/1SF0_wtm_V59K/NAME=1SF0_wtm_V59K:TEMP=433.pdb")
with open(fname, 'r') as f:
    pdb_str = f.read()


view = py3Dmol.view(width=600, height=500)
view.addModel(pdb_str, 'pdb')


view.setStyle({'ss': 'h'}, {'cartoon': {'color': '#c23779'}})
view.setStyle({'ss': 's'}, {'cartoon': {'color': '#f7b731'}})
view.setStyle({'ss': 'c'}, {'cartoon': {'color': 'grey'}}) 

sel_loose = {'resi': [51, 50,52, 59]}
view.addStyle(sel_loose, {
    'stick': {'colorscheme': 'element', 'radius': 0.35}
})


view.zoomTo()
view

<py3Dmol.view at 0x760cbe391410>

In [34]:
import py3Dmol

fname = Path("ThermalFold_prediction/1SF0_wtm/NAME=1SF0_wtm:TEMP=433.pdb")
with open(fname, 'r') as f:
    pdb_str = f.read()


view = py3Dmol.view(width=600, height=500)
view.addModel(pdb_str, 'pdb')


view.setStyle({'ss': 'h'}, {'cartoon': {'color': '#c23779'}})
view.setStyle({'ss': 's'}, {'cartoon': {'color': '#f7b731'}})
view.setStyle({'ss': 'c'}, {'cartoon': {'color': 'grey'}}) 

sel_loose = {'resi': [51, 50,52, 59]}
view.addStyle(sel_loose, {
    'stick': {'colorscheme': 'element', 'radius': 0.35}
})


view.zoomTo()
view

<py3Dmol.view at 0x7eb179a4e390>

## Disulfide Bonds

In [None]:
cfg_fname = '../weight/model_conf.yaml'
weight_path  = '../weight/model_weight.pt'

df = pd.read_csv("seqs/examples_4.csv")

seed_everything(42)
temp_lst = list(range(273,504,5))
# Remove next line '#' before rerun the prediction
#if Path("ThermalFold_prediction").is_dir(): shutil.rmtree('ThermalFold_prediction')
with ESM_embedding('esm2_650M',device='cuda:0',cache_path='./esm_cache') as esm:
    predictor = thermalFold_predictor(cfg_fname=cfg_fname,weight_path=weight_path,esm=esm,device='cuda:0')
    for name,seq in tqdm(zip(df['name'],df['sequence']),total=len(df)):
        tpath = Path(f"ThermalFold_prediction_disulfide/{name}")
        tpath.mkdir(exist_ok=True,parents=True)
        for temp in temp_lst:
            tfname = tpath/f"NAME={name}:TEMP={temp}.pdb"
            if tfname.is_file(): continue
            inputs = [[seq,temp]]
            res = predictor.predict(inputs,num_samples=3)[0]
            with open(tfname,'w') as f:f.write(res)
torch.cuda.empty_cache()

Seed set to 42
Using cache found in /home/b208/.cache/torch/hub/facebookresearch_esm_main
100%|██████████| 10/10 [14:55<00:00, 89.55s/it]


In [4]:
import py3Dmol

fname = Path("ThermalFold_prediction_disulfide/MiniProteinA_wt/NAME=MiniProteinA_wt:TEMP=403.pdb")
with open(fname, 'r') as f:
    pdb_str = f.read()


view = py3Dmol.view(width=600, height=500)
view.addModel(pdb_str, 'pdb')


view.setStyle({'ss': 'h'}, {'cartoon': {'color': '#c23779'}})
view.setStyle({'ss': 's'}, {'cartoon': {'color': '#f7b731'}})
view.setStyle({'ss': 'c'}, {'cartoon': {'color': 'grey'}}) 

sel_loose = {'resn': ['CYS']}
view.addStyle(sel_loose, {
    'stick': {'colorscheme': 'element', 'radius': 0.35}
})


view.zoomTo()
view

<py3Dmol.view at 0x760cbe439050>

In [None]:
import py3Dmol

fname = Path("ThermalFold_prediction_disulfide/MiniProteinA_cys/NAME=MiniProteinA_cys:TEMP=403.pdb")
with open(fname, 'r') as f:
    pdb_str = f.read()


view = py3Dmol.view(width=600, height=500)
view.addModel(pdb_str, 'pdb')


view.setStyle({'ss': 'h'}, {'cartoon': {'color': '#c23779'}})
view.setStyle({'ss': 's'}, {'cartoon': {'color': '#f7b731'}})
view.setStyle({'ss': 'c'}, {'cartoon': {'color': 'grey'}}) 

sel_loose = {'resn': ['CYS']}
view.addStyle(sel_loose, {
    'stick': {'colorscheme': 'element', 'radius': 0.35}
})


view.zoomTo()
view

<py3Dmol.view at 0x760dc8897dd0>