# eGFP50 model â€” training notebook

Converted from `eGFP50_model.py`. This notebook contains the full pipeline to train a CNN+BiLSTM+Attention model
to predict mean ribosome load (rl) from 5' UTR sequence and structure (egfp_50 dataset). It is annotated with
explanations and split into runnable cells for interactive use.

**Author:** Mike Wang

**Notes**:
- Update DATA_PATH and MODEL_OUT_PATH in the Configuration cell before running.
- Long-running cells (model training) may take a long time depending on hardware; consider working with a subset for quick testing.

## Configuration
Edit file paths and hyperparameters below before running.

In [None]:
# Configuration / constants
DATA_PATH = "~/data/GSM3130435_egfp_unmod_1_structure_feature_table_maxBPspan_30.txt"
MODEL_OUT_PATH = "~/models/synthetic/model_eGFP_50.h5"
MAX_ROWS = 280_000
SEQ_LEN = 80
RANDOM_STATE = 2021
BATCH_SIZE = 128
EPOCHS = 999

# DATA_PATH, MODEL_OUT_PATH, MAX_ROWS, SEQ_LEN, RANDOM_STATE


## Imports and environment setup

In [None]:
# Imports and logging
import logging
import os
from typing import Tuple, Dict

import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.model_selection import train_test_split
from scipy.stats import spearmanr, linregress

import tensorflow as tf
from tensorflow.keras import layers as L
from tensorflow.keras import initializers, regularizers, constraints
from tensorflow.keras.layers import Layer
from tensorflow.keras.callbacks import EarlyStopping

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
logger = logging.getLogger(__name__)

# reproducibility
tf.random.set_seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)

print('TensorFlow version:', tf.__version__)


## One-hot encoding function

In [None]:
def one_hot_encode(df: pd.DataFrame, seqcol: str = "sequence", strucol: str = "structure_full", seq_len: int = SEQ_LEN):
    """
    One-hot encode sequences and sequence+structure for model input.
    - seq_vectors: (N, seq_len, 4) order [A, C, G, T]
    - stru_vectors: (N, seq_len, 12) combinations of base + structure char (A(, A), A., ...)
    Unknown characters and padding remain zeros.
    """
    nuc_d: Dict[str, list] = {
        "A": [1, 0, 0, 0],
        "C": [0, 1, 0, 0],
        "G": [0, 0, 1, 0],
        "T": [0, 0, 0, 1],
        "U": [0, 0, 0, 1],
        "N": [0, 0, 0, 0],
    }

    bases = ["A", "C", "G", "T"]
    struct_symbols = ["(", ")", "."]
    stru_d = {}
    idx = 0
    for b in bases:
        for s in struct_symbols:
            vec = [1 if i == idx else 0 for i in range(12)]
            stru_d[f"{b}{s}"] = vec
            idx += 1
    fallback_12 = [0] * 12

    N = len(df)
    seq_vectors = np.zeros((N, seq_len, 4), dtype=np.float32)
    stru_vectors = np.zeros((N, seq_len, 12), dtype=np.float32)

    seq_series = df[seqcol].fillna("").astype(str).str.upper()
    stru_series = df[strucol].fillna("").astype(str).str.upper()

    for i, (seq, stru) in enumerate(zip(seq_series, stru_series)):
        seq = seq[:seq_len]
        stru = stru[:seq_len]
        for j, ch in enumerate(seq):
            seq_vectors[i, j, :] = nuc_d.get(ch, [0, 0, 0, 0])
        for j in range(seq_len):
            if j < len(seq) and j < len(stru):
                key = seq[j] + stru[j]
                stru_vectors[i, j, :] = stru_d.get(key, fallback_12)
    return seq_vectors, stru_vectors


## Attention Layer

In [None]:
class Attention(Layer):
    """
    Attention layer returning (context_vector, attention_weights).
    """
    def __init__(self, bias: bool = True, W_regularizer=None, b_regularizer=None, W_constraint=None, b_constraint=None, **kwargs):
        super().__init__(**kwargs)
        self.supports_masking = True
        self.init = initializers.get("glorot_uniform")
        self.bias = bias
        self.W_regularizer = regularizers.get(W_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)
        self.W_constraint = constraints.get(W_constraint)
        self.b_constraint = constraints.get(b_constraint)
        self.features_dim = 0

    def build(self, input_shape):
        assert len(input_shape) == 3
        self.W = self.add_weight(shape=(input_shape[-1],), initializer=self.init, name=f"{self.name}_W",
                                 regularizer=self.W_regularizer, constraint=self.W_constraint)
        self.features_dim = input_shape[-1]
        if self.bias:
            self.b = self.add_weight(shape=(1,), initializer="zeros", name=f"{self.name}_b",
                                     regularizer=self.b_regularizer, constraint=self.b_constraint)
        else:
            self.b = None
        super().build(input_shape)

    def compute_mask(self, inputs, mask=None):
        return None

    def call(self, x, mask=None):
        features_dim = self.features_dim
        step_dim = tf.shape(x)[1]
        flat_x = tf.reshape(x, (-1, features_dim))
        e = tf.reshape(tf.matmul(flat_x, tf.reshape(self.W, (features_dim, 1))), (-1, step_dim))
        if self.bias:
            e = e + self.b
        e = tf.tanh(e)
        a = tf.exp(e)
        if mask is not None:
            a *= tf.cast(mask, tf.float32)
        a /= tf.cast(tf.reduce_sum(a, axis=1, keepdims=True) + tf.keras.backend.epsilon(), tf.float32)
        a_expanded = tf.expand_dims(a, axis=-1)
        c = tf.reduce_sum(a_expanded * x, axis=1)
        return c, a_expanded

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.features_dim)


