### python version 3.10.12

In [None]:
# If running locally or on a fresh runtime, uncomment the next line:
!pip install -q transformers>=4.40 accelerate torch scikit-learn pandas numpy

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wand_ai_api_uog")

In [None]:
import wandb
wandb.login(key=secret_value_0)

In [None]:
import os
import re
import torch
import numpy as np
import pandas as pd
from typing import Optional
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score, recall_score,
    matthews_corrcoef, roc_auc_score, average_precision_score
)
from transformers import (
    AutoTokenizer, AutoConfig, AutoModelForSequenceClassification,
    AutoModel, Trainer, TrainingArguments, DataCollatorWithPadding, set_seed
)
from transformers import EarlyStoppingCallback


# -------------------- #
# FASTA utilities
# -------------------- #
def read_fasta(path):
    recs = []
    header, seq_lines = None, []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line: continue
            if line.startswith(">"):
                if header: recs.append((header, "".join(seq_lines)))
                header, seq_lines = line[1:], []
            else:
                seq_lines.append(line)
    if header: recs.append((header, "".join(seq_lines)))
    return recs

def sanitize_seq(seq: str) -> str:
    allowed = set("ACDEFGHIKLMNPQRSTVWYBXZJUO")
    seq = seq.upper().replace("*", "")
    return "".join([ch if ch in allowed else "X" for ch in seq])

def infer_label(h: str) -> Optional[int]:
    hl = h.lower()
    if "negative" in hl or "non-epitope" in hl: return 0
    if "positive" in hl or "epitope" in hl: return 1
    if "neg" in hl and "pos" not in hl: return 0
    if "pos" in hl and "neg" not in hl: return 1
    m = re.search(r"(label|class)[:=]\s*([01])", hl)
    if m: return int(m.group(2))
    return None

def load_dataset_from_fasta(fasta_path: str) -> pd.DataFrame:
    rows = []
    for h, s in read_fasta(fasta_path):
        lab = infer_label(h)
        if lab is None:
            raise ValueError(f"Could not infer label from header: {h}")
        rows.append({"header": h, "sequence": sanitize_seq(s), "label": lab})
    return pd.DataFrame(rows)

def _maybe_infer_labels_df(self, df: pd.DataFrame) -> pd.DataFrame:
    """
    If headers contain obvious label hints, keep them; otherwise set label = NaN.
    """
    if "label" not in df.columns:
        df["label"] = np.nan
    # Keep whatever load_dataset_from_fasta produced; it already tries to infer.
    # If that raised earlier for missing labels, you can add a separate loader for unlabeled FASTA.
    return df

def _effective_max_len(self) -> int:
    """
    ESM2 models typically cap around 1022 tokens incl. specials. Respect tokenizer.model_max_length if set.
    """
    # transformers may give a giant int for "very long". Clamp to 1022 as a safe ceiling for ESM2.
    tk_max = getattr(self.tokenizer, "model_max_length", 1024)
    safe_cap = 1022
    return int(min(self.max_length, tk_max if tk_max < 10_000 else safe_cap))
    
# -------------------- #
# Dataset wrapper
# -------------------- #
class ESM2Dataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_length=1022):
        self.df, self.tokenizer, self.max_length = df, tokenizer, max_length
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        enc = self.tokenizer(
            row["sequence"], truncation=True, max_length=self.max_length,
            return_tensors="pt", add_special_tokens=True
        )
        item = {k: v.squeeze(0) for k, v in enc.items()}
        item["labels"] = torch.tensor(int(row["label"]), dtype=torch.long)
        return item


