In [None]:
!pip install -U causal-conv1d
!pip install bitsandbytes
!pip install datasets evaluate accelerate
!pip install --no-build-isolation --no-cache-dir -U mamba-ssm

In [None]:
!git clone https://github.com/getorca/mamba_for_sequence_classification.git
!rm -rf mamba_for_sequence_classification/requirements.txt
!touch mamba_for_sequence_classification/requirements.txt
!pip install -q ./mamba_for_sequence_classification

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, TrainingArguments, Trainer, DataCollatorWithPadding
from datasets import load_dataset
from mamba_ssm import selective_scan_fn
from google.colab import drive
from peft import LoraConfig, get_peft_model, TaskType
from hf_mamba_classification import MambaForSequenceClassification
import pandas as pd
from datasets import Dataset, DatasetDict
import os
import evaluate
import glob
import inspect, os
import math
import torch
os.environ["HF_DATASETS_CACHE"] = "/content/hf_cache"
MODEL_NAME = "state-spaces/mamba-130m-hf"
NUM_LABELS = 2
TRAIN_CSV = "/content/train_clean.csv"
VAL_CSV   = "/content/val_clean.csv"
OUTPUT_DIR = "mamba_base_lora"
max_length = 128
BATCH_SIZE = 32
NUM_EPOCHS = 18
LR = 2e-4
LORA_R = 12
LORA_ALPHA = 32
LORA_DROP = 0.05

In [None]:
def train():

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    id2label = {0: "NEGATIVE", 1: "POSITIVE"}
    label2id = {"NEGATIVE": 0, "POSITIVE": 1}

    model = MambaForSequenceClassification.from_pretrained(MODEL_NAME, num_labels = NUM_LABELS, use_cache = False, id2label = id2label, label2id = label2id)
    model.to("cuda")

    train_df = pd.read_csv(TRAIN_CSV)
    val_df   = pd.read_csv(VAL_CSV)

    train_ds = Dataset.from_pandas(train_df)
    val_ds   = Dataset.from_pandas(val_df)

    raw_dataset = DatasetDict({"train": train_ds, "validation": val_ds})
    torch.backends.cudnn.benchmark = True

    def preprocess(examples):
        texts = [str(x) for x in examples["text"]]
        enc = tokenizer(
            texts,
            truncation=True,
            padding="max_length",
            max_length=max_length,
        )
        enc["labels"] = [int(x) for x in examples["label"]]
        return enc

    dataset = raw_dataset.map(
        preprocess,
        batched=True,
        remove_columns=["text", "label"],
    )

    peft_config = LoraConfig(
        task_type = TaskType.SEQ_CLS,
        target_modules = ["in_proj", "out_proj", "x_proj", "proj_in", "proj_out"],
        r = LORA_R,
        lora_alpha = LORA_ALPHA,
        lora_dropout = LORA_DROP,
        bias = 'none'
    )

    final_model = get_peft_model(model, peft_config)
    final_model.to("cuda")
    print(" % OF TRAINING")
    final_model.print_trainable_parameters()

    metric_acc = evaluate.load("accuracy")
    metric_f1  = evaluate.load("f1")
    metric_precision = evaluate.load("precision")
    metric_recall = evaluate.load("recall")

    def compute_metrics(p):
      preds = p.predictions.argmax(-1)
      return {
          "accuracy": metric_acc.compute(predictions = preds, references = p.label_ids)["accuracy"],
          "f1":       metric_f1.compute(predictions = preds, references = p.label_ids, average = "binary")["f1"],
          "precision":       metric_precision.compute(predictions = preds, references = p.label_ids, average="binary")["precision"],
          "recall":       metric_recall.compute(predictions = preds, references = p.label_ids, average="binary")["recall"],
      }

    final_model.gradient_checkpointing_enable()
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32  = True
    drive.mount('/content/drive')
    OUTPUT_DIR_DRIVE = "/content/drive/MyDrive/mamba_checkpoints"
    import os
    os.makedirs(OUTPUT_DIR_DRIVE, exist_ok=True)
    #final_model = torch.compile(final_model, mode="default", fullgraph=False)
    training_args = TrainingArguments(
        output_dir                  = OUTPUT_DIR_DRIVE,
        per_device_train_batch_size = BATCH_SIZE,
        learning_rate               = LR,
        gradient_accumulation_steps = 8,
        eval_strategy               = "epoch",
        save_strategy               = "epoch",
        dataloader_num_workers      = 2,
        warmup_ratio                = 0.1,
        lr_scheduler_type           = "cosine",
        dataloader_pin_memory       = True,
        bf16                        = True,
        optim                       = "adamw_torch_fused",
        max_grad_norm               = 1.0,
        fp16                        = False,
        num_train_epochs            = NUM_EPOCHS,
        logging_strategy            = "epoch",
        load_best_model_at_end      = True,
        metric_for_best_model       = "f1",
        remove_unused_columns       = False,
        greater_is_better           = True,
        report_to                   = "none",
        label_names                 = ["labels"]
    )

    #final_model = torch.compile(final_model)
    with torch.no_grad():
      torch.nn.init.kaiming_uniform_(final_model.classifier.weight, a = math.sqrt(5))

    trainer = Trainer(
        model               = final_model,
        args                = training_args,
        train_dataset       = dataset["train"],
        tokenizer           = tokenizer,
        eval_dataset        = dataset["validation"],
        data_collator       = DataCollatorWithPadding(tokenizer, return_tensors = 'pt'),
        compute_metrics     = compute_metrics
    )

    all_ckpts = sorted(
    glob.glob(os.path.join(OUTPUT_DIR_DRIVE, "checkpoint-*")),
    key=lambda x: int(x.split("-")[-1])
    )
    if all_ckpts:
        print("🔄 Riprendo da:", all_ckpts[-1])
        trainer.train(resume_from_checkpoint=all_ckpts[-1])
    else:
        print("🔄 Nessun checkpoint trovato, inizio da zero")
        trainer.train()
    metrics = trainer.evaluate()
    print("Final evaluation:", metrics)

    trainer.save_model(os.path.join(OUTPUT_DIR_DRIVE, "final_model"))

