In [1]:
# Import packages
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.distributed
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from collections import OrderedDict
from torchtext import vocab # This package can give problems sometimes, it may be necessary to downgrade to a specific version
import seaborn as sns
import random
from random import choice
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from sklearn import metrics
import os
import pickle
from functions import (load_reward_model, identify_mutations_and_count, generate_df, generate_and_evaluate_mutants, generate_and_evaluate_mutants_max_sampling,
    mutate_sequences_after_training, mutate_sequences_after_training_esm2_max_sampling, get_sft_version_file)
from dataloading_RLXF_ESM2 import (ProtDataModuleESM2, ProtRepDatasetESM2)
from PPO_ESM2_650M_with_model_saving_DDP import RLXF_PPO_ESM2
from transformers import AutoModelForMaskedLM, AutoTokenizer
from MLP import MLP
import itertools
import copy
import warnings
import optuna
import logging
import sys
from optuna.exceptions import TrialPruned
from pytorch_lightning.callbacks import Callback
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig

In [2]:
# Define amino acid dictionary for tokenization, define WT for length of context window
AAs = 'ACDEFGHIKLMNPQRSTVWY' # setup torchtext vocab to map AAs to indices, usage is aa2ind(list(AAsequence))
WT = 'MAGLRHTFVVADATLPDCPLVYASEGFYAMTGYGPDEVLGHNARFLQGEGTDPKEVQKIRDAIKKGEACSVRLLNYRKDGTPFWNLLTVTPIKTPDGRVSKFVGVQVDVTSKTEGKALA' # CreiLOV
aa2ind = vocab.vocab(OrderedDict([(a, 1) for a in AAs]))
aa2ind.set_default_index(20) # set unknown charcterers to gap
sequence_length = len(WT)

In [3]:
# Loading ESMC
model_identifier ='esmc_600m' # esmc_300m
ESMC_model = ESMC.from_pretrained(model_identifier).to("cpu") # or esmc_600m
protein = ESMProtein(sequence=WT)
protein_tensor = ESMC_model.encode(protein)
logits_output = ESMC_model.logits(protein_tensor, LogitsConfig(sequence=True))
print(logits_output)

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

LogitsOutput(logits=ForwardTrackData(sequence=tensor([[[-22.0889, -21.9569, -22.0946,  ..., -22.0116, -22.0966, -22.0890],
         [-26.7014, -26.6767, -26.6852,  ..., -26.6674, -26.7295, -26.7216],
         [-26.7073, -26.7085, -26.7197,  ..., -26.7312, -26.7723, -26.7267],
         ...,
         [-24.4723, -24.4475, -24.4328,  ..., -24.4337, -24.5339, -24.4628],
         [-20.6860, -20.6346, -20.6401,  ..., -20.6410, -20.7196, -20.6419],
         [-22.6457, -22.6019, -22.6157,  ..., -22.6052, -22.6929, -22.6211]]]), structure=None, secondary_structure=None, sasa=None, function=None), embeddings=None, residue_annotation_logits=None)


In [4]:
print(dir(logits_output))


['__annotations__', '__attrs_attrs__', '__attrs_own_setattr__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__match_args__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', 'embeddings', 'logits', 'residue_annotation_logits']


In [5]:
logits_output.logits.sequence.shape

torch.Size([1, 121, 64])

In [6]:
ESMProtein(sequence=WT)

ESMProtein(sequence='MAGLRHTFVVADATLPDCPLVYASEGFYAMTGYGPDEVLGHNARFLQGEGTDPKEVQKIRDAIKKGEACSVRLLNYRKDGTPFWNLLTVTPIKTPDGRVSKFVGVQVDVTSKTEGKALA', secondary_structure=None, sasa=None, function_annotations=None, coordinates=None, plddt=None, ptm=None, potential_sequence_of_concern=False)

In [7]:
# Parameters
num_EnsMLPs = 100  # We have 100 reward models
num_designs = 1000
seed = 7028

# Load reward models
reward_models = []
for i in range(num_EnsMLPs):
    model_name = f"best_model_v{i}.ckpt"
    checkpoint_path = f"./MLP_Reward_Models/{model_name}"
    reward_model = load_reward_model(checkpoint_path)
    for param in reward_model.parameters():
        param.requires_grad = False
    reward_models.append(reward_model)

