In [None]:
!pip install transformers accelerate peft datasets wandb bitsandbytes -q

import os, torch, torch.nn as nn
import pandas as pd, numpy as np
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from sklearn.metrics import accuracy_score, f1_score
import wandb


In [None]:
from typing import List

In [None]:
wandb.login()
os.environ["WANDB_PROJECT"] = "Banking77_MoE"
os.environ["WANDB_LOG_MODEL"] = "end"

In [None]:
TRAIN_CSV = "/content/drive/MyDrive/Banking77_Project/data/train.csv"
TEST_CSV  = "/content/drive/MyDrive/Banking77_Project/data/test.csv"

train_df = pd.read_csv(TRAIN_CSV)
test_df  = pd.read_csv(TEST_CSV)

# Auto-detect text column
possible_text_cols = ["text", "utterance", "sentence", "query", "question", "content"]
text_col = None
for c in possible_text_cols:
    if c in train_df.columns:
        text_col = c
        break
if text_col is None:
    # fallback: take first string column
    for c in train_df.columns:
        if train_df[c].dtype == object:
            text_col = c
            break
if text_col is None:
    raise ValueError(f"No text-like column found. Columns: {train_df.columns}")

# Normalize to 'text'
train_df = train_df.rename(columns={text_col: "text"})
test_df  = test_df.rename(columns={text_col: "text"})

# Auto-detect label column
if "label" in train_df.columns:
    label_col = "label"
elif "category" in train_df.columns:
    label_col = "category"
elif "intent" in train_df.columns:
    label_col = "intent"
elif "target" in train_df.columns:
    label_col = "target"
else:
    # try numeric column that looks like label
    numeric_cols = [c for c in train_df.columns if np.issubdtype(train_df[c].dtype, np.integer)]
    if len(numeric_cols) > 0:
        label_col = numeric_cols[0]
    else:
        raise ValueError(f"No label-like column found. Columns: {train_df.columns}")

# Normalize to 'label'
if label_col != "label":
    train_df = train_df.rename(columns={label_col: "label"})
    test_df  = test_df.rename(columns={label_col: "label"})

# Map string labels to ints if needed
if train_df["label"].dtype == object or test_df["label"].dtype == object:
    all_labels = pd.concat([train_df["label"], test_df["label"]]).unique()
    label2id = {lab: i for i, lab in enumerate(sorted(all_labels))}
    train_df["label"] = train_df["label"].map(label2id).astype(int)
    test_df["label"]  = test_df["label"].map(label2id).astype(int)
else:
    # numeric labels assumed OK
    all_labels = pd.concat([train_df["label"], test_df["label"]]).unique()
    label2id = {int(l): int(l) for l in sorted(all_labels)}

id2label = {v: k for k, v in label2id.items()}
num_labels = len(label2id)
print(f"Detected text column: '{text_col}', label column mapped -> 'label', num_labels = {num_labels}")


In [None]:
dataset = DatasetDict({
    "train": Dataset.from_pandas(train_df.reset_index(drop=True)),
    "test": Dataset.from_pandas(test_df.reset_index(drop=True))
})


In [None]:
expert_model_names = [
    "bert-base-uncased",
    "distilbert-base-uncased",
    "gpt2",
    "distilgpt2"
]

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
if tokenizer.pad_token is None:
    # ensure pad token exists (some tokenizers like GPT2 lack it)
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    print("Added pad token to tokenizer")

# Tokenize dataset
def tokenize_fn(examples, max_length=64):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)

tokenized = dataset.map(lambda x: tokenize_fn(x, max_length=64), batched=True)
# rename label column to 'labels' expected by Trainer
if "label" in tokenized["train"].column_names:
    tokenized = tokenized.rename_column("label", "labels")

# remove unnecessary columns that could break Trainer
cols_to_keep = ["input_ids", "attention_mask", "labels"]
tokenized = tokenized.map(lambda ex: {k: ex[k] for k in cols_to_keep}, batched=True)


In [None]:
def auto_detect_target_modules(model: AutoModelForSequenceClassification, model_name_hint: str) -> List[str]:
    """
    Inspect model.named_modules() and pick candidate substrings that exist.
    Returns a list of substrings to use as target_modules for PEFT LoRA.
    """
    all_module_names = [n for n, _ in model.named_modules()]
    # candidate substrings per family (expandable)
    families = {
        "bert": ["attention.self.query", "attention.self.key", "attention.self.value", "attention.output.dense", "dense", "query", "key", "value"],
        "distilbert": ["attention.q_lin", "attention.k_lin", "attention.v_lin", "attention.out_lin", "q_lin", "k_lin", "v_lin", "out_lin"],
        "gpt2": ["attn.c_attn", "attn.c_proj", "c_attn", "c_proj"],
        "distilgpt2": ["attn.c_attn", "attn.c_proj", "c_attn", "c_proj"]
    }

    # choose family by hint
    family_key = None
    for k in families.keys():
        if k in model_name_hint:
            family_key = k
            break
    if family_key is None:
        # fallback: try all families
        cand_subs = sum(families.values(), [])
    else:
        cand_subs = families[family_key]

    # keep only those substrings which exist in module names
    found = []
    for sub in cand_subs:
        if any(sub in mn for mn in all_module_names):
            found.append(sub)

    # final fallback: try a small safe set
    if not found:
        # try some broad substrings
        for sub in ["query", "key", "value", "dense", "c_attn", "c_proj", "q_lin", "k_lin", "v_lin", "out_lin"]:
            if any(sub in mn for mn in all_module_names):
                found.append(sub)
    return found


