In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import numpy as np
import pickle
import torch
from torch import nn
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
import random
import itertools
import re
import torch.optim as optim
from scipy.stats import spearmanr
from transformers import AutoTokenizer, AutoModel

In [3]:
class BasicBlock1D(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv1d(in_planes, planes,
                               kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1   = nn.BatchNorm1d(planes)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(planes, planes,
                               kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2   = nn.BatchNorm1d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(identity)

        out += identity
        return self.relu(out)


class CNN_Contrastive(nn.Module):
    def __init__(self, d_esm: int, use_first_fitness: bool = False):
        """
        CNN‐based contrastive model using a 1D-ResNet18 backbone.

        Args:
        -----
        d_esm: int
            Dimension of the input ESM embeddings.
        use_first_fitness: bool
            If True, append the fitness value of the first variant before the final FC.
        """
        super().__init__()
        self.use_first_fitness = use_first_fitness
        self.inplanes = 64

        # --- initial conv/bn/relu/pool (1D) ---
        self.conv1   = nn.Conv1d(d_esm, 64,
                                 kernel_size=7, stride=2,
                                 padding=3, bias=False)
        self.bn1     = nn.BatchNorm1d(64)
        self.relu    = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3,
                                    stride=2, padding=1)

        # --- the 4 ResNet layers (2 blocks each) ---
        self.layer1 = self._make_layer(BasicBlock1D,  64, 2)
        self.layer2 = self._make_layer(BasicBlock1D, 128, 2, stride=2)
        self.layer3 = self._make_layer(BasicBlock1D, 256, 2, stride=2)
        self.layer4 = self._make_layer(BasicBlock1D, 512, 2, stride=2)

        # --- global pooling + final FC ---
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        fc_in = 512 * BasicBlock1D.expansion + (1 if use_first_fitness else 0)
        self.fc = nn.Linear(fc_in, 1)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv1d(self.inplanes,
                          planes * block.expansion,
                          kernel_size=1,
                          stride=stride,
                          bias=False),
                nn.BatchNorm1d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes,
                            stride=stride,
                            downsample=downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, emb1, emb2, fit1= None):
        """
        Args:
        -----
        emb1, emb2: (N, L, d_esm)
            ESM embeddings for sequence 1 and 2.
        fit1: (N, 1), optional
            The 1‐D fitness feature of the first variant to append before the final FC.

        Returns:
        --------
        out: (N,)
            The scalar logit/regression output.
        """
        # compute difference and permute to (N, d_esm, L)
        x = emb2 - emb1
        x = x.permute(0, 2, 1)

        # stem
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # ResNet layers
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # global pooling
        x = self.avgpool(x)     # (N, 512, 1)
        x = torch.flatten(x, 1) # (N, 512)

        # 5) optionally append Fitness 1
        if self.use_first_fitness and fit1 is not None:
            x = torch.cat([x, fit1], dim=1)  # (N, 513)

        # 6) final FC -> scalar
        out = self.fc(x)  # (N, 1)
        return out.squeeze(1)  # (N,)

In [4]:
def load_data(idx1, idx2, fit1, labels, emb, batch, batch_size, device):
    """
    Returns:
      emb1_batch: (B, L, d_esm) torch.Tensor
      emb2_batch: (B, L, d_esm) torch.Tensor
      fit1_batch: (B, 1)      torch.Tensor
      labels_batch:(B,)       torch.Tensor
    """
    start = batch * batch_size
    end   = min(start + batch_size, len(labels))
    b_idx1 = idx1[start:end]
    b_idx2 = idx2[start:end]
    b_fit1 = np.array(fit1[start:end], dtype=np.float32).reshape(-1,1)
    b_lbl  = np.array(labels[start:end], dtype=np.float32).reshape(-1,1)

    emb1 = emb[np.array(b_idx1)]
    emb2 = emb[np.array(b_idx2)]
    emb1 = torch.from_numpy(emb1).float().to(device)
    emb2 = torch.from_numpy(emb2).float().to(device)
    fit1 = torch.from_numpy(b_fit1).float().to(device)
    lbl  = torch.from_numpy(b_lbl).squeeze(1).float().to(device)
    return emb1, emb2, fit1, lbl

def train_epoch(model, optimizer, idx1, idx2, fit1, labels, emb, batch_size, epoch, device, train_frac):
    model.train()
    criterion = nn.MSELoss()
    running_loss = 0.0
    total = 0
    num_batches = math.ceil(len(labels) / batch_size)
    batch_set_size = math.ceil(num_batches/train_frac)
    batch_set_idx = (epoch-1)%train_frac
    start_b = batch_set_idx * batch_set_size
    end_b = (batch_set_idx+1) * batch_set_size
    for b in tqdm(list(range(num_batches))[start_b: end_b], desc=f"Epoch {epoch}"):
        e1, e2, f1, y = load_data(idx1, idx2, fit1, labels, emb, b, batch_size, device)
        preds = model(e1, e2, f1)
        loss = criterion(preds, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * preds.size(0)
        total += preds.size(0)
    return running_loss / total


def test_epoch(model, idx1, idx2, fit1, labels, emb, batch_size, epoch, device, test_frac=None):
    model.eval()
    criterion = nn.MSELoss()
    running_loss = 0.0
    total = 0
    all_preds = []
    all_targets = []
    num_batches = math.ceil(len(labels) / batch_size)
    b_set = list(range(num_batches))
    if test_frac:
        b_set = random.sample(range(num_batches), math.ceil(num_batches/test_frac))
    with torch.no_grad():
        for b in tqdm(b_set, desc=f"Epoch {epoch}"):
            e1, e2, f1, y = load_data(idx1, idx2, fit1, labels, emb, b, batch_size, device)
            preds = model(e1, e2, f1)
            loss = criterion(preds, y)
            running_loss += loss.item() * preds.size(0)
            total += preds.size(0)
            all_preds.append(preds.cpu())
            all_targets.append(y.cpu())
    avg_loss = running_loss / total
    all_preds = torch.cat(all_preds).detach().cpu().numpy()
    all_targets = torch.cat(all_targets).detach().cpu().numpy()
    corr, _ = spearmanr(all_preds, all_targets)
    return avg_loss, corr, all_preds, all_targets

# Re-train the CNN_Contrastive model on the entire dataset

In [None]:
# Define hyperparameters and instantiate model
batch_size    = 16
train_frac    = 50                 # fraction of training batches to include in each epoch
n_epoch       = train_frac * 2
device        = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CNN_Contrastive(d_esm=1280, use_first_fitness=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=4e-6, weight_decay=1e-4)


# --- load & preprocess data ---
n_round = 3
train_idx1, train_idx2, train_fit1, train_lbl = [], [], [], []

emb_list = []
offset = 0
for r in [f'r{i}' for i in range(1, n_round+1)]:
    emb = np.load(f'/content/drive/MyDrive/Mid_1_data_train/ESM_emb_{r}.npy')
    emb_list.append(emb)
    seq2idx = pickle.load(open(f'/content/drive/MyDrive/Mid_1_data_train/seq_to_index_{r}.pkl','rb'))
    df_tr = pd.read_csv(f'/content/drive/MyDrive/Mid_1_data_train/data_contrastive_inference_{r}.csv').sample(frac=1, random_state=1).reset_index(drop=True)

    train_idx1 += [seq2idx[s.rstrip('*')]+offset for s in df_tr['seq_1']]
    train_idx2 += [seq2idx[s.rstrip('*')]+offset for s in df_tr['seq_2']]
    train_fit1 += [float(f) for f in df_tr['fitness_1']]
    train_lbl  += [float(y) for y in df_tr['label']]

    offset += emb.shape[0]
emb = np.concatenate(emb_list)




# --- training loop ---

train_losses = []

for epoch in range(1, n_epoch+1):
    tr_loss = train_epoch(model, optimizer,
        train_idx1, train_idx2,
        train_fit1, train_lbl,
        emb, batch_size, epoch, device, train_frac)
    train_losses.append(tr_loss)

    print(f"Epoch {epoch:3d} ▶ Train Loss {tr_loss:.4f}")

torch.save(model.state_dict(), "/content/drive/MyDrive/Mid_1_Contrast_results/CNN_Contrastive_inference.pt")



# --- plot loss curve ---
epochs = range(len(train_losses))
ticks  = np.arange(0, len(train_losses), 10)

plt.figure()
plt.plot(epochs, train_losses)
plt.xticks(ticks)
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title(f"CNN_Contrastive Loss Curve")
plt.savefig("/content/drive/MyDrive/Mid_1_Contrast_results/CNN_Contrastive_loss_inference.png")

Epoch 1: 100%|██████████| 11907/11907 [06:23<00:00, 31.05it/s]


Epoch   1 ▶ Train Loss 8.0359


Epoch 2: 100%|██████████| 11907/11907 [06:35<00:00, 30.11it/s]


Epoch   2 ▶ Train Loss 6.8450


Epoch 3: 100%|██████████| 11907/11907 [06:18<00:00, 31.50it/s]


Epoch   3 ▶ Train Loss 4.9328


Epoch 4: 100%|██████████| 11907/11907 [06:45<00:00, 29.39it/s]


Epoch   4 ▶ Train Loss 3.8338


Epoch 5: 100%|██████████| 11907/11907 [06:27<00:00, 30.73it/s]


Epoch   5 ▶ Train Loss 2.2549


Epoch 6: 100%|██████████| 11907/11907 [06:45<00:00, 29.37it/s]


Epoch   6 ▶ Train Loss 1.7173


Epoch 7: 100%|██████████| 11907/11907 [06:33<00:00, 30.22it/s]


Epoch   7 ▶ Train Loss 1.4130


Epoch 8: 100%|██████████| 11907/11907 [06:26<00:00, 30.81it/s]


Epoch   8 ▶ Train Loss 1.2198


Epoch 9: 100%|██████████| 11907/11907 [06:24<00:00, 30.93it/s]


Epoch   9 ▶ Train Loss 1.0862


Epoch 10: 100%|██████████| 11907/11907 [06:21<00:00, 31.17it/s]


Epoch  10 ▶ Train Loss 0.9769


Epoch 11: 100%|██████████| 11907/11907 [06:34<00:00, 30.18it/s]


Epoch  11 ▶ Train Loss 0.8934


Epoch 12: 100%|██████████| 11907/11907 [06:38<00:00, 29.87it/s]


Epoch  12 ▶ Train Loss 0.8280


Epoch 13: 100%|██████████| 11907/11907 [06:24<00:00, 30.96it/s]


Epoch  13 ▶ Train Loss 0.7691


Epoch 14: 100%|██████████| 11907/11907 [06:20<00:00, 31.26it/s]


Epoch  14 ▶ Train Loss 0.7200


Epoch 15: 100%|██████████| 11907/11907 [06:20<00:00, 31.29it/s]


Epoch  15 ▶ Train Loss 0.6754


Epoch 16: 100%|██████████| 11907/11907 [06:25<00:00, 30.92it/s]


Epoch  16 ▶ Train Loss 0.6361


Epoch 17: 100%|██████████| 11907/11907 [06:16<00:00, 31.63it/s]


Epoch  17 ▶ Train Loss 0.6027


Epoch 18: 100%|██████████| 11907/11907 [06:18<00:00, 31.43it/s]


Epoch  18 ▶ Train Loss 0.5761


Epoch 19: 100%|██████████| 11907/11907 [06:34<00:00, 30.15it/s]


Epoch  19 ▶ Train Loss 0.5437


Epoch 20: 100%|██████████| 11907/11907 [06:11<00:00, 32.05it/s]


Epoch  20 ▶ Train Loss 0.5170


Epoch 21: 100%|██████████| 11907/11907 [06:31<00:00, 30.42it/s]


Epoch  21 ▶ Train Loss 0.4934


Epoch 22: 100%|██████████| 11907/11907 [06:39<00:00, 29.81it/s]


Epoch  22 ▶ Train Loss 0.4686


Epoch 23: 100%|██████████| 11907/11907 [06:34<00:00, 30.17it/s]


Epoch  23 ▶ Train Loss 0.4534


Epoch 24: 100%|██████████| 11907/11907 [06:28<00:00, 30.62it/s]


Epoch  24 ▶ Train Loss 0.4324


Epoch 25: 100%|██████████| 11907/11907 [06:20<00:00, 31.25it/s]


Epoch  25 ▶ Train Loss 0.4140


Epoch 26: 100%|██████████| 11907/11907 [06:20<00:00, 31.33it/s]


Epoch  26 ▶ Train Loss 0.3984


Epoch 27: 100%|██████████| 11907/11907 [06:25<00:00, 30.88it/s]


Epoch  27 ▶ Train Loss 0.3776


Epoch 28: 100%|██████████| 11907/11907 [06:29<00:00, 30.57it/s]


Epoch  28 ▶ Train Loss 0.3590


Epoch 29: 100%|██████████| 11907/11907 [06:27<00:00, 30.75it/s]


Epoch  29 ▶ Train Loss 0.3456


Epoch 30: 100%|██████████| 11907/11907 [06:30<00:00, 30.50it/s]


Epoch  30 ▶ Train Loss 0.3305


Epoch 31: 100%|██████████| 11907/11907 [06:26<00:00, 30.81it/s]


Epoch  31 ▶ Train Loss 0.3172


Epoch 32: 100%|██████████| 11907/11907 [06:27<00:00, 30.71it/s]


Epoch  32 ▶ Train Loss 0.3018


Epoch 33: 100%|██████████| 11907/11907 [06:36<00:00, 30.02it/s]


Epoch  33 ▶ Train Loss 0.2838


Epoch 34: 100%|██████████| 11907/11907 [06:51<00:00, 28.90it/s]


Epoch  34 ▶ Train Loss 0.2746


Epoch 35: 100%|██████████| 11907/11907 [06:29<00:00, 30.56it/s]


Epoch  35 ▶ Train Loss 0.2607


Epoch 36: 100%|██████████| 11907/11907 [06:35<00:00, 30.10it/s]


Epoch  36 ▶ Train Loss 0.2496


Epoch 37: 100%|██████████| 11907/11907 [06:37<00:00, 29.99it/s]


Epoch  37 ▶ Train Loss 0.2357


Epoch 38: 100%|██████████| 11907/11907 [06:24<00:00, 30.98it/s]


Epoch  38 ▶ Train Loss 0.2259


Epoch 39: 100%|██████████| 11907/11907 [06:27<00:00, 30.72it/s]


Epoch  39 ▶ Train Loss 0.2127


Epoch 40: 100%|██████████| 11907/11907 [06:25<00:00, 30.87it/s]


Epoch  40 ▶ Train Loss 0.2050


Epoch 41: 100%|██████████| 11907/11907 [06:25<00:00, 30.85it/s]


Epoch  41 ▶ Train Loss 0.1921


Epoch 42: 100%|██████████| 11907/11907 [06:30<00:00, 30.50it/s]


Epoch  42 ▶ Train Loss 0.1846


Epoch 43: 100%|██████████| 11907/11907 [06:27<00:00, 30.70it/s]


Epoch  43 ▶ Train Loss 0.1745


Epoch 44: 100%|██████████| 11907/11907 [06:31<00:00, 30.45it/s]


Epoch  44 ▶ Train Loss 0.1655


Epoch 45: 100%|██████████| 11907/11907 [06:27<00:00, 30.70it/s]


Epoch  45 ▶ Train Loss 0.1605


Epoch 46: 100%|██████████| 11907/11907 [06:32<00:00, 30.30it/s]


Epoch  46 ▶ Train Loss 0.1502


Epoch 47: 100%|██████████| 11907/11907 [06:39<00:00, 29.78it/s]


Epoch  47 ▶ Train Loss 0.1426


Epoch 48: 100%|██████████| 11907/11907 [06:26<00:00, 30.85it/s]


Epoch  48 ▶ Train Loss 0.1371


Epoch 49: 100%|██████████| 11907/11907 [06:29<00:00, 30.58it/s]


Epoch  49 ▶ Train Loss 0.1314


Epoch 50: 100%|██████████| 11875/11875 [06:20<00:00, 31.17it/s]


Epoch  50 ▶ Train Loss 0.1276


Epoch 51: 100%|██████████| 11907/11907 [06:24<00:00, 30.95it/s]


Epoch  51 ▶ Train Loss 1.2834


Epoch 52: 100%|██████████| 11907/11907 [06:34<00:00, 30.21it/s]


Epoch  52 ▶ Train Loss 1.3179


Epoch 53: 100%|██████████| 11907/11907 [06:41<00:00, 29.62it/s]


Epoch  53 ▶ Train Loss 0.6289


Epoch 54: 100%|██████████| 11907/11907 [06:25<00:00, 30.91it/s]


Epoch  54 ▶ Train Loss 0.4084


Epoch 55: 100%|██████████| 11907/11907 [06:33<00:00, 30.24it/s]


Epoch  55 ▶ Train Loss 0.1447


Epoch 56: 100%|██████████| 11907/11907 [06:25<00:00, 30.92it/s]


Epoch  56 ▶ Train Loss 0.1181


Epoch 57: 100%|██████████| 11907/11907 [06:19<00:00, 31.34it/s]


Epoch  57 ▶ Train Loss 0.1076


Epoch 58: 100%|██████████| 11907/11907 [06:38<00:00, 29.85it/s]


Epoch  58 ▶ Train Loss 0.1016


Epoch 59: 100%|██████████| 11907/11907 [06:32<00:00, 30.35it/s]


Epoch  59 ▶ Train Loss 0.0942


Epoch 60: 100%|██████████| 11907/11907 [06:24<00:00, 30.97it/s]


Epoch  60 ▶ Train Loss 0.0899


Epoch 61: 100%|██████████| 11907/11907 [06:29<00:00, 30.57it/s]


Epoch  61 ▶ Train Loss 0.0878


Epoch 62: 100%|██████████| 11907/11907 [06:30<00:00, 30.46it/s]


Epoch  62 ▶ Train Loss 0.0818


Epoch 63: 100%|██████████| 11907/11907 [06:40<00:00, 29.76it/s]


Epoch  63 ▶ Train Loss 0.0784


Epoch 64: 100%|██████████| 11907/11907 [06:24<00:00, 30.98it/s]


Epoch  64 ▶ Train Loss 0.0767


Epoch 65: 100%|██████████| 11907/11907 [06:26<00:00, 30.78it/s]


Epoch  65 ▶ Train Loss 0.0726


Epoch 66: 100%|██████████| 11907/11907 [06:33<00:00, 30.25it/s]


Epoch  66 ▶ Train Loss 0.0703


Epoch 67: 100%|██████████| 11907/11907 [06:31<00:00, 30.43it/s]


Epoch  67 ▶ Train Loss 0.0688


Epoch 68: 100%|██████████| 11907/11907 [06:22<00:00, 31.11it/s]


Epoch  68 ▶ Train Loss 0.0634


Epoch 69: 100%|██████████| 11907/11907 [06:40<00:00, 29.72it/s]


Epoch  69 ▶ Train Loss 0.0635


Epoch 70: 100%|██████████| 11907/11907 [06:31<00:00, 30.40it/s]


Epoch  70 ▶ Train Loss 0.0593


Epoch 71: 100%|██████████| 11907/11907 [06:26<00:00, 30.79it/s]


Epoch  71 ▶ Train Loss 0.0590


Epoch 72: 100%|██████████| 11907/11907 [06:20<00:00, 31.27it/s]


Epoch  72 ▶ Train Loss 0.0573


Epoch 73: 100%|██████████| 11907/11907 [06:38<00:00, 29.89it/s]


Epoch  73 ▶ Train Loss 0.0571


Epoch 74: 100%|██████████| 11907/11907 [06:24<00:00, 31.00it/s]


Epoch  74 ▶ Train Loss 0.0534


Epoch 75: 100%|██████████| 11907/11907 [06:26<00:00, 30.78it/s]


Epoch  75 ▶ Train Loss 0.0522


Epoch 76: 100%|██████████| 11907/11907 [06:33<00:00, 30.27it/s]


Epoch  76 ▶ Train Loss 0.0512


Epoch 77: 100%|██████████| 11907/11907 [06:32<00:00, 30.34it/s]


Epoch  77 ▶ Train Loss 0.0484


Epoch 78: 100%|██████████| 11907/11907 [06:40<00:00, 29.74it/s]


Epoch  78 ▶ Train Loss 0.0474


Epoch 79: 100%|██████████| 11907/11907 [06:36<00:00, 30.00it/s]


Epoch  79 ▶ Train Loss 0.0461


Epoch 80: 100%|██████████| 11907/11907 [06:25<00:00, 30.89it/s]


Epoch  80 ▶ Train Loss 0.0476


Epoch 81: 100%|██████████| 11907/11907 [06:22<00:00, 31.10it/s]


Epoch  81 ▶ Train Loss 0.0443


Epoch 82: 100%|██████████| 11907/11907 [06:23<00:00, 31.07it/s]


Epoch  82 ▶ Train Loss 0.0433


Epoch 83: 100%|██████████| 11907/11907 [06:27<00:00, 30.72it/s]


Epoch  83 ▶ Train Loss 0.0418


Epoch 84: 100%|██████████| 11907/11907 [06:30<00:00, 30.46it/s]


Epoch  84 ▶ Train Loss 0.0436


Epoch 85: 100%|██████████| 11907/11907 [06:36<00:00, 30.03it/s]


Epoch  85 ▶ Train Loss 0.0421


Epoch 86: 100%|██████████| 11907/11907 [06:35<00:00, 30.09it/s]


Epoch  86 ▶ Train Loss 0.0391


Epoch 87: 100%|██████████| 11907/11907 [06:33<00:00, 30.28it/s]


Epoch  87 ▶ Train Loss 0.0383


Epoch 88: 100%|██████████| 11907/11907 [06:22<00:00, 31.16it/s]


Epoch  88 ▶ Train Loss 0.0387


Epoch 89: 100%|██████████| 11907/11907 [06:26<00:00, 30.81it/s]


Epoch  89 ▶ Train Loss 0.0395


Epoch 90: 100%|██████████| 11907/11907 [06:21<00:00, 31.18it/s]


Epoch  90 ▶ Train Loss 0.0362


Epoch 91: 100%|██████████| 11907/11907 [06:27<00:00, 30.74it/s]


Epoch  91 ▶ Train Loss 0.0355


Epoch 92: 100%|██████████| 11907/11907 [06:31<00:00, 30.42it/s]


Epoch  92 ▶ Train Loss 0.0379


Epoch 93: 100%|██████████| 11907/11907 [06:31<00:00, 30.41it/s]


Epoch  93 ▶ Train Loss 0.0349


Epoch 94:  49%|████▉     | 5819/11907 [03:11<03:19, 30.56it/s]

# Inference on point mutations

In [8]:
# 1) Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN_Contrastive(d_esm=1280, use_first_fitness=True).to(device)
checkpoint = "/content/drive/MyDrive/Mid_1_Contrast_results/CNN_Contrastive_inference.pt"
model.load_state_dict(torch.load(checkpoint, map_location=device))
model.eval()

# 2) Prepare ESM-2 for embedding
esm_model_name = "facebook/esm2_t33_650M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(esm_model_name, do_lower_case=False)
esm = AutoModel.from_pretrained(esm_model_name).to(device)
esm.eval()


# 3) Build all point‐mutants of the reference seq
ref_seq = "MAGSDSPLAEQIKNTLTFIGQANAAGRMDEVRTLQKNLHPLWAEYFQLTEGSGGSPLAQQIQNGHVLIHQARAAGRMDEVRRLTEKTLQLMKEYFQQSD"
aas = list("ACDEFGHIKLMNPQRSTVWY")
start_idx = 5    # index start from 1

variants = []
mutations = []
L = len(ref_seq)
for i, wt in enumerate(ref_seq):
    if i < start_idx-1:
      continue
    for aa in aas:
        if aa == wt:
            continue
        var = ref_seq[:i] + aa + ref_seq[i+1:]
        variants.append(var)
        mutations.append(f"{wt}{i-3}{aa}")

# include the wild‐type itself
variants.insert(0, ref_seq)
mutations.insert(0, "")


# 4) Embed mid13sc once (padded/truncated to length 101 tokens)
max_len = len(ref_seq) + 2
with torch.no_grad():
    mid_toks = tokenizer(
        ref_seq,
        return_tensors="pt",
        padding="max_length",
        max_length=max_len,
        truncation=True,
        add_special_tokens=True
    ).to(device)
    mid_emb = esm(**mid_toks).last_hidden_state.squeeze(0)  # (101,1280)

# 5) Batch‐infer all variants
batch_size = 32
fit1_value = 8.666429599066408
results = []

for start in tqdm(range(0, len(variants), batch_size), desc="Inferring variants"):
    end = min(start + batch_size, len(variants))
    batch_seqs = variants[start:end]
    batch_muts = mutations[start:end]

    # tokenize & embed variants
    with torch.no_grad():
        toks = tokenizer(
            batch_seqs,
            return_tensors="pt",
            padding="max_length",
            max_length=max_len,
            truncation=True,
            add_special_tokens=True
        ).to(device)
        var_emb = esm(**toks).last_hidden_state       # (B, 101, 1280)

    # prepare emb1 = repeated mid_emb
    B = end - start
    emb1 = mid_emb.unsqueeze(0).repeat(B, 1, 1)       # (B, 101, 1280)
    emb2 = var_emb                                    # (B, 101, 1280)
    fit1 = torch.full((B, 1), fit1_value, device=device)  # (B,1)

    # predict
    with torch.no_grad():
        preds = model(emb1, emb2, fit1)               # (B,)

    for seq, mut, score in zip(batch_seqs, batch_muts, preds.cpu().numpy()):
        results.append({
            "sequence": seq,
            "mutation": mut,
            "predicted_improvement": score
        })

# 6) Rank & save
df_out = pd.DataFrame(results)
df_out = df_out.sort_values(
    "predicted_improvement",
    ascending=False
).reset_index(drop=True)

out_csv = "/content/drive/MyDrive/Mid_1_Contrast_results/mid13sc_point_mut_results_cnn.csv"
df_out.to_csv(out_csv, index=False)
print(f"Saved {len(df_out)} variants to {out_csv}")

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Inferring variants: 100%|██████████| 57/57 [02:52<00:00,  3.03s/it]

Saved 1806 variants to /content/drive/MyDrive/Mid_1_Contrast_results/mid13sc_point_mut_results_cnn.csv





# Find where the reference sequence is ranked

In [9]:
df_out[df_out['mutation'] == '']

Unnamed: 0,sequence,mutation,predicted_improvement
1120,MAGSDSPLAEQIKNTLTFIGQANAAGRMDEVRTLQKNLHPLWAEYF...,,-9.010009


# Inference on double/ triple mutations

In [10]:
# 1) Load the previous single‐mutant results and pick top 30 (excluding wildtype)
df_single = pd.read_csv(
    "/content/drive/MyDrive/Mid_1_Contrast_results/mid13sc_point_mut_results_cnn.csv"
)
# exclude the wild‐type entry which has empty mutation string
df_filtered = df_single[df_single['mutation'] != ""].reset_index(drop=True)
top30 = df_filtered.head(30)
mut_list = top30['mutation'].tolist()

# 2) Generate all doublet & triplet combinations (only if positions differ)
#    parse mutation strings like "M1D" → (ref_aa, pos, alt_aa)

mut_pattern = re.compile(r'^([A-Z])(\d+)([A-Z])$')

def parse_mut(mut_str):
    """
    Given a mutation string like "G50A", parse into:
      (0‐based position, new_amino_acid).
    According to the new scheme, the integer in the middle is offset by –3,
    so actual_index = int(number) + 3.
    """
    m = mut_pattern.match(mut_str)
    if m is None:
        raise ValueError(f"Unrecognized mutation format: {mut_str}")
    raw_num = int(m.group(2))
    actual_idx = raw_num + 3           # offset by –3 → +3 here
    new_aa = m.group(3)
    return (actual_idx, new_aa)



# Pre‐parse to tuples for fast checks
parsed = [parse_mut(m) for m in mut_list]
# Keep track of original mutation strings in same order
mutation_to_parsed = dict(zip(mut_list, parsed))

# Build valid combos
doublets = [
    combo for combo in itertools.combinations(mut_list, 2)
    if mutation_to_parsed[combo[0]][0] != mutation_to_parsed[combo[1]][0]
]
triplets = [
    combo for combo in itertools.combinations(mut_list, 3)
    if len({mutation_to_parsed[m][0] for m in combo}) == 3
]
all_combos = doublets + triplets

# 3) Apply mutations to the reference sequence
ref_seq = "MAGSDSPLAEQIKNTLTFIGQANAAGRMDEVRTLQKNLHPLWAEYFQLTEGSGGSPLAQQIQNGHVLIHQARAAGRMDEVRRLTEKTLQLMKEYFQQSD"


def apply_mutations(ref, muts):
    arr = list(ref)
    for mut in muts:
        pos, new_aa = mutation_to_parsed[mut]
        arr[pos] = new_aa
    return "".join(arr)

variants   = [apply_mutations(ref_seq, combo) for combo in all_combos]
mut_strs   = [",".join(combo) for combo in all_combos]


# 4) Load model & ESM embedder (assumes you ran the prior snippet)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Re‐load CNN_Contrastive
model = CNN_Contrastive(d_esm=1280, use_first_fitness=True).to(device)
model.load_state_dict(torch.load(
    "/content/drive/MyDrive/Mid_1_Contrast_results/CNN_Contrastive_inference.pt",
    map_location=device
))
model.eval()

# Re‐load ESM-2
from transformers import AutoTokenizer, AutoModel
esm_name = "facebook/esm2_t33_650M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(esm_name, do_lower_case=False)
esm = AutoModel.from_pretrained(esm_name).to(device)
esm.eval()

# Pre‐embed the reference once
max_len = len(ref_seq) + 2
with torch.no_grad():
    tok = tokenizer(
        ref_seq,
        return_tensors="pt",
        padding="max_length",
        max_length=max_len,
        truncation=True,
        add_special_tokens=True
    ).to(device)
    ref_emb = esm(**tok).last_hidden_state.squeeze(0)  # (101,1280)

# 5) Batch‐embed & predict all combos
batch_size = 32
fit1_value = 8.666429599066408
results = []

for i in tqdm(range(0, len(variants), batch_size), desc="Doublet/Triplet inference"):
    batch_seqs = variants[i:i+batch_size]
    batch_muts = mut_strs[i:i+batch_size]
    B = len(batch_seqs)

    # Tokenize + embed
    with torch.no_grad():
        toks = tokenizer(
            batch_seqs,
            return_tensors="pt",
            padding="max_length",
            max_length=max_len,
            truncation=True,
            add_special_tokens=True
        ).to(device)
        var_emb = esm(**toks).last_hidden_state  # (B,101,1280)

    # Prepare emb1, emb2, fit1
    emb1 = ref_emb.unsqueeze(0).repeat(B,1,1)         # (B,101,1280)
    emb2 = var_emb                                    # (B,101,1280)
    fit1 = torch.full((B,1), fit1_value, device=device)

    # Predict
    with torch.no_grad():
        preds = model(emb1, emb2, fit1).cpu().numpy()  # (B,)

    # Collect
    for seq, muts, score in zip(batch_seqs, batch_muts, preds):
        results.append({
            "sequence": seq,
            "mutations": muts,
            "predicted_improvement": score
        })

# 6) Rank & save
df_out = pd.DataFrame(results)
df_out = df_out.sort_values(
    "predicted_improvement",
    ascending=False
).reset_index(drop=True)

out_csv = "/content/drive/MyDrive/Mid_1_Contrast_results/mid13sc_doublet_triplet_results_cnn.csv"
df_out.to_csv(out_csv, index=False)
print(f"Saved {len(df_out)} combined variants to {out_csv}")

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Doublet/Triplet inference: 100%|██████████| 128/128 [06:29<00:00,  3.04s/it]

Saved 4068 combined variants to /content/drive/MyDrive/Mid_1_Contrast_results/mid13sc_doublet_triplet_results_cnn.csv



