In [24]:
import torch
import esm
from torch import nn, optim
import random
import pandas as pd
from tqdm import tqdm

In [7]:
!head ../data/mariana_to_100.tsv

NLEIPMFD_00006 93 MCIFLRRAVLGCSCNYVLFMEYRTSDIQIAALLYAEGIELVNVDSSNPRRKEFVFKDEQQISELVKGFWQDTHKIAPRKYMGAFRELKNKLYS 0
NLEIPMFD_00010 86 MKYFVYLARCADGSLYTGSCTNIKAREIRHNKGEGAAFTYKRRPVKIVYFEEFKTLIEAMRREKQIKRWTRKKKENLVKYGHPTKF 0
NLEIPMFD_00012 77 MVFNVKRSTYVILTSFQMGKGDLKEKVITKFAYSTPSVTLALPTLFQSKALSAEQMRKLHATIFSVIDPPTQEAVES 0
NLEIPMFD_00013 71 MLMKGVQHKNPITPFGKIIMHCPDPTVLAISQAHLVPFLIQIFGQQQHLLFSPADVGSVRVKEDSHGSLND 0
NLEIPMFD_00015 42 MLPVSCVLEDWNYSPNYLGGVSIKNAGVPGNMLQGQSLNPNS 1
NLEIPMFD_00022 99 MEKLVNIVKRCLNHHKLGESAKASHVLFTAQQFLDKWFVGEKMMAKPVQLKNAVLWIGVRHPTIAQEFRGVSDKLLKELQTRFGPKLVQKIRTKHLTSI 0
NLEIPMFD_00025 96 MRLLLESIGFKVLEASNARHALTLINTEKPDITLTDHMMPGELTGEQLARHLHERGLKVVLTSGYPIEEESCFQFIAKPPRIGVLTAVLKKELGIE 0
NLEIPMFD_00029 95 MLDLTLRSKFLSPTTMIKEAVILKYGVVFTGKRHNVIFNSAQAMGLGFAGLRGGEQGFVTESGEFVNRRKAFEIALACGQIEEREKRKLFSEDLY 0
NLEIPMFD_00030 46 MSISSPAIRISGSGQIVQVLPMMKSRVITKFMGFSYLSYRTVKLVL 1
NLEIPMFD_00034 39 MEIPFIKVHPARLRAEEVLKDVIVQLKEMIEAGRIRNDL 1


In [31]:
# Example data: list of tuples (label, sequence)
data = [
    ("protein1", "MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFPQYKGSGRTQY"),
    ("protein2", "GIEVVVNATLDKAGFQAGYIGFLKTFTLGVAGSGLLGGTY"),
    # Add more sequences here
]

data = list()
df_paths = ['../data/AMP_new/AMP_2024_08_09.tsv', '../data/mariana_to_100.tsv']
for df_path,sep in zip(df_paths, ['\t', ' ']):
    df = pd.read_csv(df_path, header=None, sep=sep)
    for ind, row in df.iterrows():
        data.append( (row[0], row[2]) )
    # break

KeyboardInterrupt: 

In [32]:
len(data)

1107014

In [21]:
len(data)

55209

In [None]:
import torch
import esm
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
import random
import pandas as pd


# Check if a GPU is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
# Load the pre-trained ESM-2 model and its alphabet
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
# model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
model = model.to(device)
batch_converter = alphabet.get_batch_converter()

# Convert the data to batch format
batch_labels, batch_strs, batch_tokens = batch_converter(data)

# Prepare a dataset and dataloader for batching
dataset = TensorDataset(batch_tokens)
batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define a loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=alphabet.padding_idx)
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Masking function
def mask_tokens(tokens, mask_idx, pad_idx, mask_prob=0.15):
    labels = tokens.clone()
    masked_tokens = tokens.clone()

    # Create a mask based on the probability
    mask = (torch.rand(tokens.shape) < mask_prob) & (tokens != pad_idx)

    # Replace masked positions with the mask index
    masked_tokens[mask] = mask_idx

    return masked_tokens, labels

# Enable training mode
model.train()

