In [1]:
import esm
import sys, os
import pandas as pd
import numpy as np

import torch
import random

file_path = "../model"
sys.path.append(file_path)
from dictionary import AutoEncoder

### Load ESM and SAE

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm_model.eval()
esm_model = esm_model.to(device)
batch_converter = alphabet.get_batch_converter()

chk_path = '/path/to/MotifAE_step_80000.pt' # please download this file from zenodo: https://zenodo.org/records/17488191
motifae = AutoEncoder.from_pretrained(chk_path)
motifae.eval()
motifae = motifae.to(device)

gate = torch.load('/path/to/1404_stability_associated_features.pt', weights_only=True).to(device) # download from zenodo: https://zenodo.org/records/17488191

# set activated gate to n, others to 1
gate_n = 4
gate = gate*(gate_n-1) + 1

  state_dict = t.load(path)


### function for protein design

In [6]:
def get_per_site_prob(esm_model, seq, motifae, gate):
    batch_labels, batch_strs, batch_tokens = batch_converter([('_', seq)])
    batch_tokens = batch_tokens.to(device)

    with torch.no_grad():
        results = esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
        embed = results["representations"][33]
        f = motifae.encode(embed)
        f = gate * f
        embed = motifae.decode(f)

        logits = esm_model.lm_head(embed)

    probs = torch.softmax(logits[0, 1:-1, :], dim=-1)
    prob = pd.DataFrame(probs.cpu().numpy(), columns=alphabet.all_toks, index=list(seq)).T
    
    relative_prob = prob / np.diag(prob.loc[prob.columns])

    relative_prob = pd.DataFrame(relative_prob.iloc[4:24].T.stack(), columns=['p']).reset_index()
    relative_prob.columns = ['wt', 'mt', 'p']
    relative_prob['site'] = relative_prob.index // 20
    return relative_prob

def get_mut_candidate(esm_model, seq, motifae, gate):
    relative_prob = get_per_site_prob(esm_model, seq, motifae, gate)
    mut_candidate = relative_prob[relative_prob['wt'] != relative_prob['mt']]
    mut_candidate = mut_candidate.sort_values('p', ascending=False).reset_index(drop=True).loc[:10]
    mut_candidate['p_norm'] = mut_candidate['p'] / mut_candidate['p'].sum()

    sampled_mut = random.choices(mut_candidate.index, weights=mut_candidate['p_norm'], k=1)[0]
        
    # get mut sequence
    mut_seq = list(seq)
    mut_seq[mut_candidate.loc[sampled_mut, 'site']] = mut_candidate.loc[sampled_mut, 'mt']
    mut_candidate.loc[sampled_mut, 'mut_seq'] = ''.join(mut_seq)

    return mut_candidate.loc[sampled_mut]

### test set representative proteins

In [9]:
pro = pd.read_csv('../data/412pro_info.csv')
pro_test = pro[(pro['split'] == 'test') & (pro['cluster_representative'] == pro['WT_name'])].reset_index(drop=True)

In [None]:
i = 0
pro, seq = pro_test.loc[i, 'WT_name'], pro_test.loc[i, 'aa_seq']

design_record = {0: {'mut_seq': seq}}
for rep in range(2): # two independent design trajectories
    for r in range(1, 5): # four rounds of designb
        design_record[f'{rep}_{r}'] = get_mut_candidate(esm_model, design_record[f'{rep}_{r-1}' if r>1 else 0]['mut_seq'], motifae, gate)

design_record = pd.DataFrame(design_record).T

In [11]:
design_record

Unnamed: 0,mt,mut_seq,p,p_norm,site,wt
0,,KVTIVVENIKVFGEDGKLTDEARRLLEKALEEAKRFPGKEVEIVLLP,,,,
0_1,Y,KVTIVVENIKVFGEDGKLTDEARRLLEKALEEAKRFYGKEVEIVLLP,13.534982,0.233103,36.0,P
0_2,M,MVTIVVENIKVFGEDGKLTDEARRLLEKALEEAKRFYGKEVEIVLLP,131.699371,0.949434,0.0,K
0_3,L,MVTIVVENIKVFGEDGKLTDEARRLLEKALEEAKRLYGKEVEIVLLP,0.793624,0.111265,35.0,F
0_4,K,MVKIVVENIKVFGEDGKLTDEARRLLEKALEEAKRLYGKEVEIVLLP,2.683129,0.444825,2.0,T
1_1,Y,KVTIVVENIKVFGEDGKLTDEARRLLEKALEEAKRFYGKEVEIVLLP,13.534982,0.233103,36.0,P
1_2,M,MVTIVVENIKVFGEDGKLTDEARRLLEKALEEAKRFYGKEVEIVLLP,131.699371,0.949434,0.0,K
1_3,I,MVTIVVENIKVFGEDGKLTDEARRLLEKAIEEAKRFYGKEVEIVLLP,0.255718,0.035851,29.0,L
1_4,K,MVTIVVENIKVFGEDGKLTDEARRLLEKAIEEAKRKYGKEVEIVLLP,0.955164,0.158452,35.0,F
