**Robust Prediction of Enzyme Variant Kinetics with RealKcat**
This notebook predicts enzyme kinetics parameters ($k_{\text{cat}}$ and $K_M$) for given protein sequences and substrates using RealKcat model.

**Workflow:**
1. **Install Dependencies and Download Models**: Set up required libraries and ensure trained model files are available locally (takes ~2 minute).
2. **Choose Mode**: Select `Demo`, `Interactive`, `Bulk`, or `Bulk-large`.
3. **Run Inference**: Provide input (or upload CSV) and generate predictions.
4. **Save Results**: Download predictions as a CSV file if prompted, bulk mode automatically downloads a csv file with predictions when done.

---

**How to Navigate (few clicks):**
- Start by running the two cells sequentially.
- Select an inference mode from the dropdown.
- Follow on-screen prompts for input.
- View results below each executed cell.

_No coding experience is required! Keep sections cells collapsed to keep the notebook organized._

In [1]:
#@title Install dependencies, download and unzip model weights, get dependencies  [~2 minute]
!pip install transformers fair-esm torch torchvision torchaudio --quiet
import os
kcat_model_path = "model_weights/kcat_model.pkl"
km_model_path = "model_weights/km_model.pkl"
if not (os.path.exists(kcat_model_path) and os.path.exists(km_model_path)):
    print("Model weights not found locally. Downloading...")
    !wget https://chowdhurylab.github.io/assets/database/WT_MD_database/model_weights.zip -O model_weights.zip
    !unzip -o model_weights.zip
else:
    print("Model weights already exist locally. Skipping download.")

import sys
import os
import warnings
warnings.filterwarnings("ignore")

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import joblib
import random

from transformers import AutoTokenizer, AutoModel
import esm

from google.colab import files
import pandas as pd
import io
from IPython.display import display, Latex

# Helper function to convert floats to "number x10^(exponent)" format.
def format_sci(value):
    # Format with standard scientific notation
    s = f"{value:.2e}"  # e.g., 3.32e-08
    if 'e' in s:
        base, exp = s.split('e')
        exp = int(exp)
        # Replace the standard 'e' notation with 'x10^'
        # If exponent is negative: "3.32e-08" => "3.32x10^-8"
        # If exponent is positive: "3.32e+07" => "3.32x10^7"
        return f"{base}x10^{exp}"
    else:
        # For non-exponential numbers, just return them as is
        return s
# Utility functions for reproducibility
def check_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# Custom dataset class
class TensorDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
    def get_labels(self):
        return self.labels

def dataset_to_tensors(dataset):
    loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)
    data, labels = next(iter(loader))
    return data, labels

def standardize_x_global_separate(data, global_mean_1, global_std_1, global_mean_2, global_std_2):
    X1, X2 = data[:, :1280], data[:, 1280:]
    global_std_1, global_std_2 = torch.clamp(global_std_1, min=1e-7), torch.clamp(global_std_2, min=1e-7)
    X1_standardized = (X1 - global_mean_1) / global_std_1
    X2_standardized = (X2 - global_mean_2) / global_std_2
    return torch.cat((X1_standardized, X2_standardized), dim=1).squeeze(1)

class StandardizedDatasetGlobalSeparate(Dataset):
    def __init__(self, subset, global_mean_1, global_std_1, global_mean_2, global_std_2):
        self.subset = subset
        self.global_mean_1 = global_mean_1
        self.global_std_1 = global_std_1
        self.global_mean_2 = global_mean_2
        self.global_std_2 = global_std_2
    def __len__(self):
        return len(self.subset)
    def __getitem__(self, idx):
        x, y1 = self.subset[idx]
        if len(x.shape) == 1:
            x = x.unsqueeze(1)
        x_standardized = standardize_x_global_separate(x, self.global_mean_1, self.global_std_1, self.global_mean_2, self.global_std_2)
        return x_standardized, y1

def apply_global_standardization_separate(dataset, global_mean_1, global_std_1, global_mean_2, global_std_2):
    return StandardizedDatasetGlobalSeparate(dataset, global_mean_1, global_std_1, global_mean_2, global_std_2)


