In [1]:
# Import packages
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from collections import OrderedDict
from torchtext import vocab # This package can give problems sometimes, it may be necessary to downgrade to a specific version
import seaborn as sns
import random
from random import choice
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.patches as mpatches
from sklearn import metrics
import os
import pickle
from transformers import AutoModelForMaskedLM, AutoTokenizer
import itertools
import copy
import warnings
import optuna
import logging
import sys
from torch_ema import ExponentialMovingAverage
import gc

In [2]:
# Import helper scripts
from functions import (load_reward_model, identify_mutations_and_count, generate_df, generate_and_evaluate_mutants, mutate_sequences_after_training, mutate_sequences_after_training_esm2_max_sampling)
from dataloading_RLXF_ESM2_DDP import (ProtDataModuleESM2, ProtRepDatasetESM2)
from PPO_with_psampling_and_model_saving import RLXF_PPO_ESM2
from MLP import MLP

In [3]:
# Define amino acid dictionary for tokenization, define WT for length of context window
AAs = 'ACDEFGHIKLMNPQRSTVWY' # setup torchtext vocab to map AAs to indices, usage is aa2ind(list(AAsequence))
aa2ind = vocab.vocab(OrderedDict([(a, 1) for a in AAs]))
aa2ind.set_default_index(20) # set unknown charcterers to gap

In [4]:
################################################## hyperparameters ##################################################
# model selections
model_identifier ='esm2_t12_35M_UR50D' # esm2_t6_8M_UR50D # esm2_t12_35M_UR50D # esm2_t30_150M_UR50D # esm2_t33_650M_UR50D
sft_model = AutoModelForMaskedLM.from_pretrained(f"facebook/{model_identifier}")
rl_updated_model = AutoModelForMaskedLM.from_pretrained(f"facebook/{model_identifier}")
tokenizer = AutoTokenizer.from_pretrained(f"facebook/{model_identifier}")
num_reward_models = 100 # We have an ensemble of 100 MLP reward models
sft_model_path = None # 'SFT_ESM2_650M_with_data_v5_and_random_masking_v0.pt' # path to model to begin PPO (.pt filetype)

# learning rate hyperparameters
learning_rate = 0.0004
lr_mult = 1
lr_mult_factor = 1
warm_restart = 1 # with warm restart
use_scheduler = 1 # with scheduler

# optimizer hyperparameters
WD = 0.01
clip_type = 1 # with gradient clipping
grad_clip_threshold = 3.9369403420488362
grad_clip_threshold_factor = 1

# training hyperparameters
seed = 2549
batch_size = 1 # Loading WT to dataloader, we generate variant designs each batch so only load WT initially to models
epochs = 1000
iterations = 1
num_updates = int((epochs/100)*iterations) # First restart occurs at 10 epochs (backprop will have occured 10*iterations times)

# generating design hyperparameters
WT = 'MAGLRHTFVVADATLPDCPLVYASEGFYAMTGYGPDEVLGHNARFLQGEGTDPKEVQKIRDAIKKGEACSVRLLNYRKDGTPFWNLLTVTPIKTPDGRVSKFVGVQVDVTSKTEGKALA' # CreiLOV
num_sequences = 2 # 10 # initial batch size
inc_batch_size = 3 # increasing batch size each epoch until max_batch_size reached
max_batch_size = 2 # 20 # max batch size (dependent on GPU memory)
num_mutations = 3 # number of mutations in generated
high_conf_threshold = 0.9 # initial probability threshold to be considered high confidence mutation
cum_prob_threshold = 0.25 # initial cumulative probability threshold of non-WT resides to be considered candidate position to explore mutating

# model dependent hyperparameters
num_unfrozen_layers = 27 # initial number of layers of ESM2 unlocked
num_layers_unfreeze_each_epoch = 17 # numbers of layers of ESM2 to unlock each epoch until max_num_layers_unfreeze_each_epoch reached
max_num_layers_unfreeze_each_epoch = 82 # The max number of layers in ESM2 (650M) that will be aligned cannot exceed 82 -> We can go to at least 71 with bs = 10 on our GPU's @ Duke
training_pos_emb = 0 # do not train positional embeddings

# important PPO hyperparameters
average_type = 2
average_type_loss = 0
rel_to_WT = 0
epsilon = 0.25 # clipping parameter for PPO loss

# total reward hyperparameters
pairwise_hd_aver_factor = 0.01 # weight for pairwise hamming distance between generated designs each epoch
dkl_scale_init = 1e-8 # initial weight for Dkl
dkl_scale = 1e-7 # weight term for Dkl after 1st epoch

# hyparameters regarding model saving
decay = 0.8
saving_models_threshold = 10 # do not save models if at 10 # 1.01812135525 # 4.225/4.1498 = generated design fitness / predicted WT fitness
filepath = 'toy_Aligning_ESM2_from_SFT_ESM2_with_SA_preference_data_and_psampling_PPO'

################################################## hyperparameters ##################################################

if sft_model_path is not None:
    # Begin PPO with 2 copies of supervised fine-tuned models
    state_dict = torch.load(sft_model_path)
    sft_model.load_state_dict(state_dict)
    rl_updated_model.load_state_dict(state_dict)
    for param in sft_model.parameters():
        param.requires_grad = False
    print(f'Aligning supervised fine-tuned model from {sft_model_path}')
else:
    # Begin PPO with 2 copies of pretrained models
    for param in sft_model.parameters():
        param.requires_grad = False
    print(f'Aligning {model_identifier} model from huggingface')