In [None]:
train()

In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import logging
import matplotlib.pyplot as plt

from transformers import AutoTokenizer, DataCollatorWithPadding
from peft import PeftModel
from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    balanced_accuracy_score,
    matthews_corrcoef,
    roc_auc_score,
    average_precision_score,
    roc_curve,
    precision_recall_curve,
    cohen_kappa_score,
    confusion_matrix,
    auc
)
from hf_mamba_classification import MambaForSequenceClassification

# CONFIGURATION
MODEL_NAME = "state-spaces/mamba-130m-hf"
ADAPTER_DIR = "mamba_base_lora/final_model"
BATCH_SIZE = 32
MAX_LENGTH = 128
RESULTS_DIR = "results"
DATA_DIR = Path("/content")  # adjust as needed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

def compute_probs_from_hf_dataset(
    hf_dataset: Dataset,
    model: torch.nn.Module,
    tokenizer,
    device: torch.device
) -> np.ndarray:
    collator = DataCollatorWithPadding(tokenizer, return_tensors="pt")
    pin_memory = device.type == "cuda"
    loader = DataLoader(
        hf_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collator,
        num_workers=2,
        pin_memory=pin_memory
    )
    all_probs = []
    model.eval()
    with torch.no_grad():
        for batch in loader:
            inputs = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            logits = model(**inputs).logits
            probs = torch.softmax(logits, dim=-1)[:, 1]
            all_probs.append(probs.cpu().numpy())
    return np.concatenate(all_probs, axis=0)


