In [2]:
#import esm
import torch
import numpy as np
import pandas as pd
import loralib as lora
import esmLoRA.esm as el
from esmLoRA.esm.model.esm2 import ESM2
import pytorch_lightning as pl
import math
from Bio import SeqIO
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.nn import functional as F
from tensordict import TensorDict

import random



In [3]:
# Load ESM model

#model used on this exemple available on git, most recent version 
pretrained_model, alphabet = el.pretrained.esm2_t6_8M_UR50D()

# Load batch converter
batch_converter = alphabet.get_batch_converter()

# Disable dropout for deterministic results
pretrained_model.eval()

# Save weights from pretrained model
# PATH = 'esm2_t6_8M_UR50D_pretrained_checkpoint.pt'
# torch.save(pretrained_model.state_dict(), PATH)

# Load previously saved weights, should all match
#pretrained_model.load_state_dict(torch.load(PATH), strict=False)

# Load model using ESM2 framework (edited version of ESM2)
model = ESM2(num_layers=6, embed_dim=320, attention_heads=20, token_dropout=False)

# Load weights from pretrained model
# Set path (same as saved above)
PATH = 'esm2_t6_8M_UR50D_pretrained_checkpoint.pt'

# Load weights to ESM2 model (Should error with missing weights from LoRA)
model.load_state_dict(torch.load(PATH), strict=False)

#Generate mapping for tokens to amino acids

# Pull mapping from alphabet
token_mapping = alphabet.to_dict() # {amino_acid: token_id}

# Reverse the mapping
token_mapping = {v: k for k, v in token_mapping.items()} # {token_id: amino_acid}


In [4]:
# # Test alphabet encode, print tokens 
# input_tokens = torch.Tensor([alphabet.encode("<cls>MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG<eos><pad><pad><pad><pad>")]).to(torch.int64)
# input_tokens


In [28]:
# Test encoding with masked values

#Set input tokens that include masked values
input_tokens = torch.Tensor([alphabet.encode("<cls>M<mask>TVRQERLKSIVRILER<mask>SKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRG<mask>YVLA<mask>G<eos><pad><pad><pad><pad>")], )
#input_tokens.to(torch.int64)

# Pull logits from model based on input tokens above
var = model.forward(input_tokens.int(), return_contacts = False)['logits']
print(var[0][0])
print(len(var[0][0]))

# Make prediction using argmax
prediction = var[0,:].argmax(dim=1)

# Reconstruct the sequence using the token mapping
reconstructed = ''
for token in prediction:
    reconstructed += token_mapping[token.item()]
reconstructed


tensor([ 1.5450e+01, -8.9811e+00, -6.8793e+00, -8.9750e+00, -1.4604e-01,
         5.2747e-01,  4.9740e-02,  3.4032e-01,  1.5516e-01, -1.9002e-01,
         3.1542e-01,  3.6667e-01, -3.8365e-01, -8.8549e-01, -2.9278e+00,
        -4.5261e-01, -6.4824e-03, -1.6359e+00, -1.3195e+00, -2.1832e-01,
         1.9455e+00, -1.8578e+00, -2.6934e+00, -3.6608e-01, -4.1013e-01,
        -1.1099e+01, -1.1333e+01, -1.1781e+01, -1.4696e+01, -1.5194e+01,
        -1.5063e+01, -1.5312e+01, -8.9679e+00], grad_fn=<SelectBackward0>)
33


'<cls>MSTVRQERLKSIVRILERLSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGGYVLAEG<eos>ARRG'