# Load models
reward_models = []
for i in range(num_reward_models):
    model_name = f"best_model_v{i}.ckpt"
    checkpoint_path = f"./MLP_Reward_Models/{model_name}"
    reward_model = load_reward_model(checkpoint_path)
    for param in reward_model.parameters():
        param.requires_grad = False
    reward_models.append(reward_model)

# Determine if we're training on a GPU or CPU
if torch.cuda.is_available():
    # Make models reproducible on GPU
    os.environ['PYTHONHASHSEED'] = str(seed) # Set the PYTHONHASHSEED environment variable to the chosen seed to make hash-based operations predictable
    np.random.seed(seed) # Set NumPy's random seed to ensure reproducibility of operations using NumPy's random number generator
    random.seed(seed) # Set Python's built-in random module's seed to ensure reproducibility of random operations using Python's random functions
    np.random.seed(seed)
    torch.manual_seed(seed) # Set the seed for generating random numbers in PyTorch to ensure reproducibility on the CPU
    torch.cuda.manual_seed(seed) # Set the seed for generating random numbers in PyTorch to ensure reproducibility on the GPU
    torch.cuda.manual_seed_all(seed) # Ensure reproducibility for all GPUs by setting the seed for generating random numbers for all CUDA devices
    torch.backends.cudnn.deterministic = True # Force cuDNN to use only deterministic convolutional algorithms (can slow down computations but guarantees reproducibility)
    torch.backends.cudnn.benchmark = False # Prevent cuDnn from using any algorithms that are nondeterministic
    torch.set_float32_matmul_precision('medium')
    accelerator = "gpu"
    num_devices = torch.cuda.device_count()  # Use all available GPUs
    strategy = "ddp" if num_devices > 1 else None  # Use DDP if multiple GPUs
    print(f"Accelerator: {accelerator}, Number of devices: {num_devices}, Strategy: {strategy}")
else:
    # fix random seeds for reproducibility on CPU
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    accelerator = "cpu"
    max_threads = 16
    num_threads = min(os.cpu_count(), max_threads)  # Use all available CPUs up to a maximum of 16
    torch.set_num_threads(num_threads)  # Set the number of threads for PyTorch
    num_devices = 1  # Use the CPU
    strategy = None
    print(f"Accelerator: {accelerator}, Number of threads: {num_threads}, Strategy: {strategy}")

# Define logger for storing model metrics
logger = CSVLogger('logs', name=f"{filepath}")
version = logger.version

# Initialize the RLXF model
dm = ProtDataModuleESM2(WT, batch_size, seed)
model = RLXF_PPO_ESM2(model_identifier, sft_model, rl_updated_model, reward_models, tokenizer, num_reward_models, sft_model_path, # model selections
                num_unfrozen_layers, num_layers_unfreeze_each_epoch, max_num_layers_unfreeze_each_epoch, training_pos_emb, # model dependent hyperparameters
                seed, batch_size, epochs, iterations, num_updates, # training hyperparameters
                learning_rate, lr_mult, lr_mult_factor, use_scheduler, warm_restart, # learning rate hyperparameters
                WD, clip_type, grad_clip_threshold, grad_clip_threshold_factor, # optimizer hyperparameters
                WT, num_sequences, inc_batch_size, max_batch_size, num_mutations, high_conf_threshold, cum_prob_threshold, # generating design hyperparameters
                average_type_loss, average_type, rel_to_WT, epsilon, # important PPO hyperparameters
                pairwise_hd_aver_factor, dkl_scale, dkl_scale_init, # total reward hyperparameters
                decay, saving_models_threshold, filepath, version # hyparameters regarding model saving
                     )

# Trainer setup in PyTorch Lightning
trainer = pl.Trainer(
    logger=logger,
    max_epochs=epochs,
    precision=16 if accelerator == "gpu" else 32,  # Mixed precision only on GPU
    enable_progress_bar=True,
    log_every_n_steps=1,
    accelerator=accelerator,
    num_nodes=1,
    devices=num_devices,
    strategy=strategy
)

trainer.fit(model, dm)

Aligning esm2_t12_35M_UR50D model from huggingface


GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name             | Type           | Params
----------------------------------------------------
0 | fixed_model      | EsmForMaskedLM | 34.0 M
1 | rl_updated_model | EsmForMaskedLM | 34.0 M
----------------------------------------------------
34.0 M    Trainable params
34.0 M    Non-trainable params
68.0 M    Total params
271.951   Total estimated model params size (MB)


Accelerator: cpu, Number of threads: 8, Strategy: None
Loading data to CPU


Training: 0it [00:00, ?it/s]

iteration 1
Saved heatmap for single mutant space from WT for sft model
Saved heatmap for single mutant space from WT for aligned model
Generated sequence with high confidence mutations from fixed model: {43: [('C', 0.95703125)]}
Saved heatmap for single mutant space from sequence with high-confidence mutations for sft model
Generated sequence with high confidence mutations from aligned model: {43: [('C', 0.95703125)]}
Generated ratios for high confidence mutations from aligned model


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/Users/nathanielblalock/miniconda3/envs/RLXF/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/bx/wt_7f93n19z_0q92_mfgy2y80000gn/T/ipykernel_60398/635758733.py", line 147, in <module>
    trainer.fit(model, dm)
  File "/Users/nathanielblalock/miniconda3/envs/RLXF/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in fit
    call._call_and_handle_interrupt(
  File "/Users/nathanielblalock/miniconda3/envs/RLXF/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 54, in _call_and_handle_interrupt
    logger.finalize("failed")
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/nathanielblalock/miniconda3/envs/RLXF/lib/python3.8/inspect.py", line 737, in getmodule
    file = getabsfile(object, _filename)
  File "

In [None]:
# # Save the model, appending the device name and version number to the filename
# model.save_rl_updated_esm2()