def load_esm2_model(device, work_dir="."):
    model_url = "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt"
    model_filename = model_url.split("/")[-1]
    model_path = os.path.join(work_dir, model_filename)

    # Check if we already have the model weights locally
    if not os.path.exists(model_path):
        print("Downloading ESM2 model weights...")
        # Download and load model_data from URL
        model_data = torch.hub.load_state_dict_from_url(
            model_url,
            progress=True,
            map_location=device
        )
        # Save model_data for future runs
        torch.save(model_data, model_path)
    else:
        print("ESM2 model weights found locally. Loading from disk...")
        model_data = torch.load(model_path, map_location=device)

    # Now model_data is defined, we can load the ESM model
    esm_model, alphabet = esm.pretrained.load_model_and_alphabet_core("esm2_t33_650M_UR50D", model_data)
    esm_model.eval().to(device)
    return esm_model, alphabet


# @title KcatInference Class Definition

class KcatInference:
    def __init__(self, model_path, device=None):
        self.device = device if device else check_device()
        self.model = joblib.load(model_path)  # Load the model
        set_seed(42)
        self.X_test_tensor = None
        self.y1_test_tensor = None

    def load_data_from_pairs(self, sequence_substrate_pairs):
        # sequence_substrate_pairs should be a list of (sequence, SMILES)
        # print(f"Processing {len(sequence_substrate_pairs)} samples...")
        embeddings_list = []
        long_chain = []
        # print("Loading ESM model...")
        esm_model, alphabet = load_esm2_model(self.device)
        batch_converter = alphabet.get_batch_converter()
        esm_model.eval().to(self.device)
        # Load ChemBERTa
        # print("Loading ChemBERTa model...")
        chemberta_tokenizer = AutoTokenizer.from_pretrained('seyonec/PubChem10M_SMILES_BPE_450k')
        chemberta_model = AutoModel.from_pretrained('seyonec/PubChem10M_SMILES_BPE_450k')
        chemberta_model.eval().to(self.device)
        print("Models loaded. Embedding sequences and substrates...")

        for index, (sequence, substrate) in enumerate(sequence_substrate_pairs, start=1):
            if not sequence or not substrate:
                print(f"Skipping sample {index} due to missing sequence or substrate.")
                continue
            if len(sequence) > 1022:
                long_chain.append(index)
                print(f"Skipping sample {index} due to excessive sequence length.")
                continue
            # ESM embedding
            try:
                seq = (f'sample_{index}', sequence)
                batch_labels, batch_strs, batch_tokens = batch_converter([seq])
                batch_tokens = batch_tokens.to(self.device)
                batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
                with torch.no_grad():
                    results = esm_model(batch_tokens, repr_layers=[33], return_contacts=True)
                token_representations = results["representations"][33][0, 1 : batch_lens - 1]
                embedded_sequence = token_representations.mean(dim=0).type(torch.float32)
            except Exception as e:
                print(f"Error processing sequence at sample {index}: {e}")
                continue
            # ChemBERTa embedding
            try:
                smiles = [substrate]
                inputs = chemberta_tokenizer(smiles, return_tensors='pt', padding=True, truncation=False)
                inputs = {key: val.to(self.device) for key, val in inputs.items()}
                if inputs['input_ids'].shape[1] > 512:
                    print(f"Skipping sample {index} due to SMILES length exceeding model capacity.")
                    continue
                with torch.no_grad():
                    outputs = chemberta_model(**inputs)
                last_hidden_state = outputs.last_hidden_state
                chemberta_embedding = torch.mean(last_hidden_state, dim=1).squeeze(0).type(torch.float32)
            except Exception as e:
                print(f"Error processing SMILES at sample {index}: {e}")
                continue
            # Flatten embeddings if needed
            if embedded_sequence.dim() > 1:
                embedded_sequence = embedded_sequence.flatten()
            if chemberta_embedding.dim() > 1:
                chemberta_embedding = chemberta_embedding.flatten()
            feature = torch.cat((embedded_sequence, chemberta_embedding))
            embeddings_list.append(feature)
        if len(embeddings_list) == 0:
            raise ValueError("No valid samples were processed.")
        print(f"Total valid samples processed: {len(embeddings_list)}")
        self.X_test_tensor = torch.stack(embeddings_list).to(self.device)
        self.y1_test_tensor = torch.zeros(len(embeddings_list), dtype=torch.long).to(self.device)
    def standardize_test_data(self, global_mean_1, global_std_1, global_mean_2, global_std_2):
        dataset = TensorDataset(self.X_test_tensor, self.y1_test_tensor)
        self.test_dataset_std = apply_global_standardization_separate(dataset, global_mean_1, global_std_1, global_mean_2, global_std_2)
    def convert_to_numpy(self):
        X_test_data, test_y1 = dataset_to_tensors(self.test_dataset_std)
        return X_test_data.cpu().numpy(), test_y1.cpu().numpy()
    def predict(self, X_test_data):
        return self.model.predict(X_test_data)
    def display_prediction_ranges_kcat(self, predictions, class_ranges):
        # print("\n=== kcat Prediction Results ===")
        for i, pred_class in enumerate(predictions):
            low, high = class_ranges[pred_class]["low"], class_ranges[pred_class]["high"]
            low_str = format_sci(low)
            high_str = format_sci(high)
            print(f"Sample {i + 1}: Predicted Class = {pred_class}, kcat range = [{low_str}, {high_str}]")
    def display_prediction_ranges_km(self, predictions, class_ranges):
        # print("\n=== Km Prediction Results ===")
        for i, pred_class in enumerate(predictions):
            low, high = class_ranges[pred_class]["low"], class_ranges[pred_class]["high"]
            low_str = format_sci(low)
            high_str = format_sci(high)
            print(f"Sample {i + 1}: Predicted Class = {pred_class}, km range = [{low_str}, {high_str}]")

