In [7]:
import torch
import esm
from torch import nn, optim
import random

In [23]:
# Load the pre-trained ESM-2 model and its alphabet
# model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()

batch_converter = alphabet.get_batch_converter()

# Set the model to evaluation mode initially
model.eval()

# Example data: list of tuples (label, sequence)
data = [
    ("protein1", "MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFPQYKGSGRTQY"),
    ("protein2", "GIEVVVNATLDKAGFQAGYIGFLKTFTLGVAGSGLLGGTYTQAGG"),
    # Add more sequences here
]

# Convert the data to batch format
batch_labels, batch_strs, batch_tokens = batch_converter(data)

# Define a loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Masking function
def mask_tokens(tokens, mask_idx, pad_idx, mask_prob=0.15):
    labels = tokens.clone()
    masked_tokens = tokens.clone()

    # Create a mask based on the probability
    mask = (torch.rand(tokens.shape) < mask_prob) & (tokens != pad_idx)

    # Replace masked positions with the mask index
    masked_tokens[mask] = mask_idx

    return masked_tokens, labels

# Enable training mode
model.train()

# Fine-tuning loop
num_epochs = 30
mask_idx = alphabet.mask_idx
pad_idx = alphabet.padding_idx

for epoch in range(num_epochs):
    optimizer.zero_grad()

    # Mask tokens
    masked_tokens, labels = mask_tokens(batch_tokens, mask_idx, pad_idx)

    # Forward pass: get the output from the model
    output = model(masked_tokens)  # , repr_layers=[num_layers]
    logits = output["logits"]

    # Compute loss for masked language modeling
    loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
    loss.backward()
    optimizer.step()

    print(f"Epoch: {epoch}, Loss: {loss.item()}")

# Save the fine-tuned model
torch.save(model.state_dict(), "fine_tuned_esm2_masked_model.pth")

# Switch back to evaluation mode after fine-tuning
model.eval()

Epoch: 0, Loss: 1.2193660736083984
Epoch: 1, Loss: 1.3682093620300293
Epoch: 2, Loss: 1.2818315029144287
Epoch: 3, Loss: 1.0532981157302856
Epoch: 4, Loss: 1.2930711507797241
Epoch: 5, Loss: 1.1215124130249023
Epoch: 6, Loss: 1.273398518562317
Epoch: 7, Loss: 0.9323391318321228
Epoch: 8, Loss: 1.2231507301330566
Epoch: 9, Loss: 1.173543930053711
Epoch: 10, Loss: 1.046103596687317
Epoch: 11, Loss: 0.8925619125366211
Epoch: 12, Loss: 1.120326280593872
Epoch: 13, Loss: 1.0627635717391968
Epoch: 14, Loss: 0.972506046295166
Epoch: 15, Loss: 1.1718634366989136
Epoch: 16, Loss: 1.0267871618270874
Epoch: 17, Loss: 1.0471450090408325
Epoch: 18, Loss: 0.978630006313324
Epoch: 19, Loss: 0.8714840412139893
Epoch: 20, Loss: 0.9836001396179199
Epoch: 21, Loss: 0.8305772542953491
Epoch: 22, Loss: 1.0646620988845825
Epoch: 23, Loss: 1.001206636428833
Epoch: 24, Loss: 0.8528508543968201
Epoch: 25, Loss: 1.0158283710479736
Epoch: 26, Loss: 1.1392499208450317
Epoch: 27, Loss: 0.9566947221755981
Epoch: 28

ESM2(
  (embed_tokens): Embedding(33, 320, padding_idx=1)
  (layers): ModuleList(
    (0-5): 6 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=320, out_features=320, bias=True)
        (v_proj): Linear(in_features=320, out_features=320, bias=True)
        (q_proj): Linear(in_features=320, out_features=320, bias=True)
        (out_proj): Linear(in_features=320, out_features=320, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=320, out_features=1280, bias=True)
      (fc2): Linear(in_features=1280, out_features=320, bias=True)
      (final_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=120, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((320,), eps=1e-05, elementwis

In [16]:
logits

tensor([[[ 14.6256,  -8.1990,  -5.8007,  ..., -15.5838, -15.7364,  -8.1972],
         [ -8.7877, -15.7859,  -9.4309,  ..., -15.8601, -16.0910, -15.7822],
         [-10.9809, -19.3192, -10.6574,  ..., -16.1523, -16.1474, -19.3129],
         ...,
         [-10.9388, -18.9534, -11.8980,  ..., -16.2301, -16.2017, -18.9448],
         [-11.7391, -20.4840, -11.2705,  ..., -16.2578, -16.3088, -20.4812],
         [ -6.1847,  -9.5217,  14.0999,  ..., -16.5717, -16.5293,  -9.5495]],

        [[ 15.3362,  -7.9659,  -5.2758,  ..., -15.5773, -15.7481,  -7.9697],
         [ -7.2720, -13.8548,  -6.5442,  ..., -15.7878, -15.9642, -13.8574],
         [-10.3269, -16.9376, -10.1985,  ..., -16.0456, -15.8895, -16.9370],
         ...,
         [ -5.3700,  -7.8912,  16.4789,  ..., -16.5534, -16.4161,  -7.9517],
         [ -7.6942, -12.6855,  -3.1495,  ..., -15.9508, -15.9348, -12.7032],
         [ -7.4733, -13.0082,  -3.5302,  ..., -15.9532, -15.9346, -13.0273]]],
       grad_fn=<AddBackward0>)

In [17]:
logits.size()

torch.Size([2, 49, 33])

In [18]:
labels.size()

torch.Size([2, 49])

In [20]:
masked_tokens

tensor([[ 0, 20, 32, 11,  5, 19, 32, 32, 15, 16, 10, 16, 12,  8, 18,  7, 15,  8,
         21, 32, 32, 10, 16, 13, 12,  4, 13,  4, 22, 12, 19, 21, 11, 16,  6, 19,
         32, 14, 16, 19, 15, 32,  8, 32, 10, 11, 32, 19,  2],
        [32,  6, 12,  9,  7,  7,  7, 32,  5, 32,  4, 13, 15,  5,  6, 18, 16,  5,
          6, 32, 12,  6, 32,  4, 15, 11, 18, 11,  4,  6,  7,  5,  6, 32,  6,  4,
          4,  6,  6, 11, 19, 11, 16,  5,  6,  6,  2,  1,  1]])

In [21]:
mask_idx

32

In [22]:
pad_idx

1