In [None]:
def make_peft_model_auto(model_name: str, num_labels: int):
    print(f"\n--> Loading base model for {model_name} ...")
    base = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

    if "gpt2" in model_name:
        base.config.pad_token_id = tokenizer.pad_token_id

    targets = auto_detect_target_modules(base, model_name)
    if not targets:
        sample_mods = [n for n, _ in list(base.named_modules())[:80]]
        raise ValueError(f"No suitable LoRA target modules found for {model_name}. Sample modules:\n{sample_mods[:40]}")
    print(f"Detected target_modules for {model_name}: {targets}")

    peft_config = LoraConfig(
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        bias="none",
        task_type="SEQ_CLS",
        target_modules=targets
    )
    peft_model = get_peft_model(base, peft_config)
    print(f"--> PEFT model ready for {model_name}")
    return peft_model


In [None]:
experts = []
for nm in expert_model_names:
    try:
        m = make_peft_model_auto(nm, num_labels=num_labels)
        experts.append(m)
    except Exception as e:
        print(f"Error building PEFT expert for {nm}: {e}")
        raise

if len(experts) == 0:
    raise RuntimeError("No experts were successfully created.")

print(f"Created {len(experts)} experts.")


In [None]:
class HybridMoE(nn.Module):
    def __init__(self, experts: List[nn.Module], hidden_dim: int = 768, top_k: int = 2):
        super().__init__()
        self.experts = nn.ModuleList(experts)
        self.num_experts = len(self.experts)
        self.top_k = top_k
        # router takes averaged hidden and outputs num_experts scores
        self.router = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_experts)
        )

    def forward(self, input_ids=None, attention_mask=None, labels=None):
        device = input_ids.device if input_ids is not None else next(self.parameters()).device

        # Each expert run: get logits and last hidden (CLS or first token)
        expert_logits = []
        hidden_states = []
        for i, expert in enumerate(self.experts):
            # Ensure expert is on the same device temporarily
            expert.to(device)
            outputs = expert(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
            # outputs.logits shape: (B, num_labels)
            expert_logits.append(outputs.logits)
            # outputs.hidden_states[-1] might be (B, seq_len, hidden_dim)
            last_hidden = outputs.hidden_states[-1][:, 0, :].detach()  # take first token (CLS)
            hidden_states.append(last_hidden)
            # Move expert back to CPU if memory is constrained (optional)
            # expert.to("cpu")

        # stack hidden states: (B, num_experts, hidden_dim) -> mean over experts -> (B, hidden_dim)
        hidden_states = torch.stack(hidden_states, dim=1).mean(dim=1)
        router_scores = self.router(hidden_states)  # (B, num_experts)

        # top-k gating
        k = min(self.top_k, self.num_experts)
        topk_vals, topk_idx = torch.topk(router_scores, k, dim=-1)  # (B, k)
        topk_weights = torch.softmax(topk_vals, dim=-1)  # (B, k)

        # Weighted combination of expert logits
        batch_logits = torch.zeros_like(expert_logits[0])
        B = batch_logits.size(0)
        for i in range(k):
            idx = topk_idx[:, i]  # (B,)
            w   = topk_weights[:, i].unsqueeze(-1)  # (B,1)
            # sum: for each batch element pick expert idx[b] and weight w[b]
            for b in range(B):
                batch_logits[b] += w[b] * expert_logits[idx[b]][b]

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(batch_logits, labels)
        return {"loss": loss, "logits": batch_logits}

# 10) Instantiate MoE
moe_model = HybridMoE(experts=experts, hidden_dim=768, top_k=2)
print("HybridMoE created with", len(experts), "experts.")


In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average="weighted")
    }

In [None]:
training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/Banking77_Project/outputs_moe4_peft_auto",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=4,   # conservative for T4 (reduce if OOM)
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs_moe4_peft",
    report_to="wandb",    # set to [] or remove if not using W&B
    run_name="HybridMoE_4Experts_PEFT_auto",
    fp16=True
)

trainer = Trainer(
    model=moe_model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()
trainer.evaluate()