# #@title Configuration of Standardization Parameters and Class Ranges
device = check_device()
print("Using device:", device)

# Global standardization params (from training)
global_mean_1 = torch.tensor(-0.0006011285004206002, device=device)
global_std_1 = torch.tensor(0.18902993202209473, device=device)
global_mean_2 = torch.tensor(-0.00015002528380136937, device=device)
global_std_2 = torch.tensor(0.6113553047180176, device=device)

class_ranges_kcat = {
    0: {"low": 0.0, "high": 3.32e-8},
    1: {"low": 3.33e-8, "high": 1.0e-2},
    2: {"low": 1.01e-2, "high": 1.0e-1},
    3: {"low": 1.01e-1, "high": 1.0},
    4: {"low": 1.001, "high": 10.0},
    5: {"low": 1.004e1, "high": 1.0e2},
    6: {"low": 1.0025e2, "high": 1.0e3},
    7: {"low": 1.002e3, "high": 7.0e7}
}

class_ranges_km = {
    0: {"low": 1.0e-10, "high": 1.0e-5},
    1: {"low": 1.01e-5, "high": 1.0e-4},
    2: {"low": 1.002e-4, "high": 1.0e-3},
    3: {"low": 1.002e-3, "high": 1.0e-2},
    4: {"low": 1.008e-2, "high": 1.0e-1},
    5: {"low": 1.01e-1, "high": 1.02e2},
}

Model weights not found locally. Downloading...
--2024-12-06 22:44:00--  https://chowdhurylab.github.io/assets/database/WT_MD_database/model_weights.zip
Resolving chowdhurylab.github.io (chowdhurylab.github.io)... 185.199.109.153, 185.199.108.153, 185.199.110.153, ...
Connecting to chowdhurylab.github.io (chowdhurylab.github.io)|185.199.109.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 69295457 (66M) [application/zip]
Saving to: ‘model_weights.zip’


2024-12-06 22:44:01 (199 MB/s) - ‘model_weights.zip’ saved [69295457/69295457]