# Fine-tuning loop
num_epochs = 5
mask_idx = alphabet.mask_idx# torch.tensor(alphabet.mask_idx).to(device)
pad_idx = alphabet.padding_idx #torch.tensor(alphabet.padding_idx).to(device)

for epoch in range(num_epochs):
    for batch in tqdm(dataloader):
        batch_tokens = batch[0]
        optimizer.zero_grad()

        # Mask tokens
        masked_tokens, labels = mask_tokens(batch_tokens, mask_idx, pad_idx)
        masked_tokens = masked_tokens.to(device)
        labels = labels.to(device)  # Move labels to GPU

        # Forward pass: get the output from the model
        # with torch.no_grad():
        output = model(masked_tokens, repr_layers=[33])
        logits = output["logits"]
        # Take the argmax of the logits to get the predicted amino acids
        # predictions = torch.argmax(logits, dim=-1)
        
        # print(logits.size(), labels.size())
        # print(logits.view(-1, logits.size(-1)).size(), labels.view(-1).size())
        # Compute loss for masked language modeling
        # argmax on 33 size vector (size of vocabulary) is performed inside CrossEntropyLoss function
        loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))  
        loss.backward()
        optimizer.step()
        
    print(f"Epoch: {epoch}, Loss: {loss.item()}")

# Save the fine-tuned model
torch.save(model.state_dict(), "fine_tuned_esm2_masked_model.pth")

# Switch back to evaluation mode after fine-tuning
model.eval()

In [4]:
batch_tokens

tensor([[ 0, 20, 15, 11,  5, 19, 12,  5, 15, 16, 10, 16, 12,  8, 18,  7, 15,  8,
         21, 18,  8, 10, 16, 13, 12,  4, 13,  4, 22, 12, 19, 21, 11, 16,  6, 19,
         18, 14, 16, 19, 15,  6,  8,  6, 10, 11, 16, 19,  2],
        [ 0,  6, 12,  9,  7,  7,  7, 17,  5, 11,  4, 13, 15,  5,  6, 18, 16,  5,
          6, 19, 12,  6, 18,  4, 15, 11, 18, 11,  4,  6,  7,  5,  6,  8,  6,  4,
          4,  6,  6, 11, 19,  2,  1,  1,  1,  1,  1,  1,  1]])

In [5]:
masked_tokens

tensor([[ 0, 20, 32, 11, 32, 19, 12,  5, 15, 16, 10, 16, 12,  8, 18, 32, 15,  8,
         21, 18,  8, 10, 16, 13, 12,  4, 13,  4, 22, 12, 19, 21, 11, 16,  6, 19,
         18, 32, 16, 32, 15,  6,  8,  6, 10, 32, 32, 19,  2],
        [ 0,  6, 12,  9, 32,  7,  7, 17, 32, 11,  4, 13, 15,  5,  6, 18, 16, 32,
         32, 19, 32,  6, 18,  4, 15, 11, 18, 32,  4, 32,  7,  5,  6,  8,  6,  4,
          4,  6,  6, 11, 19,  2,  1,  1,  1,  1,  1,  1,  1]], device='cuda:0')

In [6]:
labels

tensor([[ 0, 20, 15, 11,  5, 19, 12,  5, 15, 16, 10, 16, 12,  8, 18,  7, 15,  8,
         21, 18,  8, 10, 16, 13, 12,  4, 13,  4, 22, 12, 19, 21, 11, 16,  6, 19,
         18, 14, 16, 19, 15,  6,  8,  6, 10, 11, 16, 19,  2],
        [ 0,  6, 12,  9,  7,  7,  7, 17,  5, 11,  4, 13, 15,  5,  6, 18, 16,  5,
          6, 19, 12,  6, 18,  4, 15, 11, 18, 11,  4,  6,  7,  5,  6,  8,  6,  4,
          4,  6,  6, 11, 19,  2,  1,  1,  1,  1,  1,  1,  1]], device='cuda:0')

In [7]:
output['logits']

