In [34]:
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
import pandas as pd
import matplotlib.pyplot as plt
import os

import torch
import gpytorch
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error
from scipy.stats import kendalltau
from joblib import Parallel, delayed
from sklearn.model_selection import train_test_split


In [35]:
#esm2_t6_8M_UR50D -> R^320 embeddings

esm_config = AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D")
esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
esm_model = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D")

df = pd.read_csv("data/processed_data/UBE4B_MOUSE_Klevit2013-nscor_log2_ratio/data.csv")
df = df.rename(columns={"seq": "sequence", "log_fitness": "viability"})

print(esm_config)
print(df.shape)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmConfig {
  "architectures": [
    "EsmForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "dtype": "float32",
  "emb_layer_norm_before": false,
  "esmfold_config": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 320,
  "initializer_range": 0.02,
  "intermediate_size": 1280,
  "is_folding_model": false,
  "layer_norm_eps": 1e-05,
  "mask_token_id": 32,
  "max_position_embeddings": 1026,
  "model_type": "esm",
  "num_attention_heads": 20,
  "num_hidden_layers": 6,
  "pad_token_id": 1,
  "position_embedding_type": "rotary",
  "token_dropout": true,
  "transformers_version": "4.56.0",
  "use_cache": true,
  "vocab_list": null,
  "vocab_size": 33
}

(32290, 3)


In [36]:
embedding_file = "cache/embeddings.pt"

if os.path.exists(embedding_file):
    X = torch.load(embedding_file)
else:
    embeddings = []
    with torch.no_grad():
        for seq in df["sequence"]:
            tokenized_input = esm_tokenizer(seq, return_tensors="pt")
            output = esm_model(**tokenized_input)
            seq_embedding = output.last_hidden_state.mean(dim=1).squeeze()
            embeddings.append(seq_embedding)

    X = torch.stack(embeddings)
    torch.save(X, embedding_file)

# Embedding-dim = 320

X = torch.load(embedding_file)
y = torch.tensor(df["viability"].values, dtype=torch.float32)

# from sklearn.decomposition import PCA
# pca = PCA(n_components=80)
# X = pca.fit_transform(X.numpy())
# print("Explained variance ratio:", pca.explained_variance_ratio_.sum())

In [37]:
import gpytorch

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [42]:

def run_mc_iteration(X, y, test_size=0.2, seed=None, training_iterations=100, lr=0.1):
    # Split dataset
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=seed)

    X_train_tensor = X_train
    X_test_tensor = X_test
    y_train_tensor = y_train
    y_test_tensor = y_test

    # Convert to tensors
    # X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    # y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
    # X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
    # y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

    # Model + likelihood
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    gp_model = ExactGPModel(X_train_tensor, y_train_tensor, likelihood)

    # Training
    gp_model.train()
    likelihood.train()
    optimizer = torch.optim.Adam(gp_model.parameters(), lr=lr)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, gp_model)

    for i in range(training_iterations):
        optimizer.zero_grad()
        output = gp_model(X_train_tensor)
        loss = -mll(output, y_train_tensor)
        loss.backward()
        optimizer.step()

    # Evaluation
    gp_model.eval()
    likelihood.eval()
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        test_predictions = likelihood(gp_model(X_test_tensor))
        test_means = test_predictions.mean.detach().numpy()

    # Metrics
    mae = mean_absolute_error(y_test_tensor, test_means)
    mse = mean_squared_error(y_test_tensor, test_means)
    rmse = np.sqrt(mse)
    tau, p_value = kendalltau(y_test_tensor, test_means)

    return {"mae": mae, "mse": mse, "rmse": rmse, "tau": tau, "p_value": p_value}

n_splits = 4
test_size = 0.2
training_iterations = 250
lr = 0.01
n_jobs=1


results = Parallel(n_jobs=n_jobs, verbose=1)(
    delayed(run_mc_iteration)(X, y, test_size=test_size, seed=i, training_iterations=training_iterations, lr=lr) for i in range(n_splits)
)

#1.5802
# 6m 12s

[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed: 56.2min finished


In [43]:
maes = [r["mae"] for r in results]
mses = [r["mse"] for r in results]
rmses = [r["rmse"] for r in results]
taus = [r["tau"] for r in results]

print(f"sample size: {n_splits}")
print(f"MAE: {np.mean(maes):.4f} ± {np.std(maes):.4f}")
print(f"MSE: {np.mean(mses):.4f} ± {np.std(mses):.4f}")
print(f"RMSE: {np.mean(rmses):.4f} ± {np.std(rmses):.4f}")
print(f"Kendall's Tau: {np.mean(taus):.4f} ± {np.std(taus):.4f}")

sample size: 4
MAE: 1.2303 ± 0.0134
MSE: 2.4936 ± 0.0636
RMSE: 1.5790 ± 0.0201
Kendall's Tau: 0.4179 ± 0.0053


In [40]:
# ## residual plots
#
# import matplotlib.pyplot as plt
# import numpy as np
#
# residuals = y_test - test_means
#
# # 1. Residuals vs Predicted
# plt.figure(figsize=(6, 4))
# plt.scatter(test_means, residuals, alpha=0.6)
# plt.axhline(0, color='red', linestyle='--')
# plt.xlabel("Predicted values")
# plt.ylabel("Residuals")
# plt.title("Residuals vs Predicted")
# plt.show()
#
# # 2. Histogram of residuals
# plt.figure(figsize=(6, 4))
# plt.hist(residuals, bins=30, alpha=0.7, edgecolor='k')
# plt.xlabel("Residual")
# plt.ylabel("Frequency")
# plt.title("Histogram of Residuals")
# plt.show()
#
# # 3. QQ-plot of residuals (normality check)
# import scipy.stats as stats
# plt.figure(figsize=(6, 6))
# stats.probplot(residuals, dist="norm", plot=plt)
# plt.title("QQ Plot of Residuals")
# plt.show()
#
# # 4. Residuals vs each predictor (for multivariable regression)
# if X_test_tensor.shape[1] <= 5:  # avoid plotting too many features
#     X_test_np = X_test_tensor.numpy()
#     for i in range(X_test_np.shape[1]):
#         plt.figure(figsize=(6, 4))
#         plt.scatter(X_test_np[:, i], residuals, alpha=0.6)
#         plt.axhline(0, color='red', linestyle='--')
#         plt.xlabel(f"Feature {i}")
#         plt.ylabel("Residuals")
#         plt.title(f"Residuals vs Feature {i}")
#         plt.show()


In [77]:
# ## validate arbitrary training point
# train_x = X_train_tensor[0].unsqueeze(0)
# true_viability = y_train_tensor[0].item()
#
# with torch.no_grad(), gpytorch.settings.fast_pred_var():
#     observed_pred = likelihood(gp_model(train_x))
#     predicted_mean = observed_pred.mean.item()
#     lower_bound, upper_bound = observed_pred.confidence_region()
#
# print(f"True Viability: {true_viability}")
# print(f"Predicted Viability: {predicted_mean}")
# print(f"95% CI: [{lower_bound.item()}, {upper_bound.item()}]")

True Viability: -3.128000020980835
Predicted Viability: -2.9539124965667725
95% CI: [-7.312348365783691, 1.4045231342315674]