Archive:  model_weights.zip
   creating: model_weights/
  inflating: model_weights/kcat_model.pkl  
  inflating: model_weights/km_model.pkl  
Using device: cpu


**Choose RealKcat Mode of Inference: $k_{cat}$ and $K_{M}$**

**[Demo takes ~2 minutes]**

Notes:
- **Demo**: For testing purposes with predefined inputs. Runs in ~2 minutes.
- **Interactive**: Allows manual entry of sequences and SMILES for flexible input.
- **Bulk**: Processes uploaded CSV files, ideal for < 100 sequence-substrate pairs.
- **Bulk-large**: Optimized for processing very large datasets with lower memory consumption in batches of size 100 for memory efficiency and slower speed.


In [2]:
#@title Choose RealKcat Mode of Inference,  $k_{cat}$ and $K_{M}$  [Demo  takes ~2 minutes]
mode = "Demo"  #@param ["Demo", "Interactive", "Bulk", "Bulk-large"]

if mode == "Bulk":
    print("Please upload a CSV file containing 'sequence' and 'Isomeric SMILES' columns.")
    uploaded = files.upload()
    csv_filename = list(uploaded.keys())[0]

    # Load CSV and prepare data
    df = pd.read_csv(csv_filename)
    sequence_substrate_pairs = list(zip(df['sequence'], df['Isomeric SMILES']))

    # Generate embeddings using KcatInference (just as a convenient class)
    temp_inference = KcatInference(model_path=kcat_model_path, device=device)
    temp_inference.load_data_from_pairs(sequence_substrate_pairs)
    temp_inference.standardize_test_data(global_mean_1, global_std_1, global_mean_2, global_std_2)
    X_test, _ = temp_inference.convert_to_numpy()

    # Kcat predictions
    kcat_inference = KcatInference(model_path=kcat_model_path, device=device)
    y_pred_kcat = kcat_inference.predict(X_test)

    # KM predictions
    km_inference = KcatInference(model_path=km_model_path, device=device)
    y_pred_km = km_inference.predict(X_test)

    # Map predictions to ranges
    kcat_low = [class_ranges_kcat[p]["low"] for p in y_pred_kcat]
    kcat_high = [class_ranges_kcat[p]["high"] for p in y_pred_kcat]
    km_low = [class_ranges_km[p]["low"] for p in y_pred_km]
    km_high = [class_ranges_km[p]["high"] for p in y_pred_km]

    # Convert to formatted strings
    kcat_low_str = [format_sci(val) for val in kcat_low]
    kcat_high_str = [format_sci(val) for val in kcat_high]
    km_low_str = [format_sci(val) for val in km_low]
    km_high_str = [format_sci(val) for val in km_high]

    df['Predicted_Kcat_low'] = kcat_low_str
    df['Predicted_Kcat_high'] = kcat_high_str
    df['Predicted_KM_low'] = km_low_str
    df['Predicted_KM_high'] = km_high_str

    # Save updated CSV
    output_csv = "inference_results.csv"
    df.to_csv(output_csv, index=False)
    print("Inference complete. The updated results are saved in 'inference_results.csv'.")

    # Provide a download link
    files.download(output_csv)