## Model builder

In [None]:
def build_model(maxlen: int = SEQ_LEN, hidden_dim: int = 64, dropout: float = 0.3, dropout_rate: float = 0.2):
    input_stru = L.Input(shape=(maxlen, 12), name="Input_stru")
    input_seq = L.Input(shape=(maxlen, 4), name="Input_seq")
    input_mfe = L.Input(shape=(1,), name="Input_mfe")
    # Structure branch
    s = L.Conv1D(32, 1, padding="same", name="Conv1D_stru")(input_stru)
    s = L.BatchNormalization()(s)
    s = L.ReLU()(s)
    s_branch1 = L.Conv1D(64, 7, padding="same", name="Conv1D_stru_feature2")(s)
    s_branch1 = L.BatchNormalization()(s_branch1)
    s_branch1 = L.ReLU()(s_branch1)
    s_branch2 = L.Conv1D(64, 9, padding="same", name="Conv1D_stru_feature3")(s)
    s_branch2 = L.BatchNormalization()(s_branch2)
    s_branch2 = L.ReLU()(s_branch2)
    lstm_s1 = L.Bidirectional(L.LSTM(hidden_dim, dropout=dropout, return_sequences=True, kernel_initializer="orthogonal"), name="LSTM_stru2")(s_branch1)
    lstm_s1 = L.Bidirectional(L.LSTM(hidden_dim, dropout=dropout, return_sequences=True, kernel_initializer="orthogonal"), name="LSTM_stru2_1")(lstm_s1)
    lstm_s2 = L.Bidirectional(L.LSTM(hidden_dim, dropout=dropout, return_sequences=True, kernel_initializer="orthogonal"), name="LSTM_stru3")(s_branch2)
    lstm_s2 = L.Bidirectional(L.LSTM(hidden_dim, dropout=dropout, return_sequences=True, kernel_initializer="orthogonal"), name="LSTM_stru3_1")(lstm_s2)
    merged_s = L.concatenate([lstm_s1, lstm_s2], axis=-1)
    merged_s = L.Conv1D(256, 1, padding="same", name="Conv1D_stru_merged")(merged_s)
    merged_s = L.BatchNormalization()(merged_s)
    merged_s = L.ReLU()(merged_s)
    stru_out, attention_weight_stru = Attention(name="Attention_stru")(merged_s)
    stru_out = L.Dropout(0.3)(stru_out)
    # Sequence branch
    q = L.Conv1D(32, 1, padding="same", name="Conv1D_seq")(input_seq)
    q = L.BatchNormalization()(q)
    q = L.ReLU()(q)
    q1 = L.Conv1D(64, 3, padding="same", name="Conv1D_seq_feature1")(q)
    q1 = L.BatchNormalization()(q1)
    q1 = L.ReLU()(q1)
    q2 = L.Conv1D(64, 5, padding="same", name="Conv1D_seq_feature2")(q)
    q2 = L.BatchNormalization()(q2)
    q2 = L.ReLU()(q2)
    lstm_q1 = L.Bidirectional(L.LSTM(hidden_dim, dropout=dropout, return_sequences=True, kernel_initializer="orthogonal"), name="LSTM_seq")(q1)
    lstm_q1 = L.Bidirectional(L.LSTM(hidden_dim, dropout=dropout, return_sequences=True, kernel_initializer="orthogonal"), name="LSTM_seq1")(lstm_q1)
    lstm_q2 = L.Bidirectional(L.LSTM(hidden_dim, dropout=dropout, return_sequences=True, kernel_initializer="orthogonal"), name="LSTM_seq2")(q2)
    lstm_q2 = L.Bidirectional(L.LSTM(hidden_dim, dropout=dropout, return_sequences=True, kernel_initializer="orthogonal"), name="LSTM_seq2_1")(lstm_q2)
    merged_q = L.concatenate([lstm_q1, lstm_q2], axis=-1)
    merged_q = L.Conv1D(256, 1, padding="same", name="Conv1D_seq_merged")(merged_q)
    merged_q = L.BatchNormalization()(merged_q)
    merged_q = L.ReLU()(merged_q)
    seq_out, attention_weight_seq = Attention(name="Attention_seq")(merged_q)
    seq_out = L.Dropout(0.3)(seq_out)
    # Final head
    x_flat = L.Concatenate()([stru_out, seq_out, input_mfe])
    x = L.Dense(256, activation="relu")(x_flat)
    x = L.Dropout(0.5)(x)
    x = L.Dense(128, activation="relu")(x)
    x = L.Dropout(0.2)(x)
    out = L.Dense(1, activation="linear", name="out")(x)
    model = tf.keras.Model(inputs=[input_stru, input_seq, input_mfe], outputs=out)
    model.compile(optimizer="adam", loss="mean_squared_error")
    return model


