In [1]:
import torch
import torch.nn as nn
from transformers import EsmTokenizer, EsmForSequenceClassification, EsmModel, EsmConfig
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import random
import os
import time

In [2]:
# Check if CUDA is available and set PyTorch to use GPU or CPU accordingly
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
model_name = "../models/esm2_650M"
tokenizer = EsmTokenizer.from_pretrained(model_name)

In [4]:
from peft import PeftModel, PeftConfig
config = PeftConfig.from_pretrained('../models/esm2_650M_LORA_SEQ_CLS_0.99')
model = EsmForSequenceClassification.from_pretrained("../models/esm2_650M", num_labels=2)
 
model = PeftModel.from_pretrained(model, '../models/esm2_650M_LORA_SEQ_CLS_0.99')
model.eval()
model.to(device)

Some weights of the model checkpoint at ../models/esm2_650M were not used when initializing EsmForSequenceClassification: ['lm_head.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing EsmForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at ../models/esm2_650M and are newly initialized: ['classifier.out_proj.weight', 'classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_p

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): EsmForSequenceClassification(
      (esm): EsmModel(
        (embeddings): EsmEmbeddings(
          (word_embeddings): Embedding(33, 1280, padding_idx=1)
          (dropout): Dropout(p=0.0, inplace=False)
          (position_embeddings): Embedding(1026, 1280, padding_idx=1)
        )
        (encoder): EsmEncoder(
          (layer): ModuleList(
            (0-32): 33 x EsmLayer(
              (attention): EsmAttention(
                (self): EsmSelfAttention(
                  (query): Linear(
                    in_features=1280, out_features=1280, bias=True
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.6, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1280, out_features=48, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default)

In [5]:
def get_fasta_dict(fasta_file):
    fasta_dict = {}
    with open(fasta_file, 'r') as infile:
        for line in infile:
            if line.startswith(">"):
                head = line.replace("\n", "").replace(">", "")
                fasta_dict[head] = ''
            else:
                fasta_dict[head] += line.replace("\n", "")
    return fasta_dict

## Genetic Algorithms
<font size=3>We employ property-based model to filter sequences within the positive range as preferred offsprings  
<font size=3>It takes about 19 minutes to generate sequences.

In [6]:
test_dict=get_fasta_dict('../database/LBD.fasta')
initial_pseqs=[]
initial_nseqs=[]
for i in test_dict.keys():
    if i.split('|')[-1] == 'positive':
        initial_pseqs.append(test_dict[i])
    else:
        initial_nseqs.append(test_dict[i])
print(len(initial_pseqs),len(initial_nseqs))

10 20


In [7]:
initial_pt = tokenizer(initial_pseqs, return_tensors='pt', padding="max_length", truncation=True, max_length=24).to("cuda")
initial_nt = tokenizer(initial_nseqs, return_tensors='pt', padding="max_length", truncation=True, max_length=24).to("cuda")
p_outputs=model.esm(**initial_pt,output_attentions=True,output_hidden_states=True)
n_outputs=model.esm(**initial_nt,output_attentions=True,output_hidden_states=True)
initial_pe=p_outputs.last_hidden_state.mean(1)
initial_ne=n_outputs.last_hidden_state.mean(1)

In [8]:
amino_acid = {
    4: 'L', 
    5: 'A', 
    6: 'G', 
    7: 'V', 
    8: 'S', 
    9: 'E', 
    10: 'R', 
    11: 'T', 
    12: 'I', 
    13: 'D',
    14: 'P', 
    15: 'K', 
    16: 'Q', 
    17: 'N', 
    18: 'F', 
    19: 'Y', 
    20: 'M', 
    21: 'H', 
    22: 'W', 
    23: 'C'}
hydrophobicity={
    4: 1.700,
    5: 0.310,
    6: 0.,
    7: 1.220,
    8: -0.040,
    9: -0.640,
    10: -1.010,
    11: 0.260,
    12: 1.800,
    13: -0.770,
    14: 0.720,
    15: -0.990,
    16: -0.220,
    17: -0.600,
    18: 1.790,
    19: 0.960,
    20: 1.230,
    21: 0.130,
    22: 2.250,
    23: 1.540
}
dic_new = dict(zip(amino_acid.values(), amino_acid.keys()))

In [9]:
seq_ids=[]
for seq,ids in dic_new.items():
    seq_ids.append(ids)
print(min(seq_ids),max(seq_ids))

4 23


In [10]:
def filter_func(x):
    """Use the fine-tuned model to get the index of the physicochemistry properties in changes
        First, wrap the input tokens
        Then, use the fine-tuned model to predict the physicochemistry properties
        Finally, filter the in distribution embeddings and return the index
    """
    new_population=add(x)
    #Wrap the input tokens
    attention_mask=np.ones((20,24))
    attention_mask=attention_mask.astype(int)
    new_population=torch.tensor(new_population).to(device)
    attention_mask=torch.tensor(attention_mask).to(device)
    initial_nt_change={}
    initial_nt_change={'input_ids':new_population,'attention_mask':attention_mask}
    #Extract physicochemical properties in embeddings
    with torch.no_grad():
        outputs = model(**initial_nt_change)
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=1)
        predictions=predictions.cpu()
    index=np.where(predictions==0)
    
    return index[0]

def add(pop):
    """Add special tokens and Cys token to data"""
    new_population=list()
    for i in pop:
        pop=np.insert(i,0,[0,23])
        pop=np.insert(pop,len(pop),[23,2])
        new_population.append(pop)
    new_population=np.array(new_population)
    return new_population

def crossover(parents, offspring_size):
    offspring = np.empty(offspring_size)
    # The point at which crossover takes place between two parents. Usually, it is at the center.
    crossover_point = np.uint8(offspring_size[1]/2)

    for k in range(offspring_size[0]):
        # Index of the first parent to mate.
        parent1_idx = k%parents.shape[0]
        # Index of the second parent to mate.
        parent2_idx = (k+1)%parents.shape[0]
        # The new offspring will have its first half of its genes taken from the first parent.
        offspring[k, 0:crossover_point] = parents[parent1_idx, 0:crossover_point]
        # The new offspring will have its second half of its genes taken from the second parent.
        offspring[k, crossover_point:] = parents[parent2_idx, crossover_point:]
    return offspring

def mutation(offspring_crossover, num_mutations=1):
    mutations_counter = np.uint8(offspring_crossover.shape[1] / num_mutations)
    # Mutation changes a number of genes as defined by the num_mutations argument. The changes are random.
    for idx in range(offspring_crossover.shape[0]):
        gene_idx = mutations_counter - 1
        for mutation_num in range(num_mutations):
            # The random value to be added to the gene.
            random_value = np.random.randint(min(seq_ids),max(seq_ids),1)
            offspring_crossover[idx, gene_idx] = random_value
            gene_idx = gene_idx + mutations_counter
    return offspring_crossover

In [11]:
equation_inputs=initial_pt['input_ids']
sol_per_pop = len(initial_nt['input_ids'])
num_parents_mating = 8
population=initial_nt['input_ids'][:,2:-2].cpu().numpy()
num_weights=len(population)
pop_size=(sol_per_pop,num_weights)

In [None]:
pp_in=np.zeros((1,len(population[0])))
num_generations = 10000
for generation in tqdm(range(num_generations)):
    time.sleep(0.05)
    # Measuring the fitness of each chromosome in the population.
    index = filter_func(population)
    population_filter=list()
    parents=list()
    for i,token in enumerate(population):
        if i in index:
            population_filter.append(token)
        else:
            parents.append(token)
    
    population_filter=np.array(population_filter)
    parents=np.array(parents)
    if population_filter.size > 0:
        for row in population_filter:
            if not np.any(np.all(pp_in == row, axis=0)):
                pp_in = np.vstack((pp_in, row))
    
    # Generating next generation using crossover.
    if population_filter.size > 0:
        offspring_crossover = crossover(population_filter,
                                       offspring_size=(pop_size[0]-parents.shape[0], num_weights))
    else:
        offspring_crossover = crossover(parents,
                                       offspring_size=(pop_size[0]-parents.shape[0], num_weights))

    # Adding some variations to the offspring using mutation.
    offspring_mutation = mutation(offspring_crossover, num_mutations=3)

    # Creating the new population based on the parents and offspring.
    population[0:parents.shape[0], :] = parents
    population[parents.shape[0]:, :] = offspring_mutation

  0%|          | 0/10000 [00:00<?, ?it/s]

In [None]:
pp_in=pp_in.copy()
pp_in=np.delete(pp_in,slice(0,11),axis=0)
new_population=list()
for i in pp_in:
    pp_in=np.insert(i,0,23)
    pp_in=np.insert(pp_in,len(pp_in),23)
    new_population.append(pp_in)
new_population=np.array(new_population)

In [None]:
# f1=open('./ouputs/LBD_mutation.fasta','w')
# for i,token in enumerate(new_population):
#     token=token.astype(int)
#     seq=''
#     for j in token:
#         seq+=''.join(amino_acid[j])
#     f1.write('>'+str(i)+'\n'+seq+'\n')
# f1.close()

## filter the mutations

In [None]:
from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable
import esm
import torch.nn.functional as F

In [None]:
def symmetrize(x):
    "Make layer symmetric in final two dimensions, used for contact prediction."
    return x + x.transpose(-1, -2)

def apc(x):
    "Perform average product correct, used for contact prediction."
    a1 = x.sum(-1, keepdims=True)
    a2 = x.sum(-2, keepdims=True)
    a12 = x.sum((-1, -2), keepdims=True)

    avg = a1 * a2
    avg.div_(a12)  # in-place to reduce memory
    normalized = x - avg
    return normalized

In [None]:
#Adapted from https://github.com/facebookresearch/esm/blob/main/esm/modules.py
class AttentionLogisticRegression(nn.Module):
    """Performs symmetrization, apc, and computes a logistic regression on the output features"""

    def __init__(
        self,
        in_features:int,
        prepend_bos: bool,
        append_eos: bool,
        bias=True,
        eos_idx: Optional[int] = None,
    ):
        super().__init__()
        self.in_features=in_features
        self.prepend_bos = prepend_bos
        self.append_eos = append_eos
        if append_eos and eos_idx is None:
            raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
        self.eos_idx = eos_idx
        self.regression = nn.Linear(in_features, 1, bias)
        self.activation = nn.Sigmoid()
    
    def forward(self, tokens, attentions):
        # remove eos token attentions
        if self.append_eos:
            eos_mask = tokens.ne(self.eos_idx).to(attentions)
            eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
            attentions = attentions * eos_mask[:, None, None, :, :]
            attentions = attentions[..., :-1, :-1]
        # remove cls token attentions
        if self.prepend_bos:
            attentions = attentions[..., 1:, 1:]
        batch_size, layers, heads, seqlen, _ = attentions.size()
        attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
        attentions = attentions.to(self.regression.weight.device)  # attentions always float32, may need to convert to float16
        attentions= apc(symmetrize(attentions))
        attentions = attentions.permute(0, 2, 3, 1)
        
        return self.activation(self.regression(attentions).squeeze(3))

In [None]:
mymodel=torch.load('../models/contact-based model.pt')
mymodel.eval()

### Get bceloss

In [None]:
def get_pred_contact(attention):
    attention=torch.where(attention < 0.9, torch.tensor(0), torch.tensor(1))
    
    return attention

In [None]:
def get_bceloss(test_output):
    loss_list=[]
    for number,test in enumerate(test_output):
        new_pred=get_pred_contact(test)
        loss=criterion(torch.tensor(lbdb,dtype=torch.float32),torch.tensor(new_pred,dtype=torch.float32))
        loss_list.append(loss)
    
    return loss_list

In [None]:
criterion = nn.BCELoss()

In [None]:
lbdin_dict=get_fasta_dict('../database/LBD.fasta')
lbdin_seqs=[]
for header,seq in lbdin_dict.items():
    lbdin_seqs.append(seq)

lbdin_inputs = tokenizer(lbdin_seqs, return_tensors='pt', padding="max_length", truncation=True, max_length=24).to(device)
lbdin_outputs=model.esm(**lbdin_inputs,output_attentions=True,output_hidden_states=True)
lbdin_attention=torch.stack(lbdin_outputs.attentions,1)
lbdin_outputs=mymodel(lbdin_inputs['input_ids'],lbdin_attention).cpu()

In [None]:
lbdsy_dict=get_fasta_dict('../database/LBD_test.fasta')
lbdsy_seqs=[]
for header,seq in lbdsy_dict.items():
    lbdsy_seqs.append(seq)

lbdsy_inputs = tokenizer(lbdsy_seqs, return_tensors='pt', padding="max_length", truncation=True, max_length=24).to(device)
lbdsy_outputs=model.esm(**lbdsy_inputs,output_attentions=True,output_hidden_states=True)
lbdsy_attention=torch.stack(lbdsy_outputs.attentions,1)
lbdsy_outputs=mymodel(lbdsy_inputs['input_ids'],lbdsy_attention).cpu()

In [None]:
lbdb=lbdsy_outputs[1]

In [None]:
mutation_dict=get_fasta_dict('../outputs/LBD_mutation.fasta')
seqs=[]
for header,seq in mutation_dict.items():
    seqs.append(seq)
inputs = tokenizer(seqs, return_tensors='pt', padding="max_length", truncation=True, max_length=24).to(device)
outputs=model.esm(**inputs,output_attentions=True,output_hidden_states=True)
mutation_attention=torch.stack(outputs.attentions,1)
mutation_outputs=mymodel(inputs['input_ids'],mutation_attention).cpu()
mutation_contact=get_pred_contact(mutation_outputs)

#### Plan A

In [None]:
def filter_pred_contact(test_output):
    index_list=[]
    for index,attention in enumerate(test_output):
        pred_contact=get_pred_contact(attention)
        if pred_contact[16][18] ==1 & pred_contact[13][16] ==1:
            index_list.append(index)
    return index_list

In [None]:
mutation_index=filter_pred_contact(mutation_outputs)
pp_filter=mutation_outputs[mutation_index]
loss_list=get_bceloss(pp_filter)

In [None]:
bceloss=torch.tensor(loss_list)
values,indices=torch.topk(bceloss,k=5,largest=False)
for i in indices:
    seq=mutation_dict[list(mutation_dict.keys())[mutation_index[i]]]
    print(seq)

#### Plan B

In [None]:
class ContactMapRegression(nn.Module):

    def __init__(
        self,
        in_features:int,
        bias=True,
    ):
        super().__init__()
        self.in_features=in_features
        self.regression = nn.Linear(in_features, 2, bias)
        self.activation = nn.Sigmoid()
    
    def forward(self,contact_map):
        contact_map=contact_map.reshape((contact_map.shape[0],contact_map.shape[1]*contact_map.shape[1]))
        outputs=self.activation(self.regression(contact_map))
        return outputs

In [None]:
contactmodel=torch.load('../models/contactmap_filter_planb.pt')
contactmodel.eval()

In [None]:
pe_outputs=torch.cat([lbdin_outputs[:8],lbdsy_outputs[:6]],dim=0)
pe_contact=get_pred_contact(pe_outputs)

In [None]:
def test_bceloss(seq):
    lbd2=get_pred_contact(seq)
    
    loss_list=[]
    for lbd in lbd2:
        loss=0
        for i in pe_contact:
            loss+=criterion(torch.tensor(i,dtype=torch.float32),torch.tensor(lbd,dtype=torch.float32)).item()
        average_loss = loss / len(pe_contact)
        loss_list.append(average_loss)
    
    return loss_list

In [None]:
mutation_pred_contacts=get_pred_contact(mutation_outputs)
mutation_pred_contacts=torch.tensor(mutation_pred_contacts,dtype=torch.float32).to(device)
contact_outputs = contactmodel(mutation_pred_contacts)
_,contact_prediction=contact_outputs.max(dim=1)
index=np.where(contact_prediction.cpu()==1)
pp_filter=mutation_outputs[index[0]]
loss_list=test_bceloss(pp_filter)

In [None]:
bceloss=torch.tensor(loss_list)
values,indices=torch.topk(bceloss,k=5,largest=False)
for i in indices:
    seq=mutation_dict[list(mutation_dict.keys())[index[0][i]]]
    print(seq)