In [1]:
import torch
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoModel
import torch.nn as nn
from scipy.stats import spearmanr
from torch.utils.data import DataLoader, Dataset

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

df = pd.read_csv("/home/ml4science0/novozymes/train.csv")
test_df = pd.read_csv("/home/ml4science0/novozymes/test.csv")

sequences = df["protein_sequence"].tolist()
tm = df["tm"].values

train_sequences, test_sequences, train_tm, test_tm = train_test_split(sequences, tm, test_size=0.2, shuffle=True, random_state=42)

In [16]:
model_checkpoint = "facebook/esm2_t6_8M_UR50D"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

train_tokenized = tokenizer(train_sequences, max_length=512, padding="max_length", truncation=True, return_tensors="pt")
test_tokenized = tokenizer(test_sequences, max_length=512, padding="max_length", truncation=True, return_tensors="pt")

In [32]:
class ProteinModel(nn.Module):
    def __init__(self, model_checkpoint):
        super(ProteinModel, self).__init__()
        self.model = AutoModel.from_pretrained(model_checkpoint)
        self.fc = nn.Linear(320, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        cls_token = last_hidden_state[:, 0, :]
        out = self.fc(cls_token)
        return out

In [None]:
class ProteinDataset(Dataset):
    def __init__(self, sequences, tm):
        self.input_ids = sequences["input_ids"]
        self.attention_mask = sequences["attention_mask"]
        self.tm = tm

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "tm": self.tm[idx]
        }

train_dataset = ProteinDataset(train_tokenized, train_tm)
test_dataset = ProteinDataset(test_tokenized, test_tm)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

print(len(train_loader), len(test_loader), len(train_dataset), len(test_dataset))

3139 785 25112 6278


In [33]:
model_checkpoint = "facebook/esm2_t6_8M_UR50D"
model = ProteinModel(model_checkpoint)
model.load_state_dict(torch.load("../model_weights/esm.pth"), strict=False)
model.to(device)
model.eval()

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  model.load_state_dict(torch.load("../model_weights/esm.pth"), strict=False)


ProteinModel(
  (model): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-5): 6 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=320, out_features=320, bias=True)
              (key): Linear(in_features=320, out_features=320, bias=True)
              (value): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_a

In [34]:
def correlation(predictions, targets):
    # Flatten inputs to 1D tensors
    predictions = predictions.flatten()
    targets = targets.flatten()
    
    # Compute means
    pred_mean = predictions.mean()
    target_mean = targets.mean()
    
    # Compute covariance
    covariance = ((predictions - pred_mean) * (targets - target_mean)).mean()
    
    # Compute standard deviations
    pred_std = predictions.std()
    target_std = targets.std()
    
    # Compute Pearson correlation
    correlation = covariance / (pred_std * target_std + 1e-8)  # Add small epsilon to avoid division by zero
    
    return correlation

# Compute correlation loss on validation set
model.eval()  # Ensure model is in evaluation mode
all_predictions = []
all_labels = []
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        tm = batch["tm"].float().to(device)
        
        all_predictions.append(model(input_ids, attention_mask))
        all_labels.append(tm)

all_predictions = torch.cat(all_predictions, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# Compute the total loss on the entire validation dataset
val_loss = correlation(all_predictions, all_labels).item()

print(f"Validation Correlation Loss: {val_loss:.4f}")

Validation Correlation Loss: 0.7463


In [35]:
from scipy.stats import spearmanr

print(spearmanr(all_predictions.cpu(), all_labels.cpu()).correlation)

0.5564864942596712


In [None]:
from sklearn.metrics import mean_squared_error 

print(mean_squared_error(all_predictions.cpu(), all_labels.cpu()))

85.65133