## Evaluation helpers

In [None]:
def r2_score_from_slope(y_true: np.ndarray, y_pred: np.ndarray):
    """Compute r-squared using scipy.linregress (same as original script)."""
    slope, intercept, r_value, p_value, std_err = linregress(y_true, y_pred)
    return r_value ** 2

## Run training pipeline

In [None]:
# Load data
logger.info("Loading data from %s", DATA_PATH)
df = pd.read_csv(DATA_PATH, sep="\t")
df.reset_index(drop=True, inplace=True)
df = df.iloc[:MAX_ROWS].copy()
logger.info("Data shape after slicing: %s", df.shape)

# Train/test split consistent with original script
e_test = df.iloc[:20000].copy()
e_train = df.iloc[20000:].copy()
logger.info("Train / test split sizes: %d / %d", len(e_train), len(e_test))

# One-hot encoding for train
logger.info("One-hot encoding train sequences (seq_len=%d)", SEQ_LEN)
seq_e_train, stru_e_train = one_hot_encode(e_train, seqcol="sequence", strucol="structure_full", seq_len=SEQ_LEN)

# Norm_mfe and scaling
e_train["Norm_mfe"] = e_train["MFE_full"] / 80.0
e_test["Norm_mfe"] = e_test["MFE_full"] / 80.0
mm_scaler = MinMaxScaler(feature_range=(0, 1))
scaled_train_meta = mm_scaler.fit_transform(e_train[["Norm_mfe"]].values)
scaled_test_meta = mm_scaler.transform(e_test[["Norm_mfe"]].values)

# Scale target rl on training set
target_scaler = StandardScaler()
e_train["scaled_rl"] = target_scaler.fit_transform(e_train["rl"].values.reshape(-1, 1)).flatten()

# Build model
logger.info("Building model")
model = build_model(maxlen=SEQ_LEN)
model.summary()

# Train/validation split
X_train_seq, X_val_seq, X_train_stru, X_val_stru, X_train_meta, X_val_meta, y_train, y_val = train_test_split(
    seq_e_train, stru_e_train, scaled_train_meta, e_train["scaled_rl"].values, test_size=0.2, random_state=42
)

early_stopping = EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True)
history = model.fit(
    [X_train_stru, X_train_seq, X_train_meta],
    y_train,
    validation_data=([X_val_stru, X_val_seq, X_val_meta], y_val),
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    verbose=1
)



In [None]:
# Save model
os.makedirs(os.path.dirname(MODEL_OUT_PATH), exist_ok=True)
model.save(MODEL_OUT_PATH)
logger.info("Saved trained model to %s", MODEL_OUT_PATH)

# Test encoding and prediction
logger.info("One-hot encoding test sequences (seq_len=%d)", SEQ_LEN)
seq_e_test, stru_e_test = one_hot_encode(e_test, seqcol="sequence", strucol="structure_full", seq_len=SEQ_LEN)

logger.info("Predicting on test set")
preds_scaled = model.predict([stru_e_test, seq_e_test, scaled_test_meta]).reshape(-1)
preds_orig = target_scaler.inverse_transform(preds_scaled.reshape(-1, 1)).reshape(-1)

e_test = e_test.copy()
e_test["pred"] = preds_orig

# Evaluation metrics
r2 = r2_score_from_slope(e_test["rl"].values, e_test["pred"].values)
spearman_corr, spearman_p = spearmanr(e_test["rl"].values, e_test["pred"].values)
mse = ((e_test["rl"].values - e_test["pred"].values) ** 2).mean()
rmse = mse ** 0.5

logger.info("Evaluation results on test set:")
logger.info("  r-squared (linear regression) = %.4f", r2)
logger.info("  Spearman rho = %.4f (p=%.3e)", spearman_corr, spearman_p)
logger.info("  MSE = %.6f, RMSE = %.6f", mse, rmse)

# Show some example predictions
e_test[["rl","pred"]].head()