In [305]:
#weight_names = ['layers.0.fc1.lora_A', 'layers.0.fc1.lora_B', 'layers.0.fc2.lora_A', 'layers.0.fc2.lora_B', 'layers.1.fc1.lora_A', 'layers.1.fc1.lora_B', 'layers.1.fc2.lora_A', 'layers.1.fc2.lora_B', 'layers.2.fc1.lora_A', 'layers.2.fc1.lora_B', 'layers.2.fc2.lora_A', 'layers.2.fc2.lora_B', 'layers.3.fc1.lora_A', 'layers.3.fc1.lora_B', 'layers.3.fc2.lora_A', 'layers.3.fc2.lora_B', 'layers.4.fc1.lora_A', 'layers.4.fc1.lora_B', 'layers.4.fc2.lora_A', 'layers.4.fc2.lora_B', 'layers.5.fc1.lora_A', 'layers.5.fc1.lora_B', 'layers.5.fc2.lora_A', 'layers.5.fc2.lora_B']
weight_names = ['model.layers[0].fc1.lora_A', 'model.layers[0].fc1.lora_B', 'model.layers[0].fc2.lora_A', 'model.layers[0].fc2.lora_B', 'model.layers[1].fc1.lora_A', 'model.layers[1].fc1.lora_B', 'model.layers[1].fc2.lora_A', 'model.layers[1].fc2.lora_B', 'model.layers[2].fc1.lora_A', 'model.layers[2].fc1.lora_B', 'model.layers[2].fc2.lora_A', 'model.layers[2].fc2.lora_B', 'model.layers[3].fc1.lora_A', 'model.layers[3].fc1.lora_B', 'model.layers[3].fc2.lora_A', 'model.layers[3].fc2.lora_B', 'model.layers[4].fc1.lora_A', 'model.layers[4].fc1.lora_B', 'model.layers[4].fc2.lora_A', 'model.layers[4].fc2.lora_B', 'model.layers[5].fc1.lora_A', 'model.layers[5].fc1.lora_B', 'model.layers[5].fc2.lora_A', 'model.layers[5].fc2.lora_B']



#for item in weight_names:
#    lora.nn.init.normal_(eval(item), std=0.02)


lora_before_training = [model.layers[0].fc1.lora_A, model.layers[0].fc1.lora_B, model.layers[0].fc2.lora_A, model.layers[0].fc2.lora_B, 
                        model.layers[1].fc1.lora_A, model.layers[1].fc1.lora_B, model.layers[1].fc2.lora_A, model.layers[1].fc2.lora_B, 
                        model.layers[2].fc1.lora_A, model.layers[2].fc1.lora_B, model.layers[2].fc2.lora_A, model.layers[2].fc2.lora_B,
                        model.layers[3].fc1.lora_A, model.layers[3].fc1.lora_B, model.layers[3].fc2.lora_A, model.layers[3].fc2.lora_B,
                        model.layers[4].fc1.lora_A, model.layers[4].fc1.lora_B, model.layers[4].fc2.lora_A, model.layers[4].fc2.lora_B,
                        model.layers[5].fc1.lora_A, model.layers[5].fc1.lora_B, model.layers[5].fc2.lora_A, model.layers[5].fc2.lora_B]

print(lora_before_training)


