In [1]:
from transformers import AutoTokenizer
import torch
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel
import torch.nn as nn
from sklearn.model_selection import KFold
from scipy.stats import spearmanr
import torch
from torch.utils.data import DataLoader, Subset

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

df = pd.read_csv("/home/ml4science0/novozymes/train_updated.csv")

sequences = df["protein_sequence"].tolist()
tm = df["tm"].values

In [3]:
class ProteinDataset(Dataset):
    def __init__(self, sequences, tm=None):
        self.input_ids = sequences["input_ids"]
        self.attention_mask = sequences["attention_mask"]
        self.tm = tm

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        if self.tm is None:
            return {
                "input_ids": self.input_ids[idx],
                "attention_mask": self.attention_mask[idx]
            }
        else:
            return {
                "input_ids": self.input_ids[idx],
                "attention_mask": self.attention_mask[idx],
                "tm": self.tm[idx]
            }

In [4]:
model_checkpoint = "facebook/esm2_t6_8M_UR50D"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [5]:
class ProteinModel(nn.Module):
    def __init__(self, model_checkpoint):
        super(ProteinModel, self).__init__()
        self.model = AutoModel.from_pretrained(model_checkpoint)
        self.fc1 = nn.Linear(320, 120)
        self.fc2 = nn.Linear(120, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        cls_token = last_hidden_state[:, 0, :]
        out = self.fc1(cls_token)
        out = self.fc2(out)
        return out

In [6]:
# K-Fold Cross Validation
k = 5
kf = KFold(n_splits=k, shuffle=True, random_state=42)

# Tokenize all sequences at once (before splitting)
tokenized = tokenizer(sequences, max_length=512, padding="max_length", truncation=True, return_tensors="pt")
tm_values = torch.tensor(tm, dtype=torch.float32)

dataset = ProteinDataset(tokenized, tm_values)

predictions_per_fold = []
labels_per_fold = []
fold_results = []

print(dataset)

for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
    print(f"Fold {fold + 1}/{k}")

    # Create train and validation subsets
    train_subset = Subset(dataset, train_idx)
    val_subset = Subset(dataset, val_idx)

    print(train_subset)
    print(val_subset)

    # Create DataLoaders
    train_loader = DataLoader(train_subset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=8, shuffle=False)

    # Initialize model, loss, optimizer
    model = ProteinModel(model_checkpoint).to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

    num_epochs = 5
    val_losses = []
    val_correlations = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        # Training loop
        for batch in train_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            batch_tm = batch["tm"].float().to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs.squeeze(1), batch_tm)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {epoch_loss:.4f}")

        # Validation loop
        model.eval()
        val_loss = 0.0
        total_preds = []
        total_tms = []

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                batch_tm = batch["tm"].float().to(device)

                outputs = model(input_ids, attention_mask)
                loss = criterion(outputs.squeeze(1), batch_tm)
                val_loss += loss.item()

                total_preds += outputs.squeeze(1).cpu().numpy().tolist()
                total_tms += batch_tm.cpu().numpy().tolist()

        val_loss /= len(val_loader)
        print(f"Validation Loss: {val_loss:.4f}")

        # Compute Spearman correlation
        spearman_corr = spearmanr(total_preds, total_tms)
        val_correlations.append(spearman_corr.correlation)
        print(f"Spearman Correlation: {spearman_corr.correlation:.4f}")

    model.eval()
    total_preds = []
    total_tms = []

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            batch_tm = batch["tm"].float().to(device)

            outputs = model(input_ids, attention_mask)

            total_preds += outputs.squeeze(1).cpu().numpy().tolist()
            total_tms += batch_tm.cpu().numpy().tolist()

    predictions_per_fold.append(total_preds)
    labels_per_fold.append(total_tms)

    # Store results for this fold
    fold_results.append({
        "fold": fold + 1,
        "final_val_loss": val_loss,
        "final_spearman_corr": val_correlations[-1],
    })

# Aggregate results across folds
avg_val_loss = sum(result["final_val_loss"] for result in fold_results) / k
avg_spearman_corr = sum(result["final_spearman_corr"] for result in fold_results) / k

print(f"Average Validation Loss: {avg_val_loss:.4f}")
print(f"Average Spearman Correlation: {avg_spearman_corr:.4f}")

<__main__.ProteinDataset object at 0x7f108f7c7f50>
Fold 1/5
<torch.utils.data.dataset.Subset object at 0x7f108f84eae0>
<torch.utils.data.dataset.Subset object at 0x7f11dc3f5460>