elif mode == "Bulk-large":
    print("Please upload a CSV file containing 'sequence' and 'Isomeric SMILES' columns.")
    uploaded = files.upload()
    csv_filename = list(uploaded.keys())[0]

    # Load CSV and prepare data
    df = pd.read_csv(csv_filename)
    sequence_substrate_pairs = list(zip(df['sequence'], df['Isomeric SMILES']))

    # Batch size for memory-efficient processing
    batch_size = 100
    total_samples = len(sequence_substrate_pairs)
    num_batches = (total_samples + batch_size - 1) // batch_size  # Ceiling division

    # Placeholder for predictions
    kcat_low_list, kcat_high_list = [], []
    km_low_list, km_high_list = [], []

    print(f"Processing {total_samples} samples in batches of {batch_size}...")

    for batch_idx in range(num_batches):
        # Determine batch start and end indices
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, total_samples)

        # Extract the batch
        batch_pairs = sequence_substrate_pairs[start_idx:end_idx]
        print(f"Processing batch {batch_idx + 1}/{num_batches} (samples {start_idx + 1}-{end_idx})...")

        # Generate embeddings using KcatInference for the current batch
        temp_inference = KcatInference(model_path=kcat_model_path, device=device)
        temp_inference.load_data_from_pairs(batch_pairs)
        temp_inference.standardize_test_data(global_mean_1, global_std_1, global_mean_2, global_std_2)
        X_test, _ = temp_inference.convert_to_numpy()

        # Kcat predictions for the batch
        kcat_inference = KcatInference(model_path=kcat_model_path, device=device)
        y_pred_kcat = kcat_inference.predict(X_test)

        # KM predictions for the batch
        km_inference = KcatInference(model_path=km_model_path, device=device)
        y_pred_km = km_inference.predict(X_test)

        # Map predictions to ranges
        kcat_low = [class_ranges_kcat[p]["low"] for p in y_pred_kcat]
        kcat_high = [class_ranges_kcat[p]["high"] for p in y_pred_kcat]
        km_low = [class_ranges_km[p]["low"] for p in y_pred_km]
        km_high = [class_ranges_km[p]["high"] for p in y_pred_km]

        # Append results to the respective lists
        kcat_low_list.extend(kcat_low)
        kcat_high_list.extend(kcat_high)
        km_low_list.extend(km_low)
        km_high_list.extend(km_high)

    # Convert ranges to formatted strings
    kcat_low_str = [format_sci(val) for val in kcat_low_list]
    kcat_high_str = [format_sci(val) for val in kcat_high_list]
    km_low_str = [format_sci(val) for val in km_low_list]
    km_high_str = [format_sci(val) for val in km_high_list]

    # Add predictions to the DataFrame
    df['Predicted_Kcat_low'] = kcat_low_str
    df['Predicted_Kcat_high'] = kcat_high_str
    df['Predicted_KM_low'] = km_low_str
    df['Predicted_KM_high'] = km_high_str

    # Save updated CSV
    output_csv = "inference_results.csv"
    df.to_csv(output_csv, index=False)
    print("Inference complete. The updated results are saved in 'inference_results.csv'.")

    # Provide a download link
    files.download(output_csv)