In [8]:
# Mask a specific position (e.g., position 5)
masked_position = 6  # Index of the position to mask (0-based index)
masked_sequence = list(WT)  # Convert WT sequence to a mutable list
masked_sequence[masked_position] = "<mask>"  # Use '*' as a mask token (or appropriate for your model)
masked_sequence = ''.join(masked_sequence)  # Convert back to a string

In [9]:
# Create ESMProtein instance with the masked sequence
protein = ESMProtein(sequence=masked_sequence)
protein

ESMProtein(sequence='MAGLRH<mask>FVVADATLPDCPLVYASEGFYAMTGYGPDEVLGHNARFLQGEGTDPKEVQKIRDAIKKGEACSVRLLNYRKDGTPFWNLLTVTPIKTPDGRVSKFVGVQVDVTSKTEGKALA', secondary_structure=None, sasa=None, function_annotations=None, coordinates=None, plddt=None, ptm=None, potential_sequence_of_concern=False)

In [10]:
# Encode the protein sequence
protein_tensor = ESMC_model.encode(protein)
protein_tensor

ESMProteinTensor(sequence=tensor([ 0, 20,  5,  6,  4, 10, 21, 32, 18,  7,  7,  5, 13,  5, 11,  4, 14, 13,
        23, 14,  4,  7, 19,  5,  8,  9,  6, 18, 19,  5, 20, 11,  6, 19,  6, 14,
        13,  9,  7,  4,  6, 21, 17,  5, 10, 18,  4, 16,  6,  9,  6, 11, 13, 14,
        15,  9,  7, 16, 15, 12, 10, 13,  5, 12, 15, 15,  6,  9,  5, 23,  8,  7,
        10,  4,  4, 17, 19, 10, 15, 13,  6, 11, 14, 18, 22, 17,  4,  4, 11,  7,
        11, 14, 12, 15, 11, 14, 13,  6, 10,  7,  8, 15, 18,  7,  6,  7, 16,  7,
        13,  7, 11,  8, 15, 11,  9,  6, 15,  5,  4,  5,  2]), structure=None, secondary_structure=None, sasa=None, function=None, residue_annotations=None, coordinates=None, potential_sequence_of_concern=False)

In [11]:
# Get logits
logits_output = ESMC_model.logits(protein_tensor, LogitsConfig(sequence=True))

# Index logits for the masked position
sequence_logits = logits_output.logits.sequence.squeeze(0)  # Remove batch dimension
masked_position_logits = sequence_logits[masked_position+1]  # Get logits for the masked position
probs = torch.nn.functional.softmax(masked_position_logits, dim=-1)

# Output results
print("Masked Sequence:", masked_sequence)
print("Probabilities for masked position:", probs)