# -------------------- #
# Main Pipeline Class
# -------------------- #
class ESM2Pipeline:
    def __init__(self,
                 model_name="facebook/esm2_t12_35M_UR50D",
                 outdir="./esm2_out",
                 max_length=1022,
                 batch_size=8,
                 lr=2e-5,
                 epochs=5,
                 weight_decay=0.01,
                 seed=42,
                 fp16=True,
                 BF16=True,
                 freeze_encoder=False,
                 class_weighted_loss=True):
        self.model_name = model_name
        self.outdir = outdir
        self.max_length = max_length
        self.batch_size = batch_size
        self.lr = lr
        self.epochs = epochs
        self.weight_decay = weight_decay
        self.seed = seed
        self.fp16 = fp16
        self.freeze_encoder = freeze_encoder
        self.class_weighted_loss = class_weighted_loss

        os.makedirs(outdir, exist_ok=True)
        set_seed(seed)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)

    def compute_metrics(self):
        def _compute(eval_pred):
            logits, labels = eval_pred
            if isinstance(logits, (tuple, list)): logits = logits[0]
            probs = torch.softmax(torch.tensor(logits), dim=-1).numpy()
            preds = probs.argmax(axis=1)
            y_true = labels
            y_proba = probs[:, 1]
            return {
                 "precision": precision_score(y_true, preds, zero_division=0),
               "accuracy": accuracy_score(y_true, preds),
                "recall": recall_score(y_true, preds, zero_division=0),
                "f1": f1_score(y_true, preds, zero_division=0),
                "mcc": matthews_corrcoef(y_true, preds),
                "roc_auc": roc_auc_score(y_true, y_proba),
                "pr_auc": average_precision_score(y_true, y_proba),
            }
        return _compute

    def train_and_eval(self, train_fasta, ind_fasta, val_size=0.1):
        df_train_all = load_dataset_from_fasta(train_fasta)
        df_ind = load_dataset_from_fasta(ind_fasta)

        df_tr, df_val = train_test_split(df_train_all, test_size=val_size,
                                         random_state=self.seed, stratify=df_train_all["label"])

        train_ds = ESM2Dataset(df_tr, self.tokenizer, self.max_length)
        val_ds = ESM2Dataset(df_val, self.tokenizer, self.max_length)
        ind_ds = ESM2Dataset(df_ind, self.tokenizer, self.max_length)

        config = AutoConfig.from_pretrained(self.model_name, num_labels=2)
        model = AutoModelForSequenceClassification.from_pretrained(self.model_name, config=config)

        if self.freeze_encoder:
            for name, p in model.named_parameters():
                if "classifier" not in name: p.requires_grad = False

        args = TrainingArguments(
            output_dir=self.outdir,
            learning_rate=self.lr,
            per_device_train_batch_size=self.batch_size,
            per_device_eval_batch_size=self.batch_size,
            num_train_epochs=self.epochs,
            weight_decay=self.weight_decay,
            eval_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            warmup_ratio=0.1,
            metric_for_best_model="f1",
            fp16=self.fp16,
            report_to=[]
        )

        trainer = Trainer(
            model=model,
            args=args,
            train_dataset=train_ds,
            eval_dataset=val_ds,
            tokenizer=self.tokenizer,
            data_collator=DataCollatorWithPadding(tokenizer=self.tokenizer),
            compute_metrics=self.compute_metrics(),
            callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
        )

        trainer.train()
        trainer.save_model(os.path.join(self.outdir, "best_model"))

        val_metrics = trainer.evaluate(val_ds)
        ind_metrics = trainer.evaluate(ind_ds)

        return val_metrics, ind_metrics

    def extract_embeddings(self, fasta_path, pooling="mean", layer=None):
        df = load_dataset_from_fasta(fasta_path)
        model = AutoModel.from_pretrained(self.model_name, output_hidden_states=True).eval()
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)

        vecs, headers, labels = [], df["header"].tolist(), df["label"].tolist()
        for i in range(0, len(df), self.batch_size):
            seqs = df["sequence"].iloc[i:i+self.batch_size].tolist()
            enc = self.tokenizer(seqs, truncation=True, padding=True,
                                 max_length=self.max_length, return_tensors="pt")
            enc = {k: v.to(device) for k, v in enc.items()}
            out = model(**enc)
            hidden = out.hidden_states[layer] if layer is not None else out.last_hidden_state

            if pooling == "cls":
                vec = hidden[:, 0, :]
            elif pooling == "mean":
                mask = enc["attention_mask"].unsqueeze(-1)
                vec = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1)
            else:
                vec = hidden.reshape(hidden.size(0), -1)
            vecs.append(vec.cpu().detach().numpy())

        X = np.concatenate(vecs, axis=0)
        return X, pd.DataFrame({"header": headers, "label": labels})


    # ---------- prediction ----------
    def predict(self, fasta_path: str, threshold: float = 0.5, out_csv: Optional[str] = None) -> pd.DataFrame:
        # allow unlabeled FASTA
        df = load_dataset_from_fasta(fasta_path)
        df.rename(columns={"header": "seq_id"}, inplace=True)
        
        ds = ESM2Dataset(df, self.tokenizer, self.max_length)
        collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
        # config = AutoConfig.from_pretrained(self.model_name, num_labels=2)
        # model = AutoModelForSequenceClassification.from_pretrained(
        #         os.path.join(self.outdir, "best_model"), config=config
        #     )

        best_dir = os.path.join(self.outdir, "best_model")
        model = AutoModelForSequenceClassification.from_pretrained(best_dir)
        model.eval()
        
        # model.predict = True
        trainer = Trainer(model=model, tokenizer=self.tokenizer,
            data_collator=DataCollatorWithPadding(tokenizer=self.tokenizer))
        proba = trainer.predict(ds)
        proba = proba.predictions[:, 1]
        y_pred = (proba >= threshold).astype(int)
        # y_pred = proba.argmax(axis=1)
        # y_pred = np.argmax(proba.predictions, axis=1)
        # df = self._maybe_unlabeled_df(fasta_path)
        # df.rename(columns={"header": "seq_id"}, inplace=True)

        # eff_max = self._effective_max_len()
        # pred_ds = ESM2Dataset(df, self.tokenizer, eff_max)
        # collator = DataCollatorWithPadding(tokenizer=self.tokenizer)

        # best_dir = os.path.join(self.outdir, "best_model")
        # model = AutoModelForSequenceClassification.from_pretrained(best_dir)
        # model.eval()

        # # trainer = Trainer(model=model, tokenizer=self.tokenizer, data_collator=collator)
        # trainer = SafeTrainer(model=model, tokenizer=self.tokenizer,
        #               data_collator=DataCollatorWithPadding(tokenizer=self.tokenizer))
        # out = trainer.predict(pred_ds)
        # logits = out.predictions
        # probs = torch.softmax(torch.tensor(logits), dim=-1).numpy()
        # proba = probs[:, 1]
        # y_pred = (proba >= threshold).astype(int)

        res = df.copy()
        # ensure numeric labels or NaN
        if "label" in res.columns:
            res["label"] = pd.to_numeric(res["label"], errors="coerce")
        else:
            res["label"] = np.nan

        res["proba"] = proba
        res["y_pred"] = y_pred
        res = res[["seq_id", "label", "proba", "y_pred"]]
        res.rename(columns={"label": "y_true"}, inplace=True)

        # metrics if y_true present
        if res["y_true"].notna().all():
            y_true = res["y_true"].astype(int).values
            try:
                auc  = roc_auc_score(y_true, proba)
                prau = average_precision_score(y_true, proba)
            except Exception:
                auc, prau = float("nan"), float("nan")
            acc  = accuracy_score(y_true, y_pred)
            mcc  = matthews_corrcoef(y_true, y_pred)
            f1   = f1_score(y_true, y_pred, zero_division=0)
            prec = precision_score(y_true, y_pred, zero_division=0)
            rec  = recall_score(y_true, y_pred, zero_division=0)
            spec = recall_score(y_true, y_pred, pos_label=0, zero_division=0)
            print(f"Accuracy: {acc:.4f}")
            print(f"AUC: {auc:.4f}")
            print(f"PR AUC: {prau:.4f}")
            print(f"MCC: {mcc:.4f}")
            print(f"F1: {f1:.4f}")
            print(f"Precision: {prec:.4f}")
            print(f"Recall/Sensitivity: {rec:.4f}")
            print(f"Specificity: {spec:.4f}")
        else:
            print("No ground-truth labels detected; wrote probabilities only.")

        if out_csv is None:
            out_csv = os.path.join(self.outdir, "predictions.csv")
        res.to_csv(out_csv, index=False)
        print(f"Saved predictions to: {out_csv}")
        return res


