In [None]:
import evo_prot_grad
from transformers import AutoModel
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; 
sns.set_style("whitegrid")
# set font size seaborn
sns.set_context("notebook", font_scale=1.25)

%load_ext autoreload
%autoreload 2

## HuggingFace ESM2 8M model + Supervised fluorescence regression model

In [None]:
# HuggingFace ESM2 8M model
esm2_expert = evo_prot_grad.get_expert('esm', 'mutant_marginal', temperature = 1.0, device = 'cpu')

# Supervised fluorescence regression model
gfp_expert = evo_prot_grad.get_expert(
                        'onehot_downstream_regression',
                        'attribute_value',
                        temperature = 1.0,
                        model = AutoModel.from_pretrained('NREL/avGFP-fluorescence-onehot-cnn',trust_remote_code=True),
                        device = 'cpu')

variants, scores = evo_prot_grad.DirectedEvolution(
                        wt_fasta = 'test/gfp.fasta',
                        output = 'all',
                        experts = [esm2_expert, gfp_expert],
                        parallel_chains = 2,
                        n_steps = 5,              
                        max_mutations = 15,
                        verbose = False
)()


In [None]:
plt.figure()
for i in range(scores.shape[1]):
    plt.plot(scores[:,i], alpha=0.5)
    best_step = np.argmax(scores[:,i])
    best_score = scores[best_step,i]
    plt.plot(np.arange(0,scores.shape[0]), best_score * np.ones((scores.shape[0],)), c='black')
plt.plot(np.arange(0,scores.shape[0]), np.zeros((scores.shape[0],)), c='red', linewidth=3, label='wild type')
plt.xlabel('MCMC step')
plt.ylabel('best product of experts score')
plt.legend()
plt.show()

In [None]:
with open('test/gfp.fasta', 'r') as f:
    for line in f:
        if line[0] != '>':
            wtseq = ' '.join(line.strip())

for i in range(scores.shape[1]):
    best_step = np.argmax(scores[:,i])
    print(f'chain {i}, score: {scores[best_step,i]}')
    evo_prot_grad.common.utils.print_variant_in_color(variants[best_step][i], wtseq)

## Preserve regions in the wild type sequence

In [None]:
variants, scores = evo_prot_grad.DirectedEvolution(
                        wt_fasta = 'test/gfp.fasta',
                        output = 'all',
                        experts = [esm2_expert, gfp_expert],
                        parallel_chains = 16,
                        n_steps = 1000,              
                        max_mutations = 15,
                        verbose = False,
                        preserved_regions = [(0,13),(150,237)]
)()


In [None]:
plt.figure()
for i in range(scores.shape[1]):
    plt.plot(scores[:,i], alpha=0.5)
    best_step = np.argmax(scores[:,i])
    best_score = scores[best_step,i]
    plt.plot(np.arange(0,scores.shape[0]), best_score * np.ones((scores.shape[0],)), c='black')
plt.plot(np.arange(0,scores.shape[0]), np.zeros((scores.shape[0],)), c='red', linewidth=3, label='wild type')
plt.xlabel('MCMC step')
plt.ylabel('best product of experts score')
plt.legend()
plt.show()

for i in range(scores.shape[1]):
    best_step = np.argmax(scores[:,i])
    print(f'chain {i}, score: {scores[best_step,i]}')
    evo_prot_grad.common.utils.print_variant_in_color(variants[best_step][i], wtseq)