tensor([[[ 14.9760,  -8.2150,  -5.9843,  ..., -15.5670, -15.7102,  -8.2143],
         [ -8.6602, -15.5507,  -9.4913,  ..., -15.7807, -16.0151, -15.5496],
         [-10.9838, -19.3604, -10.5763,  ..., -16.1565, -16.1443, -19.3552],
         ...,
         [-11.1903, -19.1748, -11.2916,  ..., -16.1480, -16.1324, -19.1696],
         [-11.7353, -20.3083, -11.2006,  ..., -16.2858, -16.3197, -20.3066],
         [ -6.3241,  -9.8155,  13.7309,  ..., -16.5888, -16.5480,  -9.8441]],

        [[ 15.6065,  -8.8213,  -5.9375,  ..., -15.5666, -15.7425,  -8.8252],
         [ -7.9181, -16.1651,  -7.7032,  ..., -15.8610, -16.0526, -16.1708],
         [-10.8046, -18.6040, -11.3581,  ..., -16.1035, -15.9465, -18.5995],
         ...,
         [-10.3284, -17.2729,  -9.5998,  ..., -15.9651, -15.9785, -17.2731],
         [ -9.8864, -16.6112,  -8.8543,  ..., -15.9494, -15.9664, -16.6132],
         [ -8.4657, -14.5306,  -5.9612,  ..., -15.8741, -15.8850, -14.5394]]],
       device='cuda:0', grad_fn=<AddBackward

In [24]:
logits.view(-1, logits.size(-1))

tensor([[ 14.6652,  -8.3590,  -5.4929,  ..., -15.6166, -15.7767,  -8.3568],
        [ -7.6443, -14.8921,  -7.3066,  ..., -15.8550, -16.0314, -14.8942],
        [-10.9162, -18.5463, -11.2390,  ..., -16.1038, -15.9538, -18.5438],
        ...,
        [-10.7189, -18.5088, -11.2823,  ..., -16.2630, -16.2313, -18.4974],
        [-11.6306, -19.6774, -10.7238,  ..., -16.2581, -16.2839, -19.6738],
        [ -6.5811,  -9.0778,  14.4713,  ..., -16.5973, -16.4729,  -9.1092]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [25]:
logits.view(-1, logits.size(-1)).size()

torch.Size([98, 33])

In [28]:
masked_tokens

tensor([[ 0, 20, 15, 11,  5, 19, 12,  5, 15, 16, 10, 32, 12,  8, 18,  7, 15, 32,
         21, 18,  8, 10, 32, 13, 12,  4, 13, 32, 22, 32, 19, 21, 11, 16,  6, 19,
         18, 14, 16, 19, 15,  6, 32,  6, 10, 11, 32, 19,  2],
        [32,  6, 12, 32,  7,  7, 32, 17,  5, 11,  4, 13, 15, 32, 32, 32, 16,  5,
          6, 19, 12, 32, 18,  4, 15, 11, 18, 32,  4, 32,  7,  5, 32,  8,  6,  4,
          4,  6,  6, 11, 19, 11, 16,  5,  6,  6,  2,  1,  1]])

In [11]:
# Take the argmax of the logits to get the predicted amino acids
predictions = torch.argmax(logits, dim=-1).float()
predictions

tensor([[ 0., 20., 12.,  9.,  7.,  7.,  7.,  6.,  6., 11.,  4., 13., 15.,  5.,
          6., 18., 16.,  5.,  6., 19., 12.,  6.,  6.,  4., 15., 11., 18., 11.,
          6.,  6.,  7.,  5.,  6.,  8.,  6.,  4.,  4.,  6.,  6., 11., 19.,  2.,
          2.,  2.,  2.,  2.,  2.,  2.,  2.],
        [ 0., 20., 15., 11.,  5., 19., 12.,  5., 15., 16., 10.,  4., 12.,  8.,
         18.,  4.,  4.,  8., 21., 18.,  8., 10.,  4., 13.,  4.,  4., 13.,  4.,
         22., 12., 19., 21., 11., 16., 15.,  4., 18.,  4., 16., 19., 15.,  6.,
          8.,  6., 10., 11., 16., 19.,  2.]], device='cuda:0')

In [30]:
pad_idx

1