In [16]:
%pip install biopython scikit-learn tensorflow==2.12.0 requests tqdm tensorflow-metal

[31mERROR: Could not find a version that satisfies the requirement tensorflow==2.12.0 (from versions: 2.20.0rc0, 2.20.0)[0m[31m
[0m[31mERROR: No matching distribution found for tensorflow==2.12.0[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


In [8]:
import os, sys, json, math, random
from collections import Counter
from tqdm import tqdm
from Bio import SeqIO
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, matthews_corrcoef, cohen_kappa_score
)
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Bidirectional, LSTM, Dense, Dropout, BatchNormalization, Flatten
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint


In [9]:
FASTA_PATH = "Vista_Dataset/vista_sequence.fasta"
RESULT_DIR = "results"
os.makedirs(RESULT_DIR, exist_ok=True)
SEED = 42
EPOCHS = 500      # paper value
BATCH_SIZE = 32
MAX_SEQ_LEN = 4000   # set None for full length (may be large)
QUICK_TEST = False   # True = quick smoke run

# Reproducibility
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)


In [10]:
def load_fasta_records(path):
    records = []
    for rec in SeqIO.parse(path, "fasta"):
        sid = rec.id
        seq = str(rec.seq).upper()
        # species label from id prefix (hs / mm)
        species = "human" if sid.startswith("hs") else "mouse" if sid.startswith("mm") else "unknown"
        records.append({"id": sid, "species": species, "sequence": seq})
    return pd.DataFrame(records)

df = load_fasta_records(FASTA_PATH)
print(f"Loaded {len(df)} sequences from {FASTA_PATH}")

# Infer positive/negative if possible
df["status"] = "positive"  # if your FASTA only contains enhancers; adjust if negatives exist

FileNotFoundError: [Errno 2] No such file or directory: 'Vista_Dataset/vista_sequence.fasta'

In [3]:
SCENARIO1_PARAMS = {
    # Predict: human vs mouse (only enhancer sequences)
    "bi_lstm_units": [256, 128, 64],
    "dropouts": [0.15, 0.20, 0.20],
    "dense_units": [512, 256, 128],
    "activation": "selu",
    "output_activation": "sigmoid",
    "loss": "binary_crossentropy",
    "optimizer": tf.keras.optimizers.RMSprop(),
    "epochs": 500,
    "batch_size": 32,
    "scenario": 1,
    "multiclass": False
}
MAX_LEN_CAP = 1000 

In [11]:
def parse_vista_fasta(fasta_path):
    """
    Parse FASTA; return list of dicts with fields: seq_id, seq, species, is_enhancer (True/False), label_text
    The VISTA FASTA header format varies; this function uses heuristics to extract species and enhancer status.
    """
    records = []
    # Accept zipped directories as well: find .fa/.fasta files in parent dir
    if fasta_path.is_dir():
        fasta_files = list(fasta_path.glob("*.fa")) + list(fasta_path.glob("*.fasta")) + list(fasta_path.glob("*.txt"))
        if not fasta_files:
            raise FileNotFoundError("No fasta files found in dir: " + str(fasta_path))
        fasta_path = fasta_files[0]

    print("Parsing FASTA:", fasta_path)
    for rec in SeqIO.parse(str(fasta_path), "fasta"):
        header = rec.description
        seq = str(rec.seq).upper()
        # heuristic species detection
        species = None
        if re.search(r'\b(mouse|mus|mm|mouse)\b', header, re.I):
            species = "mouse"
        if re.search(r'\b(human|hs|hg|homo sapiens)\b', header, re.I):
            species = "human"
        # enhancer label detection; VISTA often tags "positive" or "negative" or "enhancer"
        is_enhancer = None
        if re.search(r'\b(positive|enhancer|activity|yes|pos)\b', header, re.I):
            is_enhancer = True
        if re.search(r'\b(negative|no|neg|not enhancer)\b', header, re.I):
            is_enhancer = False
        # fallback: if neither, attempt to parse metadata after '|' or ';'
        rec_id = rec.id
        records.append({
            "seq_id": rec_id,
            "header": header,
            "seq": seq,
            "species": species,
            "is_enhancer": is_enhancer
        })
    df = pd.DataFrame(records)
    print("Parsed", len(df), "records; summary:\n", df[['species','is_enhancer']].agg(lambda s: s.isnull().sum()))
    return df

In [6]:
INT_MAP = {"A":1, "C":3, "G":2, "T":4, "N":0}
ATOMIC_MAP = {"A":70, "C":58, "G":78, "T":66, "N":0}
EIIP_MAP = {"A":0.1260, "C":0.1340, "G":0.0806, "T":0.1335, "N":0.0}

def encode_integer(seq):
    return np.array([INT_MAP.get(b, 0) for b in seq], dtype=float)

def encode_atomic(seq):
    return np.array([ATOMIC_MAP.get(b, 0) for b in seq], dtype=float)

def encode_eiip(seq):
    return np.array([EIIP_MAP.get(b, 0.0) for b in seq], dtype=float)

def encode_bfdna(seq):
    # BFDNA: each base value = (count of that base in the sequence) / (sequence length).
    seq = seq.upper()
    L = len(seq)
    counts = Counter(seq)
    # handle zero-length (shouldn't happen)
    if L == 0:
        return np.zeros(0, dtype=float)
    # compute values for A,C,G,T
    vals = {}
    for b in ["A","C","G","T"]:
        vals[b] = counts.get(b, 0) / L
    # map sequence to values
    return np.array([vals.get(b, 0.0) for b in seq], dtype=float)

ENCODERS = {
    "integer": encode_integer,
    "atomic": encode_atomic,
    "eiip": encode_eiip,
    "bfdna": encode_bfdna
}

In [7]:
def pad_truncate(arr, max_len):
    if len(arr) >= max_len:
        return arr[:max_len]
    else:
        pad = np.zeros(max_len - len(arr), dtype=float)
        return np.concatenate([arr, pad])

# Min-max scaling across dataset (per position). We'll stack sequences into shape (N, L).
def minmax_scale_dataset(X):
    # X: numpy array shape (N, L)
    scaler = MinMaxScaler(feature_range=(0,1))
    # fit per-column
    X_scaled = scaler.fit_transform(X)
    return X_scaled, scaler

In [8]:
def build_bilstm_model(input_len, params, num_classes=1):
    inp = layers.Input(shape=(input_len, 1))
    x = inp
    # Stack BiLSTM layers as specified. Use return_sequences=True for intermediate layers.
    for i, units in enumerate(params["bi_lstm_units"]):
        return_seq = True if i < (len(params["bi_lstm_units"]) - 1) else False
        x = layers.Bidirectional(layers.LSTM(units, activation=params["activation"], return_sequences=return_seq))(x)
        do = params["dropouts"][i] if i < len(params["dropouts"]) else 0.2
        if do > 0:
            x = layers.Dropout(do)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Flatten()(x)
    for u in params["dense_units"]:
        x = layers.Dense(u, activation=params["activation"])(x)
    if params["multiclass"]:
        out = layers.Dense(num_classes, activation=params["output_activation"])(x)
    else:
        out = layers.Dense(1, activation=params["output_activation"])(x)
    model = models.Model(inputs=inp, outputs=out)
    model.compile(optimizer=params["optimizer"], loss=params["loss"], metrics=["accuracy"])
    model.summary()
    return model

In [9]:
def classification_success_index(y_true, y_pred):
    # CSI = Precision + TPR - 1
    tpr = recall_score(y_true, y_pred, zero_division=0)
    prec = precision_score(y_true, y_pred, zero_division=0)
    return prec + tpr - 1

def g_mean(y_true, y_pred):
    # G-mean = sqrt(recall * specificity)
    cm = confusion_matrix(y_true, y_pred)
    if cm.shape == (2,2):
        tn, fp, fn, tp = cm.ravel()
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        return math.sqrt(recall * specificity)
    else:
        # multiclass: geometric mean of per-class recall
        recalls = []
        for i in range(len(cm)):
            tp = cm[i,i]
            fn = cm[i,:].sum() - tp
            recall_i = tp / (tp + fn) if (tp + fn) > 0 else 0
            recalls.append(recall_i)
        prod = 1.0
        for r in recalls: prod *= max(r, 1e-12)
        return prod ** (1.0/len(recalls))

In [None]:
def prepare_dataset(df, encoder_name="bfdna", scenario=1, max_len_cap=MAX_LEN_CAP):
    encoder = ENCODERS[encoder_name]
    # Decide which rows to include based on scenario:
    # scenario 1 -> only enhancer sequences and species labels (human/mouse) -> binary
    # scenario 2 -> multiclass: human_enhancer, mouse_enhancer, no_enhancer
    df2 = df.copy()
    # Ensure species/is_enhancer are not None: try to infer from header if missing
    # (We already did heuristics in parse)
    if scenario == 1:
        # keep only rows where is_enhancer is True and species is known
        df_filt = df2[(df2["is_enhancer"]==True) & (df2["species"].notnull())].reset_index(drop=True)
        y = df_filt["species"].map(lambda s: 1 if s=="human" else 0).values  # human=1, mouse=0
    else:
        # multiclass
        # three classes: 'human_enhancer', 'mouse_enhancer', 'no_enhancer'
        def map_label(row):
            if row["is_enhancer"]==True and row["species"]=="human":
                return "human_enhancer"
            if row["is_enhancer"]==True and row["species"]=="mouse":
                return "mouse_enhancer"
            return "no_enhancer"
        df_filt = df2.copy()
        df_filt["label"] = df_filt.apply(map_label, axis=1)
        y = df_filt["label"].values

    # Encode sequences:
    encoded = []
    lengths = []
    for seq in df_filt["seq"].values:
        arr = encoder(seq)
        lengths.append(len(arr))
        encoded.append(arr)
    if len(encoded) == 0:
        raise RuntimeError("No sequences found after filtering — check dataset parsing/headers.")

    # Determine max_len: min(max(lengths), max_len_cap)
    dataset_max_len = max(lengths)
    max_len = min(dataset_max_len, max_len_cap)
    print(f"Dataset max length = {dataset_max_len}; using max_len = {max_len}")

    # Pad/truncate and stack
    X = np.stack([pad_truncate(arr, max_len) for arr in encoded], axis=0)  # shape (N, L)
    # Min-max normalization across dataset (per position)
    X_scaled, scaler = minmax_scale_dataset(X)
    # Expand dims for model input: (N, L, 1)
    X_scaled = X_scaled.reshape((X_scaled.shape[0], X_scaled.shape[1], 1))
    return X_scaled, y, df_filt, scaler, max_len

In [10]:
def train_and_evaluate(X, y, params, df_meta, label_names=None):
    # split 75% train, 15% val, 15% test (paper)
    N = X.shape[0]
    idx = np.arange(N)
    np.random.shuffle(idx)
    train_end = int(0.75 * N)
    val_end = int(0.90 * N)
    train_idx, val_idx, test_idx = idx[:train_end], idx[train_end:val_end], idx[val_end:]
    X_train, X_val, X_test = X[train_idx], X[val_idx], X[test_idx]
    y_train, y_val, y_test = y[train_idx], y[val_idx], y[test_idx]

    # Prepare labels
    if params["multiclass"]:
        # y are strings -> one-hot encode
        lb = LabelBinarizer()
        lb.fit(y)  # fits unique labels
        y_train_o = lb.transform(y_train)
        y_val_o = lb.transform(y_val)
        y_test_o = lb.transform(y_test)
        num_classes = y_train_o.shape[1]
    else:
        # y are 0/1
        y_train_o, y_val_o, y_test_o = y_train.astype(int), y_val.astype(int), y_test.astype(int)
        num_classes = 1

    model = build_bilstm_model(input_len=X.shape[1], params=params, num_classes=num_classes)

    # Callbacks: save best model
    ckpt_path = MODEL_DIR / f"scenario{params['scenario']}_{random.randint(0,9999)}.h5"
    cb = [
        callbacks.ModelCheckpoint(str(ckpt_path), monitor='val_accuracy', save_best_only=True, verbose=1),
        # optional early stopping (not in paper) — commented here; you can enable if desired.
        # callbacks.EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)
    ]

    history = model.fit(X_train, y_train_o, validation_data=(X_val, y_val_o),
                        epochs=params["epochs"], batch_size=params["batch_size"], callbacks=cb, verbose=2)

    # Load best model
    model.load_weights(str(ckpt_path))

    # Predictions
    y_pred_prob = model.predict(X_test)
    if params["multiclass"]:
        y_pred_idx = np.argmax(y_pred_prob, axis=1)
        # map back to labels
        if isinstance(y_test_o, np.ndarray):
            # decode using LabelBinarizer
            # We need lb stored earlier; re-fit to all labels to get mapping
            lb2 = LabelBinarizer(); lb2.fit(y)
            label_list = lb2.classes_
            y_test_decoded = y_test
            y_pred_decoded = label_list[y_pred_idx]
        else:
            raise RuntimeError("Unexpected type for multiclass y")
    else:
        y_pred_prob = y_pred_prob.ravel()
        y_pred_bin = (y_pred_prob >= 0.5).astype(int)
        y_test_decoded = y_test
        y_pred_decoded = y_pred_bin

    # Compute metrics
    metrics = {}
    if params["multiclass"]:
        # For multiclass AUC we compute macro-average via one-vs-rest if possible
        lb2 = LabelBinarizer(); lb2.fit(y)
        y_test_bin = lb2.transform(y_test)
        try:
            auc = roc_auc_score(y_test_bin, y_pred_prob, average="macro", multi_class="ovr")
        except Exception:
            auc = float("nan")
        metrics["AUC_macro"] = auc
        # Accuracy etc using decoded labels
        acc = accuracy_score(y_test_decoded, y_pred_decoded)
        prec = precision_score(y_test_decoded, y_pred_decoded, average="macro", zero_division=0)
        rec = recall_score(y_test_decoded, y_pred_decoded, average="macro", zero_division=0)
        f1 = f1_score(y_test_decoded, y_pred_decoded, average="macro", zero_division=0)
        metrics.update({"accuracy":acc, "precision_macro":prec, "recall_macro":rec, "f1_macro":f1})
        # For per-class metrics you can compute separately
    else:
        auc = roc_auc_score(y_test, y_pred_prob) if len(np.unique(y_test)) == 2 else float("nan")
        metrics["AUC"] = auc
        metrics["accuracy"] = accuracy_score(y_test, y_pred_bin)
        metrics["precision"] = precision_score(y_test, y_pred_bin, zero_division=0)
        metrics["recall"] = recall_score(y_test, y_pred_bin, zero_division=0)
        metrics["f1"] = f1_score(y_test, y_pred_bin, zero_division=0)
        metrics["CSI"] = classification_success_index(y_test, y_pred_bin)
        metrics["Gmean"] = g_mean(y_test, y_pred_bin)
        metrics["MCC"] = matthews_corrcoef(y_test, y_pred_bin)
        metrics["Kappa"] = cohen_kappa_score(y_test, y_pred_bin)

    print("Evaluation metrics:")
    for k,v in metrics.items():
        print(f"  {k}: {v}")

    # Save predictions CSV
    out_df = df_meta.iloc[test_idx].copy()
    out_df = out_df.reset_index(drop=True)
    if params["multiclass"]:
        out_df["true_label"] = y_test
        out_df["pred_label"] = y_pred_decoded
        # store per-class probs
        for i,lab in enumerate(lb2.classes_):
            out_df[f"prob_{lab}"] = y_pred_prob[:,i]
    else:
        out_df["true_label"] = y_test
        out_df["pred_label"] = y_pred_decoded
        out_df["prob_positive"] = y_pred_prob
    out_csv = RESULTS_DIR / f"predictions_scenario{params['scenario']}.csv"
    out_df.to_csv(out_csv, index=False)
    print("Saved predictions to", out_csv)

    # Save metrics
    with open(RESULTS_DIR / f"metrics_scenario{params['scenario']}.json", "w") as f:
        json.dump(metrics, f, indent=2)
    print("Saved metrics.")

    return model, history, metrics

In [14]:
if __name__ == "__main__":
    # 1) Download dataset (may throw if fails)
    try:
        downloaded = FASTA_FILE
    except Exception as e:
        raise RuntimeError("Automatic download failed. Please manually download the VISTA FASTA from https://enhancer.lbl.gov/vista/downloads and upload to `vista_data/` then re-run.")

    # 2) Parse FASTA
    df = parse_vista_fasta(downloaded)

    # Quick cleanup heuristics: if there are ambiguous 'N' categories or missing labels,
    # user might need to manually map. We'll proceed with available labels as best-effort.

    # 3) For each encoder and scenario, prepare dataset, train, evaluate
    encoders_to_try = ["bfdna", "eiip", "atomic", "integer"]  # order matches paper
    results_summary = {}
    for encoder_name in encoders_to_try:
        print("\n\n==============================")
        print("Encoder:", encoder_name)
        print("==============================")
        # Scenario 1
        try:
            X1, y1, df_meta1, scaler1, max_len1 = prepare_dataset(df, encoder_name=encoder_name, scenario=1)
            model1, hist1, metrics1 = train_and_evaluate(X1, y1, SCENARIO1_PARAMS, df_meta1)
            results_summary[f"{encoder_name}_scenario1"] = metrics1
        except Exception as e:
            print("Scenario 1 failed for encoder", encoder_name, ":", e)

        # Scenario 2
        try:
            X2, y2, df_meta2, scaler2, max_len2 = prepare_dataset(df, encoder_name=encoder_name, scenario=2)
            model2, hist2, metrics2 = train_and_evaluate(X2, y2, SCENARIO2_PARAMS, df_meta2)
            results_summary[f"{encoder_name}_scenario2"] = metrics2
        except Exception as e:
            print("Scenario 2 failed for encoder", encoder_name, ":", e)

    # Save summary
    with open(RESULTS_DIR / "summary_results.json", "w") as f:
        json.dump(results_summary, f, indent=2)
    print("All done. Results saved to", RESULTS_DIR)

Parsing FASTA: Vista_Dataset/vista_sequences.fasta
Parsed 3408 records; summary:
 species        0
is_enhancer    0
dtype: int64


Encoder: bfdna
Scenario 1 failed for encoder bfdna : name 'prepare_dataset' is not defined
Scenario 2 failed for encoder bfdna : name 'prepare_dataset' is not defined


Encoder: eiip
Scenario 1 failed for encoder eiip : name 'prepare_dataset' is not defined
Scenario 2 failed for encoder eiip : name 'prepare_dataset' is not defined


Encoder: atomic
Scenario 1 failed for encoder atomic : name 'prepare_dataset' is not defined
Scenario 2 failed for encoder atomic : name 'prepare_dataset' is not defined


Encoder: integer
Scenario 1 failed for encoder integer : name 'prepare_dataset' is not defined
Scenario 2 failed for encoder integer : name 'prepare_dataset' is not defined
All done. Results saved to results