def save_results_and_plots(
    y_true: np.ndarray,
    probs: np.ndarray,
    set_name: str,
    threshold: float,
    out_dir: str = RESULTS_DIR
):
    # Create output directory
    os.makedirs(out_dir, exist_ok=True)

    # Compute predictions
    y_pred = (probs >= threshold).astype(int)

    # Compute confusion matrix components
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    # Collect metrics
    metrics = {
        "Threshold": threshold,
        "Accuracy": accuracy_score(y_true, y_pred),
        "Balanced Accuracy": balanced_accuracy_score(y_true, y_pred),
        "F1": f1_score(y_true, y_pred),
        "Precision": precision_score(y_true, y_pred),
        "Recall (Sensitivity)": recall_score(y_true, y_pred),
        "Specificity": tn / (tn + fp) if (tn + fp) > 0 else np.nan,
        "Negative Predictive Value": tn / (tn + fn) if (tn + fn) > 0 else np.nan,
        "MCC": matthews_corrcoef(y_true, y_pred),
        "Cohen Kappa": cohen_kappa_score(y_true, y_pred),
        "ROC-AUC": roc_auc_score(y_true, probs),
        "PR-AUC": average_precision_score(y_true, probs)
    }
    metrics_df = pd.DataFrame([metrics])
    metrics_path = os.path.join(out_dir, f"{set_name}_metrics.csv")
    metrics_df.to_csv(metrics_path, index=False)

    # Save per-sample probabilities and predictions
    results_df = pd.DataFrame({
        "true_label": y_true,
        "probability": probs,
        "pred_label": y_pred
    })
    results_path = os.path.join(out_dir, f"{set_name}_predictions.csv")
    results_df.to_csv(results_path, index=False)

    # ROC Curve
    fpr, tpr, _ = roc_curve(y_true, probs)
    roc_auc = auc(fpr, tpr)
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}", linewidth=2)
    ax.plot([0, 1], [0, 1], linestyle="--", color="gray", linewidth=1)
    ax.set(
        xlabel="False Positive Rate",
        ylabel="True Positive Rate",
        title=f"ROC Curve ({set_name})"
    )
    ax.legend(loc="lower right")
    plt.tight_layout()
    roc_path = os.path.join(out_dir, f"{set_name}_roc_curve.png")
    fig.savefig(roc_path, dpi=300)
    plt.close(fig)

    # Precision-Recall Curve
    precision, recall, _ = precision_recall_curve(y_true, probs)
    pr_auc = auc(recall, precision)
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(recall, precision, label=f"AP = {pr_auc:.3f}", linewidth=2)
    ax.set(
        xlabel="Recall",
        ylabel="Precision",
        title=f"Precision–Recall Curve ({set_name})"
    )
    ax.legend(loc="upper right")
    plt.tight_layout()
    pr_path = os.path.join(out_dir, f"{set_name}_pr_curve.png")
    fig.savefig(pr_path, dpi=300)
    plt.close(fig)

    logger.info(f"[Saved] {set_name} metrics -> {metrics_path}")
    logger.info(f"[Saved] {set_name} predictions -> {results_path}")
    logger.info(f"[Saved] {set_name} ROC curve -> {roc_path}")
    logger.info(f"[Saved] {set_name} PR curve -> {pr_path}")


def main():
    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    base_model = MambaForSequenceClassification.from_pretrained(
        MODEL_NAME, num_labels=2, use_cache=False
    )
    model = PeftModel.from_pretrained(base_model, ADAPTER_DIR).to(device)

    # Load data
    test_df = pd.read_csv(DATA_DIR / "test_clean.csv")
    val_df = pd.read_csv(DATA_DIR / "val_clean.csv")

    test_ds = Dataset.from_pandas(test_df)
    val_ds = Dataset.from_pandas(val_df)
    raw_dataset = DatasetDict({"test": test_ds, "val": val_ds})

    # Preprocessing function
    def preprocess(examples):
        texts = [str(x) for x in examples["text"]]
        encodings = tokenizer(
            texts,
            truncation=True,
            padding="max_length",
            max_length=MAX_LENGTH
        )
        encodings["labels"] = [int(x) for x in examples["label"]]
        return encodings

    # Tokenize datasets
    dataset = raw_dataset.map(
        preprocess,
        batched=True,
        remove_columns=["text", "label"]
    )

    # Extract splits
    y_val = np.array(dataset["val"]["labels"])
    y_test = np.array(dataset["test"]["labels"])

    # Compute validation probabilities and find optimal threshold
    probs_val = compute_probs_from_hf_dataset(dataset["val"], model, tokenizer, device)
    fpr, tpr, thresholds = roc_curve(y_val, probs_val)
    youden_j = tpr - fpr
    best_idx = np.argmax(youden_j)
    best_thresh = thresholds[best_idx]
    logger.info(f"Optimal threshold from validation: {best_thresh:.3f}")

    # Compute test probabilities
    probs_test = compute_probs_from_hf_dataset(dataset["test"], model, tokenizer, device)

    # Save results and plots
    save_results_and_plots(y_test, probs_test, set_name="test_youden", threshold=best_thresh)
    save_results_and_plots(y_test, probs_test, set_name="test_standard", threshold=0.5)

In [None]:
test()
!zip -r results.zip /content/results