2024-12-16 20:15:24.457738: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1734376524.467139 1661470 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1734376524.469945 1661470 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-16 20:15:24.479583: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialize

Epoch 1/5, Train Loss: 321.4490
Validation Loss: 67.7165
Spearman Correlation: 0.5082
Epoch 2/5, Train Loss: 64.9673
Validation Loss: 61.0064
Spearman Correlation: 0.5310
Epoch 3/5, Train Loss: 58.7290
Validation Loss: 57.1001
Spearman Correlation: 0.5488
Epoch 4/5, Train Loss: 52.6321
Validation Loss: 58.3605
Spearman Correlation: 0.5578
Epoch 5/5, Train Loss: 47.8951
Validation Loss: 55.7869
Spearman Correlation: 0.5562


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 2/5
<torch.utils.data.dataset.Subset object at 0x7f0ff622c1a0>
<torch.utils.data.dataset.Subset object at 0x7f0ff622d9d0>
Epoch 1/5, Train Loss: 324.8848
Validation Loss: 65.4420
Spearman Correlation: 0.5216
Epoch 2/5, Train Loss: 62.7491
Validation Loss: 60.9669
Spearman Correlation: 0.5394
Epoch 3/5, Train Loss: 53.6794
Validation Loss: 58.9299
Spearman Correlation: 0.5615
Epoch 4/5, Train Loss: 46.5653
Validation Loss: 59.3775
Spearman Correlation: 0.5719
Epoch 5/5, Train Loss: 37.9396
Validation Loss: 57.1369
Spearman Correlation: 0.5675


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 3/5
<torch.utils.data.dataset.Subset object at 0x7f108255dca0>
<torch.utils.data.dataset.Subset object at 0x7f108ec8d4f0>
Epoch 1/5, Train Loss: 322.2094
Validation Loss: 70.0897
Spearman Correlation: 0.5099
Epoch 2/5, Train Loss: 62.7953
Validation Loss: 63.7072
Spearman Correlation: 0.5435
Epoch 3/5, Train Loss: 56.7344
Validation Loss: 61.4422
Spearman Correlation: 0.5622
Epoch 4/5, Train Loss: 51.3118
Validation Loss: 59.1103
Spearman Correlation: 0.5691
Epoch 5/5, Train Loss: 45.8820
Validation Loss: 60.6027
Spearman Correlation: 0.5657


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 4/5
<torch.utils.data.dataset.Subset object at 0x7f11dc3f5460>
<torch.utils.data.dataset.Subset object at 0x7f108f4cddf0>
Epoch 1/5, Train Loss: 319.2952
Validation Loss: 68.9296
Spearman Correlation: 0.5216
Epoch 2/5, Train Loss: 63.0562
Validation Loss: 58.3747
Spearman Correlation: 0.5595
Epoch 3/5, Train Loss: 54.5772
Validation Loss: 63.0117
Spearman Correlation: 0.5708
Epoch 4/5, Train Loss: 47.1118
Validation Loss: 58.9907
Spearman Correlation: 0.5707
Epoch 5/5, Train Loss: 39.1803
Validation Loss: 61.7452
Spearman Correlation: 0.5545


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 5/5
<torch.utils.data.dataset.Subset object at 0x7f108f84eae0>
<torch.utils.data.dataset.Subset object at 0x7f10140c3aa0>
Epoch 1/5, Train Loss: 316.5728
Validation Loss: 69.7140
Spearman Correlation: 0.5210
Epoch 2/5, Train Loss: 62.7005
Validation Loss: 62.2377
Spearman Correlation: 0.5490
Epoch 3/5, Train Loss: 56.0497
Validation Loss: 60.2856
Spearman Correlation: 0.5683
Epoch 4/5, Train Loss: 50.4222
Validation Loss: 59.2929
Spearman Correlation: 0.5632
Epoch 5/5, Train Loss: 43.6345
Validation Loss: 59.7276
Spearman Correlation: 0.5572
Average Validation Loss: 58.9999
Average Spearman Correlation: 0.5602


In [27]:
# Drop the last element from fold 0 to make all arrays the same length
labels_per_fold[0] = labels_per_fold[0][:-1]
predictions_per_fold[0] = predictions_per_fold[0][:-1]

In [28]:
results = {}
for i in range(k):
    results[f"tm_{i}"] = labels_per_fold[i]
    results[f"preds_{i}"] = predictions_per_fold[i]

pd.DataFrame(results).to_csv("/home/ml4science0/novozymes/predictions/esm.csv", index=False)