elif mode == "Interactive" or mode == "Demo":
    sequence_substrate_pairs = []
    max_entries = 10
    count = 0

    if mode == "Demo":
        print("\nRunning in Demo Mode...")
        # Predefined sequence and SMILES for demo
        sequence_substrate_pairs = [
            (
                "MKVAVLGAAGGIGQALALLLKTQLPSGSELSLYDIAPVTPGVAVDLSHIPTAVKIKGFSGEDATPALEGADVVLISAGVARKPGMDRSDLFNVNAGIVKNLVQQVAKTCPKACIGIITNPVNTTVAIAAEVLKKAGVYDKNKLFGVTTLDIIRSNTFVAELKGKQPGEVEVPVIGGHSGVTILPLLSQVPGVSFTEQEVADLTKRIQNAGTEVVEAKAGGGSATLSMGQAAARFGLSLVRALQGEQGVVECAYVEGDGQYARFFSQPLLLGKNGVEERKSIGTLSAFEQNALEGMLDTLKKDIALGEEFVNK",
                "C(C(=O)C(=O)O)C(=O)O",
            ),
            (
                "MGVEQILKRKTGVIVGEDVHNLFTYAKEHKFAIPAINVTSSSTAVAALEAARDSKSPIILQTSNGGAAYFAGKGISNEGQNASIKGAIAAAHYIRSIAPAYGIPVVLHSDHCAKKLLPWFDGMLEADEAYFKEHGEPLFSSHMLDLSEETDEENISTCVKYFKRMAAMDQWLEMEIGITGGEEDGVNNENADKEDLYTKPEQVYNVYKALHPISPNFSIAAAFGNCHGLYAGDIALRPEILAEHQKYTREQVGCKEEKPLFLVFHGGSGSTVQEFHTGIDNGVVKVNLDTDCQYAYLTGIRDYVLNKKDYIMSPVGNPEGPEKPNKKFFDPRVWVREGEKTMGAKITKSLETFRTTNTL",
                "C([C@H](C=O)O)OP(=O)(O)O",
            ),
        ]
        count = len(sequence_substrate_pairs)

    else:
      while count < max_entries:
          print(f"\nEntry {count+1}/{max_entries}:")
          user_sequence = input("Enter Sequence (or press Enter to stop): ").strip()
          if user_sequence == "":
              # user pressed Enter without input -> stop
              break
          user_smiles = input("Enter Isomeric SMILES: ").strip()
          if user_smiles == "":
              # If SMILES is empty, skip this entry
              print("No SMILES provided. Skipping this entry.")
              continue

          sequence_substrate_pairs.append((user_sequence, user_smiles))
          count += 1

          # Optionally ask if user wants to add another entry
          if count < max_entries:
              cont = input("Add another? (y/n): ").strip().lower()
              if cont not in ["y", "yes"]:
                  break

    if len(sequence_substrate_pairs) == 0:
        print("No entries provided. Exiting.")
    else:
        # Now run inference as done in bulk mode
        print(f"Processing {len(sequence_substrate_pairs)} samples...")

        # Generate embeddings (from previously defined logic)
        temp_inference = KcatInference(model_path=kcat_model_path, device=device)
        temp_inference.load_data_from_pairs(sequence_substrate_pairs)
        temp_inference.standardize_test_data(global_mean_1, global_std_1, global_mean_2, global_std_2)
        X_test, _ = temp_inference.convert_to_numpy()

        # Kcat predictions
        kcat_inference = KcatInference(model_path=kcat_model_path, device=device)
        y_pred_kcat = kcat_inference.predict(X_test)

        # KM predictions
        km_inference = KcatInference(model_path=km_model_path, device=device)
        y_pred_km = km_inference.predict(X_test)

        # Display predictions
        display(Latex(r"\textbf{\large $k_{\text{cat}}$ Prediction Results:}"))
        kcat_inference.display_prediction_ranges_kcat(y_pred_kcat, class_ranges_kcat)

        display(Latex(r"\textbf{\large $K_{M}$ Prediction Results:}"))
        km_inference.display_prediction_ranges_km(y_pred_km, class_ranges_km)

        # Optional save to CSV
        if mode == "Interactive":
          save_results = input("Save results to CSV? (y/n): ").strip().lower()
          if save_results in ["y", "yes"]:
              df = pd.DataFrame(sequence_substrate_pairs, columns=["sequence", "SMILES"])
              kcat_low = [class_ranges_kcat[p]["low"] for p in y_pred_kcat]
              kcat_high = [class_ranges_kcat[p]["high"] for p in y_pred_kcat]
              km_low = [class_ranges_km[p]["low"] for p in y_pred_km]
              km_high = [class_ranges_km[p]["high"] for p in y_pred_km]
              df["Kcat_low"] = [format_sci(val) for val in kcat_low]
              df["Kcat_high"] = [format_sci(val) for val in kcat_high]
              df["KM_low"] = [format_sci(val) for val in km_low]
              df["KM_high"] = [format_sci(val) for val in km_high]
              df.to_csv("interactive_results.csv", index=False)
              print("Results saved to 'interactive_results.csv'.")




Running in Demo Mode...
Processing 2 samples...
Downloading ESM2 model weights...
Models loaded. Embedding sequences and substrates...
Total valid samples processed: 2


<IPython.core.display.Latex object>

Sample 1: Predicted Class = 6, kcat range = [1.00x10^2, 1.00x10^3]
Sample 2: Predicted Class = 5, kcat range = [1.00x10^1, 1.00x10^2]


<IPython.core.display.Latex object>

Sample 1: Predicted Class = 1, km range = [1.01x10^-5, 1.00x10^-4]
Sample 2: Predicted Class = 2, km range = [1.00x10^-4, 1.00x10^-3]
