In [25]:
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split
import torch
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel
import torch.nn as nn

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

df = pd.read_csv("../train_updated.csv")

sequences = df["protein_sequence"].tolist()
tm = df["tm"].values

train_sequences, test_sequences, train_tm, test_tm = train_test_split(sequences, tm, test_size=0.2, shuffle=True, random_state=42)

In [27]:
model_checkpoint = "facebook/esm2_t6_8M_UR50D"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

train_tokenized = tokenizer(train_sequences, max_length=512, padding="max_length", truncation=True, return_tensors="pt")
test_tokenized = tokenizer(test_sequences, max_length=512, padding="max_length", truncation=True, return_tensors="pt")

In [28]:
class ProteinDataset(Dataset):
    def __init__(self, sequences, tm):
        self.input_ids = sequences["input_ids"]
        self.attention_mask = sequences["attention_mask"]
        self.tm = tm

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "tm": self.tm[idx]
        }
    
train_dataset = ProteinDataset(train_tokenized, train_tm)
test_dataset = ProteinDataset(test_tokenized, test_tm)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

print(len(train_loader), len(test_loader), len(train_dataset), len(test_dataset))

2898 725 23184 5797


In [29]:
class ProteinModel(nn.Module):
    def __init__(self, model_checkpoint):
        super(ProteinModel, self).__init__()
        self.model = AutoModel.from_pretrained(model_checkpoint)
        self.fc = nn.Linear(320, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        cls_token = last_hidden_state[:, 0, :]
        out = self.fc(cls_token)
        return out
    
model = ProteinModel(model_checkpoint).to(device)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

for epoch in range(5):
    model.train()
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        tm = batch["tm"].float().to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs.squeeze(1), tm)
        loss.backward()
        optimizer.step()
        
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for batch in test_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            tm = batch["tm"].float().to(device)
            
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs.squeeze(1), tm)
            total_loss += loss.item()
            
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(test_loader)}")

Epoch 1, Loss: 282.7626096449227


In [None]:
def correlation(predictions, targets):
    # Flatten inputs to 1D tensors
    predictions = predictions.flatten()
    targets = targets.flatten()
    
    # Compute means
    pred_mean = predictions.mean()
    target_mean = targets.mean()
    
    # Compute covariance
    covariance = ((predictions - pred_mean) * (targets - target_mean)).mean()
    
    # Compute standard deviations
    pred_std = predictions.std()
    target_std = targets.std()
    
    # Compute Pearson correlation
    correlation = covariance / (pred_std * target_std + 1e-8)  # Add small epsilon to avoid division by zero
    
    return correlation

# Compute correlation loss on validation set
model.eval()  # Ensure model is in evaluation mode
all_predictions = []
all_labels = []
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        tm = batch["tm"].float().to(device)
        
        all_predictions.append(model(input_ids, attention_mask))
        all_labels.append(tm)

all_predictions = torch.cat(all_predictions, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# Compute the total loss on the entire validation dataset
val_loss = correlation(all_predictions, all_labels).item()

print(f"Validation Correlation Loss: {val_loss:.4f}")

Validation Correlation Loss: 0.7463


In [None]:
from scipy.stats import spearmanr

print(spearmanr(all_predictions.cpu(), all_labels.cpu()).correlation)

0.5564864942596712


In [None]:
torch.save(model.state_dict(), "../model_weights/esm2.pth")