Masked Sequence: MAGLRH<mask>FVVADATLPDCPLVYASEGFYAMTGYGPDEVLGHNARFLQGEGTDPKEVQKIRDAIKKGEACSVRLLNYRKDGTPFWNLLTVTPIKTPDGRVSKFVGVQVDVTSKTEGKALA
Probabilities for masked position: tensor([4.4926e-21, 4.7351e-21, 4.6001e-21, 3.1521e-07, 9.0963e-04, 4.9500e-02,
        3.4532e-02, 4.2323e-03, 5.8087e-01, 4.6608e-03, 2.9080e-03, 7.5115e-02,
        9.6038e-04, 1.9427e-02, 7.8629e-03, 1.2239e-03, 4.8081e-03, 1.1435e-01,
        1.8392e-03, 3.4115e-03, 1.6124e-03, 1.9138e-02, 1.2737e-03, 7.1365e-02,
        7.6950e-07, 2.3488e-11, 2.9986e-08, 1.0042e-12, 2.9761e-13, 4.1920e-21,
        4.4152e-21, 4.8907e-21, 4.5065e-21, 4.6019e-21, 4.5237e-21, 4.9978e-21,
        4.7506e-21, 5.0350e-21, 4.9012e-21, 4.8875e-21, 4.7245e-21, 4.7894e-21,
        4.3468e-21, 4.8849e-21, 4.6155e-21, 4.6724e-21, 4.7003e-21, 4.6213e-21,
        4.7450e-21, 4.6686e-21, 4.6493e-21, 4.4516e-21, 5.9318e-21, 4.7724e-21,
        4.8365e-21, 4.6127e-21, 4.8750e-21, 5.0812e-21, 5.0008e-21, 4.8811e-21,
        4.3710e-21, 4.9

In [12]:
SEQUENCE_VOCAB = [
    "<cls>", "<pad>", "<eos>", "<unk>",
    "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
    "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
    "O", ".", "-", "|",
    "<mask>",
]

# Convert SEQUENCE_VOCAB to a dictionary
token_dict = {index: token for index, token in enumerate(SEQUENCE_VOCAB)}

# Print the dictionary
for index, token in token_dict.items():
    print(f"Index {index}: Token {token}")

Index 0: Token <cls>
Index 1: Token <pad>
Index 2: Token <eos>
Index 3: Token <unk>
Index 4: Token L
Index 5: Token A
Index 6: Token G
Index 7: Token V
Index 8: Token S
Index 9: Token E
Index 10: Token R
Index 11: Token T
Index 12: Token I
Index 13: Token D
Index 14: Token P
Index 15: Token K
Index 16: Token Q
Index 17: Token N
Index 18: Token F
Index 19: Token Y
Index 20: Token M
Index 21: Token H
Index 22: Token W
Index 23: Token C
Index 24: Token X
Index 25: Token B
Index 26: Token U
Index 27: Token Z
Index 28: Token O
Index 29: Token .
Index 30: Token -
Index 31: Token |
Index 32: Token <mask>


In [13]:
from esm.sdk.forge import ESM3ForgeInferenceClient


In [14]:
ESMC_model = ESM3ForgeInferenceClient(model="esmc-6b-2024-12", url="https://forge.evolutionaryscale.ai", token="7cebzsdq955rf3p2LlRHsz")
protein_tensor = ESMC_model.encode(protein)

# Get logits
logits_output = ESMC_model.logits(protein_tensor, LogitsConfig(sequence=True))

# Index logits for the masked position
sequence_logits = logits_output.logits.sequence.squeeze(0)  # Remove batch dimension
masked_position_logits = sequence_logits[masked_position+1]  # Get logits for the masked position
probs = torch.nn.functional.softmax(masked_position_logits, dim=-1)

# Output results
print("Masked Sequence:", masked_sequence)
print("Probabilities for masked position:", probs)

Masked Sequence: MAGLRH<mask>FVVADATLPDCPLVYASEGFYAMTGYGPDEVLGHNARFLQGEGTDPKEVQKIRDAIKKGEACSVRLLNYRKDGTPFWNLLTVTPIKTPDGRVSKFVGVQVDVTSKTEGKALA
Probabilities for masked position: tensor([9.8710e-16, 9.8710e-16, 9.8710e-16, 5.0826e-01, 1.7666e-02, 1.1676e-02,
        1.3545e-02, 1.0467e-02, 1.3335e-02, 2.5108e-02, 4.8776e-02, 5.9404e-03,
        4.0195e-03, 7.3354e-03, 5.7802e-03, 6.3235e-03, 1.2237e-02, 3.2806e-03,
        5.7017e-03, 3.3290e-03, 1.4055e-03, 7.4219e-03, 4.6627e-03, 1.1676e-02,
        2.7205e-01, 3.4011e-10, 1.5750e-12, 1.6066e-10, 3.1014e-13, 9.8710e-16,
        9.8710e-16, 9.8710e-16, 9.8710e-16, 9.8710e-16, 9.8710e-16, 9.8710e-16,
        9.8710e-16, 8.7111e-16, 9.8710e-16, 9.8710e-16, 9.8710e-16, 9.8710e-16,
        8.7111e-16, 9.8710e-16, 9.8710e-16, 9.8710e-16, 9.8710e-16, 9.8710e-16,
        9.8710e-16, 9.8710e-16, 8.7111e-16, 9.8710e-16, 9.8710e-16, 9.8710e-16,
        9.8710e-16, 9.8710e-16, 9.8710e-16, 9.8710e-16, 9.8710e-16, 9.8710e-16,
        9.8710e-16, 9.8