[Parameter containing:
tensor([[ 0.0085, -0.0014, -0.0319,  ...,  0.0292, -0.0222,  0.0454],
        [ 0.0132, -0.0434, -0.0193,  ..., -0.0275, -0.0431,  0.0065],
        [-0.0115,  0.0435, -0.0405,  ..., -0.0415, -0.0135, -0.0348],
        ...,
        [ 0.0573, -0.0326,  0.0628,  ...,  0.0367,  0.0041,  0.0364],
        [ 0.0218, -0.0299,  0.0414,  ..., -0.0203,  0.0226,  0.0147],
        [-0.0314,  0.0071,  0.0014,  ..., -0.0147, -0.0373,  0.0111]],
       requires_grad=True), Parameter containing:
tensor([[-0.0117, -0.0116, -0.0118,  ...,  0.0107,  0.0113, -0.0112],
        [-0.0119, -0.0118, -0.0112,  ..., -0.0114, -0.0110, -0.0069],
        [ 0.0110, -0.0107,  0.0092,  ...,  0.0119,  0.0120, -0.0113],
        ...,
        [ 0.0119,  0.0119,  0.0119,  ...,  0.0116,  0.0051, -0.0119],
        [ 0.0118,  0.0118,  0.0119,  ...,  0.0119,  0.0119,  0.0110],
        [-0.0116, -0.0056, -0.0115,  ..., -0.0118, -0.0118, -0.0022]],
       requires_grad=True), Parameter containing:
tensor([[

In [7]:
# Import data

# Read fasta file to csv with SeqIO
fasta_file = 'data/uniprotkb_reviewed_true_AND_annotation_2023_07_12.fasta'

with open(fasta_file) as fasta_file:  # Will close handle cleanly
    id = []
    sequence = []
    for seq_record in SeqIO.parse(fasta_file, 'fasta'):  # (generator)
        id.append(seq_record.id)
        sequence.append(str(seq_record.seq))

tuples = list(zip(id, sequence))

# Remove length > 1022
tuples = [t for t in tuples if len(t[1]) <= 1022]

# create dataframe
# df = pd.DataFrame(tuples, columns=['id', 'sequence'])
# df = df[df['sequence'].str.len() <= 1022]

data = tuples
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
len(batch_tokens)

12037

In [302]:
# forward + backward + optimize
input = batch_tokens[0:1000]
y = input.clone()
mask_index_list = []

# Mask 15% of the sequence
for i in range(len(y)):
    # Remove all padding, cls, eos tokens
    mask_arr = (y[i] != model.padding_idx) * (y[i] != model.cls_idx) * (y[i] != model.eos_idx)
    short_seq = y[i][mask_arr]
    masking_index = random.sample(range(len(short_seq)), int(0.15*len(short_seq)))
    # Add one value for each index to account for skipped BOS token in original sequence
    masking_index = [i + 1 for i in masking_index]
    # Randomly input old character, mask, or random character
    mask = np.random.choice(['mask', 'old', 'random'], len(masking_index), p=[0.8, 0.1, 0.1])

    # Lists masking_index and mask to dictionary
    mask_and_index = dict(zip(masking_index, mask))
    
    # Loop through dictionary and replace values in y
    for masking_index, mask in mask_and_index.items():
        if mask == 'old':
            continue
        if mask == 'random':
            # generate random token
            y[i][masking_index] = np.random.choice(range(4,29))
        if mask == 'mask':
            y[i][masking_index] = model.mask_idx

    mask_index_list.append(list(mask_and_index.keys()))

# Pad and set mask_index_list to tensor
mask_index = []
for i in range(len(mask_index_list)):
    mask_index.append(torch.tensor(mask_index_list[i]))
mask_index = torch.nn.utils.rnn.pad_sequence(mask_index, batch_first=True, padding_value=0)

# Set x to be one-hot encoded amino acid at each position
target = torch.zeros(len(input), len(input[0]), len(alphabet.all_toks))

for i in range(len(input)):
    for j in range(len(input[i])):
        # print("i: " + str(i) + " j: " + str(j) + " input[i][j]: " + str(int(input[i][j])))
        target[i][j][int(input[i][j])] = 1

# Edit print option for testing
#torch.set_printoptions(profile="full")
torch.set_printoptions(profile="default")

print(len(target))
print(len(y))
print(len(mask_index))

data_dict = {'target': target, 'data': y, 'mask_index': mask_index}

data_inputs = []
#Generate a list of dictionaries for each sequence
for i in range(len(data_dict['target'])):
    data_inputs.append({'target': data_dict['target'][i], 'data': data_dict['data'][i], 'mask_index': data_dict['mask_index'][i]})


1000
1000
1000


In [303]:
### TRAINING! ###

# Set up model for Training LoRA Weights
lora.mark_only_lora_as_trainable(model)

# Select device and move model to device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

# Set criterion and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Setup Dataloader
dataset = data_inputs
train, val = random_split(dataset, [math.ceil(0.9*len(dataset)), math.floor(0.1*len(dataset))])

train_loader = DataLoader(train, batch_size=32)
val_loader = DataLoader(val, batch_size=32)


In [352]:
# Training Loop
for epoch in range(2):
    batch_number = 0
    running_loss = 0.0
    print("Working on epoch number: ", epoch+1)
    for batch in train_loader:
        batch_number+=1
        print("Working on batch number: ", batch_number)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        y_hat = batch['target']
        y = batch['data']
        mask_index = batch['mask_index']
        
        # Generate outputs
        y = model(y)['logits']
        
        # Pad location of unmasked residues in y and y_hat
        for i in range(len(mask_index)):
            y_hat[i][mask_index[i][mask_index[i] != 0]] = 1
            y[i][mask_index[i][mask_index[i] != 0]] = 1

        # Calculate loss and backpropagate
        # for i in mask_index:
        #     print(i)
        # print(torch.nn.CrossEntropyLoss(reduction='none')(y, y_hat))

        loss = criterion(y, y_hat)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        print(f'[{epoch + 1}, {batch_number}] batch loss: {loss.item():.3f} running loss: {running_loss:.3f}')
    running_loss = 0.0

print('Finished Training')

Working on epoch number:  1
Working on batch number:  1
[1, 1] batch loss: 785.916 running loss: 785.916
Working on batch number:  2
[1, 2] batch loss: 779.286 running loss: 1565.203
Working on batch number:  3
[1, 3] batch loss: 739.778 running loss: 2304.981
Working on batch number:  4
[1, 4] batch loss: 741.516 running loss: 3046.497
Working on batch number:  5
[1, 5] batch loss: 735.762 running loss: 3782.259
Working on batch number:  6
[1, 6] batch loss: 757.762 running loss: 4540.021
Working on batch number:  7
[1, 7] batch loss: 718.927 running loss: 5258.949
Working on batch number:  8
[1, 8] batch loss: 707.476 running loss: 5966.425
Working on batch number:  9


KeyboardInterrupt: 

In [350]:
len(y_hat[2])

1024

In [340]:
len(y[0][mask_index[0][mask_index[0] != 0]])

y[0][mask_index[0][mask_index[0] != 0]]
y_hat[0][mask_index[0][mask_index[0] != 0]]

#mask_index
thing = mask_index[0][mask_index[0] != 0]
thing.tolist()
#len(mask_index[0])

[41,
 13,
 103,
 122,
 108,
 7,
 155,
 44,
 158,
 77,
 30,
 93,
 82,
 99,
 102,
 124,
 146,
 62,
 85,
 3,
 97,
 164,
 141,
 78,
 142,
 119]

In [228]:
lora_after_training = [model.layers[0].fc1.lora_A, model.layers[0].fc1.lora_B, model.layers[0].fc2.lora_A, model.layers[0].fc2.lora_B, 
                        model.layers[1].fc1.lora_A, model.layers[1].fc1.lora_B, model.layers[1].fc2.lora_A, model.layers[1].fc2.lora_B, 
                        model.layers[2].fc1.lora_A, model.layers[2].fc1.lora_B, model.layers[2].fc2.lora_A, model.layers[2].fc2.lora_B,
                        model.layers[3].fc1.lora_A, model.layers[3].fc1.lora_B, model.layers[3].fc2.lora_A, model.layers[3].fc2.lora_B,
                        model.layers[4].fc1.lora_A, model.layers[4].fc1.lora_B, model.layers[4].fc2.lora_A, model.layers[4].fc2.lora_B,
                        model.layers[5].fc1.lora_A, model.layers[5].fc1.lora_B, model.layers[5].fc2.lora_A, model.layers[5].fc2.lora_B]

print(lora_after_training)


[Parameter containing:
 tensor([[-0.0089,  0.0239, -0.0300,  ...,  0.0442,  0.0503,  0.0449],
         [ 0.0088, -0.0534,  0.0089,  ...,  0.0220,  0.0446, -0.0437],
         [-0.0215,  0.0343,  0.0420,  ..., -0.0398,  0.0237, -0.0056],
         ...,
         [-0.0499, -0.0167, -0.0011,  ..., -0.0099,  0.0430, -0.0011],
         [-0.0497, -0.0211, -0.0406,  ..., -0.0469, -0.0152, -0.0057],
         [-0.0548,  0.0412, -0.0019,  ..., -0.0253, -0.0369, -0.0421]],
        requires_grad=True),
 Parameter containing:
 tensor([[ 0.0019, -0.0019,  0.0019,  ...,  0.0019, -0.0019,  0.0018],
         [-0.0020,  0.0020, -0.0020,  ..., -0.0019,  0.0020, -0.0020],
         [-0.0019, -0.0019, -0.0020,  ...,  0.0020, -0.0019, -0.0020],
         ...,
         [-0.0020, -0.0019, -0.0019,  ...,  0.0020,  0.0015, -0.0020],
         [ 0.0020,  0.0020,  0.0020,  ..., -0.0020,  0.0020,  0.0019],
         [-0.0020,  0.0020, -0.0020,  ...,  0.0017,  0.0020, -0.0020]],
        requires_grad=True)]