In [None]:

TRAIN_FASTA = "/kaggle/input/lbtope-ibce-final-dataset-features/new_dataset/Train-ibce.fasta"
IND_FASTA   = "/kaggle/input/lbtope-ibce-final-dataset-features/new_dataset/Ind-ibce50_renamed.fasta"

pipeline = ESM2Pipeline(
    model_name="facebook/esm2_t12_35M_UR50D",
    outdir="./esm2_run_64_t33_650M",
    max_length=64,
    epochs=5,
    batch_size=8,
    fp16=True
)

# Train + evaluate
val_metrics, ind_metrics = pipeline.train_and_eval(TRAIN_FASTA, IND_FASTA)
# print(val_metrics)
# print(ind_metrics)
df_ind = pd.DataFrame([val_metrics, ind_metrics], index=[1,2])
df_ind

In [None]:
df_ind = pipeline.predict(IND_FASTA, threshold=0.1, out_csv="ind_ibce_predictions.csv")
df_ind.head()

In [None]:
Test_clbe = "//kaggle/input/lbtope-ibce-final-dataset-features/features/ind_clbe/Test_clbe_filtered.fasta"
df_ind = pipeline.predict(Test_clbe, threshold=0.25, out_csv="test_clbe_predictions.csv")
df_ind.head()

In [None]:
Test_abcpred = "/kaggle/input/lbtope-ibce-final-dataset-features/features/ind_abcpred/abcpred_filtered.fasta"
df_ind = pipeline.predict(Test_abcpred, threshold=0.1, out_csv="test_abcpred_predictions.csv")
df_ind.head()