In [22]:
import os, json, math, random, gc


os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HOME"] = "/root/autodl-tmp/cache"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTHONDONTWRITEBYTECODE"] = "1"
os.environ["TMPDIR"] = "/root/autodl-tmp/tmp"
os.environ["TORCH_HOME"] = "/root/autodl-tmp/torch"



from pathlib import Path
from typing import List, Dict, Any, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from sentence_transformers import SentenceTransformer

# PyTorch perf knobs
torch.backends.cuda.matmul.allow_tf32 = True
if hasattr(torch, "set_float32_matmul_precision"):
    torch.set_float32_matmul_precision("high")

device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.float32

print(f"Device: {device} | dtype: {dtype}")

# Setting up the Hugging Face model endpoint and cache location
# These variables must be set in your shell or at the beginning of the notebook.


Device: cuda:0 | dtype: torch.float32


In [2]:
import sys
print(sys.executable)

/root/miniconda3/bin/python


In [3]:
# Activate virtual environment (in the terminal or JupyterLab shell)
# Ensure that `kyro` environment is activated before running the notebook
# !conda activate kyro  # This should be run in a separate Jupyter cell or shell

# Change the root data directory for the server environment
LAVIC_DATA_DIR = Path("data")

# Adapter/category names (ensure the categories match what you want for the model)
gate_categories = [
    "amazon_home",
    "amazon_fashion",
    "all_beauty",
    "Appliances",
    "Arts_Crafts_and_Sewing",
    "Automotive",
    "Baby_Products",
    "Books",
    "CDs_and_Vinyl",
    "Cell_Phones_and_Accessories",
    "Digital_Music",
    "Electronics",
    "Grocery_and_Gourmet_Food",
    "Handmade_Products",
    "Health",
    "Industrial_and_Scientific",
    "Kindle_Store",
    "Movies_and_TV",
    "Musical_Instruments",
    "Office_Products",
    "Patio_Lawn_and_Garden",
    "Pet_Supplies",
    "Software",
    "Sports_and_Outdoors",
    "Toys_and_Games",
    "Video_Games",
]

# Fixed category-to-index mapping to avoid order dependence
GATE_MAP_PATH = Path("gate_label_mapping.json")
if GATE_MAP_PATH.exists():
    with GATE_MAP_PATH.open("r", encoding="utf-8") as f:
        _map = json.load(f)
    gate_categories = [_map[str(i)] for i in sorted(map(int, _map.keys()))]
else:
    gate_categories = sorted(gate_categories)
    with GATE_MAP_PATH.open("w", encoding="utf-8") as f:
        json.dump({i: c for i, c in enumerate(gate_categories)}, f, ensure_ascii=False, indent=2)



In [4]:
import sentence_transformers
print(sentence_transformers.__version__)
print(sentence_transformers.__file__)

5.2.0
/root/miniconda3/lib/python3.12/site-packages/sentence_transformers/__init__.py


In [5]:
# Model and data setup
ST_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
MAX_LENGTH = 256
epochs = 10
batch_size = 128 if torch.cuda.is_available() else 64
lr = 1e-3
weight_decay = 0.0
grad_clip = 1.0
seed = 42
random.seed(seed); torch.manual_seed(seed)
# Cache directory setup
CACHE_DIR = Path(os.environ.get("HF_HOME", "/root/autodl-tmp/cache"))
CACHE_DIR.mkdir(parents=True, exist_ok=True)
# Sentence Transformer setup for model inference
st_model = SentenceTransformer(
    ST_MODEL_NAME,
    device="cuda" if torch.cuda.is_available() else "cpu"
)
st_model.max_seq_length = MAX_LENGTH


In [6]:
"""
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"

# Model and data setup
ST_MODEL_NAME = "/root/autodl-tmp/cache/hub/models--sentence-transformers--all-minilm-l6-v2/snapshots/2424fdd47412fccc66d91719126b420e9fbd7065"
MAX_LENGTH = 256
epochs = 10
batch_size = 128 if torch.cuda.is_available() else 64
lr = 1e-3
weight_decay = 0.0
grad_clip = 1.0
seed = 42
random.seed(seed); torch.manual_seed(seed)

# Cache directory setup
CACHE_DIR = Path("/root/autodl-tmp/cache/")
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Sentence Transformer setup for model inference
st_model = SentenceTransformer(
    ST_MODEL_NAME,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

st_model.max_seq_length = MAX_LENGTH
"""

'\nos.environ["HF_HUB_OFFLINE"] = "1"\nos.environ["TRANSFORMERS_OFFLINE"] = "1"\n\n# Model and data setup\nST_MODEL_NAME = "/root/autodl-tmp/cache/hub/models--sentence-transformers--all-minilm-l6-v2/snapshots/2424fdd47412fccc66d91719126b420e9fbd7065"\nMAX_LENGTH = 256\nepochs = 10\nbatch_size = 128 if torch.cuda.is_available() else 64\nlr = 1e-3\nweight_decay = 0.0\ngrad_clip = 1.0\nseed = 42\nrandom.seed(seed); torch.manual_seed(seed)\n\n# Cache directory setup\nCACHE_DIR = Path("/root/autodl-tmp/cache/")\nCACHE_DIR.mkdir(parents=True, exist_ok=True)\n\n# Sentence Transformer setup for model inference\nst_model = SentenceTransformer(\n    ST_MODEL_NAME,\n    device="cuda" if torch.cuda.is_available() else "cpu"\n)\n\nst_model.max_seq_length = MAX_LENGTH\n'

In [7]:
print("Loaded local ST model successfully.")
print(st_model)

emb = st_model.encode(["hello world"])
print("Embedding shape:", emb.shape)

Loaded local ST model successfully.
SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False, 'architecture': 'BertModel'})
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)
Embedding shape: (1, 384)


In [8]:
class LaViCGateRawCategory(Dataset):
    """Loads ONLY ONE category for a given split."""
    def __init__(self, data_root: Path, category: str, split: str):
        self.samples = []
        fp = data_root / category / f"{split}.jsonl"

        if not fp.exists():
            print(f"[WARN] Missing {fp}. Skipping category {category}.")
            return

        with fp.open("r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    rec = json.loads(line)
                    text = rec.get("context", "").strip()
                    if text:
                        self.samples.append(text)
                except:
                    continue

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


@torch.no_grad()
def build_cache_category(split: str, st_model: SentenceTransformer, batch: int = 64):
    """
    Build per-category embedding cache:
      cache/train/Automotive.pt
      cache/valid/Automotive.pt
      ...
    Uses incremental batching to avoid memory explosion.
    """
    split_dir = CACHE_DIR / split
    split_dir.mkdir(parents=True, exist_ok=True)

    for idx, cat in enumerate(gate_categories):
        cache_path = split_dir / f"{cat}.pt"

        # Already exists → skip
        if cache_path.exists():
            print(f"[CACHE] Found: {cache_path} — skipping.")
            continue

        print(f"\n[EMBED] Building cache for {split}/{cat} ...")

        ds = LaViCGateRawCategory(LAVIC_DATA_DIR, cat, split)
        if len(ds) == 0:
            print(f"[WARN] No samples for category {cat} ({split}). Skipping.")
            continue

        X_parts = []
        Y_parts = []

        # Use small batches to prevent kernel crash
        for i in tqdm(range(0, len(ds), batch), desc=f"Encoding {cat}"):
            chunk = ds.samples[i:i+batch]

            emb = st_model.encode(
                chunk,
                batch_size=min(32, len(chunk)),   # safer batch size
                convert_to_tensor=True,
                device=device,
                show_progress_bar=False,
            )

            X_parts.append(emb.cpu())
            Y_parts.append(torch.full((emb.size(0),), idx, dtype=torch.long))

        # Merge
        X = torch.cat(X_parts, dim=0)
        y = torch.cat(Y_parts, dim=0)

        out = {
            "X": X,
            "y": y,
            "category": cat,
            "category_index": idx,
        }

        torch.save(out, cache_path)
        print(f"[CACHE SAVED] {cache_path} | X={tuple(X.shape)} | y={tuple(y.shape)}")

        # free memory
        del X_parts, Y_parts, X, y
        gc.collect()
        torch.cuda.empty_cache()


st_model = SentenceTransformer(ST_MODEL_NAME, device=device)
st_model.max_seq_length = MAX_LENGTH

for split in ["train", "valid", "test"]:
    print(f"\n===== BUILDING SPLIT: {split} =====")
    build_cache_category(split, st_model)


===== BUILDING SPLIT: train =====
[CACHE] Found: /root/autodl-tmp/cache/train/Appliances.pt — skipping.
[CACHE] Found: /root/autodl-tmp/cache/train/Arts_Crafts_and_Sewing.pt — skipping.
[CACHE] Found: /root/autodl-tmp/cache/train/Automotive.pt — skipping.
[CACHE] Found: /root/autodl-tmp/cache/train/Baby_Products.pt — skipping.
[CACHE] Found: /root/autodl-tmp/cache/train/Books.pt — skipping.
[CACHE] Found: /root/autodl-tmp/cache/train/CDs_and_Vinyl.pt — skipping.
[CACHE] Found: /root/autodl-tmp/cache/train/Cell_Phones_and_Accessories.pt — skipping.
[CACHE] Found: /root/autodl-tmp/cache/train/Digital_Music.pt — skipping.
[CACHE] Found: /root/autodl-tmp/cache/train/Electronics.pt — skipping.
[CACHE] Found: /root/autodl-tmp/cache/train/Grocery_and_Gourmet_Food.pt — skipping.
[CACHE] Found: /root/autodl-tmp/cache/train/Handmade_Products.pt — skipping.
[CACHE] Found: /root/autodl-tmp/cache/train/Health.pt — skipping.
[CACHE] Found: /root/autodl-tmp/cache/train/Industrial_and_Scientific.pt —

In [9]:
#@title Router dataset (cached tensors) & loaders
"""
class TensorGateDataset(Dataset):
    def __init__(self, tensor_path: Path):
        blob = torch.load(tensor_path, map_location="cpu")
        self.X = blob["X"]
        self.y = blob["y"]
        self.categories = blob["categories"]

    def __len__(self): return self.X.size(0)
    def __getitem__(self, i): return self.X[i], self.y[i]

def collate(batch):
    X, y = zip(*batch)
    X = torch.stack(X).to(device=device, dtype=dtype, non_blocking=True)
    y = torch.stack(y).to(device=device, non_blocking=True)
    return X, y

train_ds = TensorGateDataset(CACHE_DIR / "train.pt")
val_ds   = TensorGateDataset(CACHE_DIR / "valid.pt")

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate, drop_last=False)
# num_workers=2 is # of parallel background processes will load data, if we are using server, it can be larger
# pin_memory=True is for PyTorch to allocate batches in page-locked (pinned) host memory, is for gup
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate, drop_last=False)

embed_dim = train_ds.X.size(1)
print(f"Embed dim: {embed_dim} | Train {len(train_ds)} | Valid {len(val_ds)}")
"""

class TensorGateDataset(Dataset):
    """Loads ALL category-level cached tensors in a split directory."""
    def __init__(self, split_dir: Path):
        """
        split_dir:
            /root/autodl-tmp/cache/train
            /root/autodl-tmp/cache/valid
        """
        self.X_list = []
        self.y_list = []
        self.category_index_list = []

        pt_files = sorted(split_dir.glob("*.pt"))
        if len(pt_files) == 0:
            raise RuntimeError(f"No .pt cached files found in {split_dir}")

        print(f"[LOAD] Loading {len(pt_files)} category cache files from {split_dir}")

        for pt in pt_files:
            blob = torch.load(pt, map_location="cpu")

            X = blob["X"]          # [Nc, H]
            y = blob["y"]          # [Nc]
            cat = blob["category"] # str

            print(f"  - Loaded {pt.name} | {tuple(X.shape)} | category={cat}")

            self.X_list.append(X)
            self.y_list.append(y)

        # Concatenate all categories
        self.X = torch.cat(self.X_list, dim=0)
        self.y = torch.cat(self.y_list, dim=0)

        print(f"[MERGED] X={tuple(self.X.shape)}, y={tuple(self.y.shape)}")

    def __len__(self): 
        return self.X.size(0)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Collate
def collate(batch):
    X, y = zip(*batch)
    X = torch.stack(X).to(device=device, dtype=dtype, non_blocking=True)
    y = torch.stack(y).to(device=device, non_blocking=True)
    return X, y

train_ds = TensorGateDataset(CACHE_DIR / "train")
val_ds   = TensorGateDataset(CACHE_DIR / "valid")

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,  # change if GPU
    pin_memory=False,
    collate_fn=collate,
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,   # change if GPU
    pin_memory=False,
    collate_fn=collate,
)

embed_dim = train_ds.X.size(1)
print(f"Embed dim: {embed_dim} | Train {len(train_ds)} | Valid {len(val_ds)}")

[LOAD] Loading 26 category cache files from /root/autodl-tmp/cache/train
  - Loaded Appliances.pt | (1832, 384) | category=Appliances
  - Loaded Arts_Crafts_and_Sewing.pt | (1013, 384) | category=Arts_Crafts_and_Sewing
  - Loaded Automotive.pt | (1209, 384) | category=Automotive
  - Loaded Baby_Products.pt | (2668, 384) | category=Baby_Products
  - Loaded Books.pt | (3765, 384) | category=Books
  - Loaded CDs_and_Vinyl.pt | (3054, 384) | category=CDs_and_Vinyl
  - Loaded Cell_Phones_and_Accessories.pt | (2860, 384) | category=Cell_Phones_and_Accessories
  - Loaded Digital_Music.pt | (1388, 384) | category=Digital_Music
  - Loaded Electronics.pt | (3196, 384) | category=Electronics
  - Loaded Grocery_and_Gourmet_Food.pt | (2291, 384) | category=Grocery_and_Gourmet_Food
  - Loaded Handmade_Products.pt | (2284, 384) | category=Handmade_Products
  - Loaded Health.pt | (1440, 384) | category=Health
  - Loaded Industrial_and_Scientific.pt | (3182, 384) | category=Industrial_and_Scientific
  

In [10]:
#@title Simple Router & training
class SimpleRouter(nn.Module):
    def __init__(self, hidden_size, num_adapters):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_size, num_adapters)
        )
    def forward(self, x):  # x: [B, H]
        return self.net(x)

router = SimpleRouter(hidden_size=embed_dim, num_adapters=len(gate_categories)).to(device).to(dtype)

criterion = nn.CrossEntropyLoss()
optimizer = AdamW(router.parameters(), lr=lr, weight_decay=weight_decay)

@torch.no_grad()
def evaluate():
    router.eval()
    total, correct, total_loss = 0, 0, 0.0
    for X, y in val_loader:
        logits = router(X)  # [B, C]
        loss = criterion(logits, y)
        preds = logits.argmax(dim=-1)
        correct += (preds == y).sum().item()
        total += y.numel()
        total_loss += loss.item() * y.size(0)
    return (total_loss / max(total, 1)), (correct / max(total, 1))


best_val_acc, best_state = 0.0, None
for epoch in range(1, epochs + 1):
    router.train()
    total_train_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", unit="batch")
    for X, y in pbar:
        logits = router(X)
        loss = criterion(logits, y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        clip_grad_norm_(router.parameters(), grad_clip)
        optimizer.step()

        total_train_loss += loss.item() * y.size(0)
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    val_loss, val_acc = evaluate()
    train_loss = total_train_loss / len(train_ds)
    print(f"[Epoch {epoch}] train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc:.4f}")
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_state = {k: v.detach().cpu() for k, v in router.state_dict().items()}

# Save best state
CKPT_DIR = Path("/root/autodl-tmp/cache/llmSimpleRouter"); CKPT_DIR.mkdir(parents=True, exist_ok=True)
torch.save(best_state if best_state is not None else router.state_dict(), CKPT_DIR / "gate_router.pt")
with (CKPT_DIR / "gate_label_mapping.json").open("w", encoding="utf-8") as f:
    json.dump({i: cat for i, cat in enumerate(gate_categories)}, f, ensure_ascii=False, indent=2)

print(f"[DONE] Saved router to {CKPT_DIR/'gate_router.pt'}")

Epoch 1/10: 100%|██████████| 729/729 [00:01<00:00, 466.65batch/s, loss=1.3634]


[Epoch 1] train_loss=1.9360 | val_loss=1.2105 | val_acc=0.7641


Epoch 2/10: 100%|██████████| 729/729 [00:01<00:00, 510.19batch/s, loss=0.6909]


[Epoch 2] train_loss=0.9570 | val_loss=0.7788 | val_acc=0.8266


Epoch 3/10: 100%|██████████| 729/729 [00:01<00:00, 544.07batch/s, loss=0.7514]


[Epoch 3] train_loss=0.6976 | val_loss=0.6288 | val_acc=0.8435


Epoch 4/10: 100%|██████████| 729/729 [00:01<00:00, 503.93batch/s, loss=0.5831]


[Epoch 4] train_loss=0.5913 | val_loss=0.5566 | val_acc=0.8537


Epoch 5/10: 100%|██████████| 729/729 [00:01<00:00, 512.73batch/s, loss=0.5208]


[Epoch 5] train_loss=0.5340 | val_loss=0.5141 | val_acc=0.8587


Epoch 6/10: 100%|██████████| 729/729 [00:01<00:00, 523.42batch/s, loss=0.3683]


[Epoch 6] train_loss=0.4981 | val_loss=0.4867 | val_acc=0.8628


Epoch 7/10: 100%|██████████| 729/729 [00:01<00:00, 563.06batch/s, loss=0.4981]


[Epoch 7] train_loss=0.4734 | val_loss=0.4672 | val_acc=0.8664


Epoch 8/10: 100%|██████████| 729/729 [00:01<00:00, 533.57batch/s, loss=0.5355]


[Epoch 8] train_loss=0.4552 | val_loss=0.4529 | val_acc=0.8687


Epoch 9/10: 100%|██████████| 729/729 [00:01<00:00, 519.04batch/s, loss=0.4218]


[Epoch 9] train_loss=0.4412 | val_loss=0.4419 | val_acc=0.8705


Epoch 10/10: 100%|██████████| 729/729 [00:01<00:00, 551.80batch/s, loss=0.6967]


[Epoch 10] train_loss=0.4300 | val_loss=0.4333 | val_acc=0.8728
[DONE] Saved router to /root/autodl-tmp/cache/llmSimpleRouter/gate_router.pt


In [11]:
print("num_adapters =", len(gate_categories))
print("router output size:", router.net[0].out_features)
import torch
ys = train_ds.y.tolist() + val_ds.y.tolist()
print("Unique labels:", sorted(set(ys)))
print("Count:", len(set(ys)))
from collections import Counter

cnt = Counter(ys)
for idx in range(19):
    print(idx, gate_categories[idx], "count =", cnt.get(idx, 0))

num_adapters = 26
router output size: 26
Unique labels: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
Count: 26
0 Appliances count = 2061
1 Arts_Crafts_and_Sewing count = 1139
2 Automotive count = 1360
3 Baby_Products count = 3001
4 Books count = 4235
5 CDs_and_Vinyl count = 3435
6 Cell_Phones_and_Accessories count = 3217
7 Digital_Music count = 1561
8 Electronics count = 3595
9 Grocery_and_Gourmet_Food count = 2577
10 Handmade_Products count = 2569
11 Health count = 1620
12 Industrial_and_Scientific count = 3579
13 Kindle_Store count = 15414
14 Movies_and_TV count = 2055
15 Musical_Instruments count = 3775
16 Office_Products count = 2614
17 Patio_Lawn_and_Garden count = 1032
18 Pet_Supplies count = 4696


In [12]:
#@title 5. Evaluation report (fixed & minimal modification)

from sklearn.metrics import classification_report, confusion_matrix

# ==== 修复 1：加载模型时确保 dtype + device 一致 ====
router2 = SimpleRouter(hidden_size=embed_dim, num_adapters=len(gate_categories)).to(device).to(dtype)
state = torch.load(CKPT_DIR / "gate_router.pt", map_location=device)
router2.load_state_dict(state)
router2.eval()

# ==== 修复 2：使用训练时保存的 label mapping，确保类别顺序一致 ====
with open(CKPT_DIR / "gate_label_mapping.json", "r", encoding="utf-8") as f:
    label_map = json.load(f)

# label_map 是 {index: category_name}
sorted_indices = sorted(int(i) for i in label_map.keys())
sorted_names   = [label_map[str(i)] for i in sorted_indices]

# ==== 运行验证 ====
all_preds, all_labels = [], []

with torch.no_grad():
    for X, y in tqdm(val_loader, desc="Evaluating"):
        logits = router2(X)
        preds = logits.argmax(dim=-1).cpu().tolist()
        all_preds.extend(preds)
        all_labels.extend(y.cpu().tolist())

# validation 中实际出现的类别
present_classes = sorted(set(all_labels))
present_names   = [label_map[str(i)] for i in present_classes]

print("Detected classes in eval:", present_names)

# ==== 修复 3：report 使用 present_classes, 避免缺失类报错 ====
print(classification_report(
    all_labels,
    all_preds,
    labels=present_classes,
    target_names=present_names,
    zero_division=0   # 防止出现未预测类别时报 warning
))

# 输出 confusion matrix（按出现过的类）
print(confusion_matrix(
    all_labels,
    all_preds,
    labels=present_classes
))


Evaluating: 100%|██████████| 91/91 [00:00<00:00, 2073.61it/s]

Detected classes in eval: ['Appliances', 'Arts_Crafts_and_Sewing', 'Automotive', 'Baby_Products', 'Books', 'CDs_and_Vinyl', 'Cell_Phones_and_Accessories', 'Digital_Music', 'Electronics', 'Grocery_and_Gourmet_Food', 'Handmade_Products', 'Health', 'Industrial_and_Scientific', 'Kindle_Store', 'Movies_and_TV', 'Musical_Instruments', 'Office_Products', 'Patio_Lawn_and_Garden', 'Pet_Supplies', 'Software', 'Sports_and_Outdoors', 'Toys_and_Games', 'Video_Games', 'all_beauty', 'amazon_fashion', 'amazon_home']
                             precision    recall  f1-score   support

                 Appliances       0.93      0.97      0.95       229
     Arts_Crafts_and_Sewing       0.91      0.78      0.84       126
                 Automotive       0.85      0.79      0.82       151
              Baby_Products       0.90      0.93      0.92       333
                      Books       0.78      0.72      0.75       470
              CDs_and_Vinyl       0.93      0.93      0.93       381
Cell_Phone




In [30]:
#@title Inspect split distribution & make class weights
import torch
import numpy as np

counts = torch.bincount(train_ds.y)
class_names = gate_categories
print("Train counts per class:", dict(zip(class_names, counts.tolist())))

# inverse-frequency weights (normalized)
class_weights = (counts.sum() / (counts + 1e-8)).float()
class_weights = class_weights / class_weights.mean()

# extra boost for underperforming class (if present)
if "amazon_home" in class_names:
    idx = class_names.index("amazon_home")
    class_weights[idx] *= 2.0
    print(f"[DEBUG] Boost amazon_home weight -> {class_weights[idx].item():.4f}")

print("Class weights:", dict(zip(class_names, class_weights.tolist())))


Train counts per class: {'Appliances': 1832, 'Arts_Crafts_and_Sewing': 1013, 'Automotive': 1209, 'Baby_Products': 2668, 'Books': 3765, 'CDs_and_Vinyl': 3054, 'Cell_Phones_and_Accessories': 2860, 'Digital_Music': 1388, 'Electronics': 3196, 'Grocery_and_Gourmet_Food': 2291, 'Handmade_Products': 2284, 'Health': 1440, 'Industrial_and_Scientific': 3182, 'Kindle_Store': 13703, 'Movies_and_TV': 1827, 'Musical_Instruments': 3356, 'Office_Products': 2324, 'Patio_Lawn_and_Garden': 918, 'Pet_Supplies': 4175, 'Software': 1821, 'Sports_and_Outdoors': 2558, 'Toys_and_Games': 1527, 'Video_Games': 15313, 'all_beauty': 6142, 'amazon_fashion': 6416, 'amazon_home': 2961}
[DEBUG] Boost amazon_home weight -> 1.5155
Class weights: {'Appliances': 1.2247132062911987, 'Arts_Crafts_and_Sewing': 2.214881181716919, 'Automotive': 1.855810284614563, 'Baby_Products': 0.840957522392273, 'Books': 0.5959295034408569, 'CDs_and_Vinyl': 0.7346675395965576, 'Cell_Phones_and_Accessories': 0.7845016121864319, 'Digital_Music'

In [31]:
#@title Balanced sampler DataLoader + Weighted CE + label smoothing criterion
from torch.utils.data import WeightedRandomSampler, DataLoader

sample_weights = class_weights[train_ds.y]
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader_balanced = DataLoader(
    train_ds, batch_size=batch_size, sampler=sampler,
    num_workers=0, pin_memory=False, collate_fn=collate, drop_last=False
)

crit_weighted = torch.nn.CrossEntropyLoss(
    weight=class_weights.to(device),
    label_smoothing=0.05
)

In [15]:
#@title RouterMLP (regularized)
import torch.nn as nn

class RouterMLP(nn.Module):
    def __init__(self, hidden_size, num_adapters, width=512, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_size, width),
            nn.LayerNorm(width),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(width, num_adapters)
        )
    def forward(self, x):
        return self.net(x)

router_mlp = RouterMLP(embed_dim, len(gate_categories)).to(device).to(dtype)
print(router_mlp)

RouterMLP(
  (net): Sequential(
    (0): Linear(in_features=384, out_features=512, bias=True)
    (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (2): GELU(approximate='none')
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=512, out_features=26, bias=True)
  )
)


In [16]:
#@title Train RouterMLP with balanced loader + cosine schedule + early stop
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm

max_epochs = 20
use_balanced = True   # set False to try original loader
loader = train_loader_balanced if use_balanced else train_loader

criterion = crit_weighted  # or: nn.CrossEntropyLoss()
optimizer = AdamW(router_mlp.parameters(), lr=1e-3, weight_decay=0.0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs, eta_min=1e-4)
patience, bad = 4, 0
best_acc, best_state = 0.0, None

@torch.no_grad()
def eval_router(model):
    model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    for X, y in val_loader:
        logits = model(X)
        loss = criterion(logits, y)
        preds = logits.argmax(-1)
        correct += (preds == y).sum().item()
        total += y.numel()
        loss_sum += loss.item() * y.size(0)
    return loss_sum / max(total,1), correct / max(total,1)

for epoch in range(1, max_epochs + 1):
    router_mlp.train()
    run_loss = 0.0
    pbar = tqdm(loader, desc=f"Epoch {epoch}/{max_epochs}", unit="batch")
    for X, y in pbar:
        logits = router_mlp(X)
        loss = criterion(logits, y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        clip_grad_norm_(router_mlp.parameters(), 1.0)
        optimizer.step()
        run_loss += loss.item() * y.size(0)
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    val_loss, val_acc = eval_router(router_mlp)
    train_loss = run_loss / len(train_ds)
    print(f"[Epoch {epoch}] train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc:.4f}")

    if val_acc > best_acc:
        best_acc, bad = val_acc, 0
        best_state = {k: v.detach().cpu() for k, v in router_mlp.state_dict().items()}
    else:
        bad += 1
        if bad >= patience:
            print(f"[EarlyStop] no val acc improvement for {patience} epochs.")
            break
    scheduler.step()

# Save the stronger router
CKPT_DIR2 = Path("/root/autodl-tmp/cache/llmRouterMLP"); CKPT_DIR2.mkdir(parents=True, exist_ok=True)
torch.save(best_state if best_state else router_mlp.state_dict(), CKPT_DIR2 / "gate_router_mlp.pt")
label_map = {str(i): name for i, name in enumerate(gate_categories)}
with (CKPT_DIR2 / "gate_label_mapping.json").open("w", encoding="utf-8") as f:
    json.dump(label_map, f, ensure_ascii=False, indent=2)
print(f"[DONE] Saved MLP router → {CKPT_DIR2/'gate_router_mlp.pt'} | best_acc={best_acc:.4f}")
print(f"[DONE] Saved label map → {CKPT_DIR2/'gate_label_mapping.json'}")

Epoch 1/20: 100%|██████████| 729/729 [00:02<00:00, 320.10batch/s, loss=0.5582]


[Epoch 1] train_loss=0.6616 | val_loss=1.4201 | val_acc=0.8132


Epoch 2/20: 100%|██████████| 729/729 [00:02<00:00, 329.45batch/s, loss=0.5212]


[Epoch 2] train_loss=0.5299 | val_loss=1.3984 | val_acc=0.8236


Epoch 3/20: 100%|██████████| 729/729 [00:02<00:00, 356.85batch/s, loss=0.4218]


[Epoch 3] train_loss=0.4939 | val_loss=1.3250 | val_acc=0.8543


Epoch 4/20: 100%|██████████| 729/729 [00:02<00:00, 347.63batch/s, loss=0.3568]


[Epoch 4] train_loss=0.4698 | val_loss=1.3298 | val_acc=0.8560


Epoch 5/20: 100%|██████████| 729/729 [00:02<00:00, 349.75batch/s, loss=0.4489]


[Epoch 5] train_loss=0.4527 | val_loss=1.3175 | val_acc=0.8592


Epoch 6/20: 100%|██████████| 729/729 [00:02<00:00, 345.03batch/s, loss=0.4983]


[Epoch 6] train_loss=0.4425 | val_loss=1.3340 | val_acc=0.8480


Epoch 7/20: 100%|██████████| 729/729 [00:02<00:00, 358.97batch/s, loss=0.4631]


[Epoch 7] train_loss=0.4297 | val_loss=1.3044 | val_acc=0.8628


Epoch 8/20: 100%|██████████| 729/729 [00:02<00:00, 348.58batch/s, loss=0.3729]


[Epoch 8] train_loss=0.4213 | val_loss=1.2867 | val_acc=0.8712


Epoch 9/20: 100%|██████████| 729/729 [00:02<00:00, 351.89batch/s, loss=0.4749]


[Epoch 9] train_loss=0.4157 | val_loss=1.2859 | val_acc=0.8677


Epoch 10/20: 100%|██████████| 729/729 [00:02<00:00, 346.60batch/s, loss=0.4295]


[Epoch 10] train_loss=0.4081 | val_loss=1.2885 | val_acc=0.8663


Epoch 11/20: 100%|██████████| 729/729 [00:02<00:00, 350.50batch/s, loss=0.4718]


[Epoch 11] train_loss=0.4036 | val_loss=1.2823 | val_acc=0.8719


Epoch 12/20: 100%|██████████| 729/729 [00:02<00:00, 329.57batch/s, loss=0.3660]


[Epoch 12] train_loss=0.3986 | val_loss=1.2703 | val_acc=0.8771


Epoch 13/20: 100%|██████████| 729/729 [00:02<00:00, 341.68batch/s, loss=0.3743]


[Epoch 13] train_loss=0.3964 | val_loss=1.2695 | val_acc=0.8769


Epoch 14/20: 100%|██████████| 729/729 [00:02<00:00, 322.17batch/s, loss=0.3555]


[Epoch 14] train_loss=0.3931 | val_loss=1.2622 | val_acc=0.8814


Epoch 15/20: 100%|██████████| 729/729 [00:02<00:00, 355.98batch/s, loss=0.4085]


[Epoch 15] train_loss=0.3914 | val_loss=1.2629 | val_acc=0.8805


Epoch 16/20: 100%|██████████| 729/729 [00:02<00:00, 337.71batch/s, loss=0.3911]


[Epoch 16] train_loss=0.3871 | val_loss=1.2587 | val_acc=0.8815


Epoch 17/20: 100%|██████████| 729/729 [00:01<00:00, 366.70batch/s, loss=0.4193]


[Epoch 17] train_loss=0.3862 | val_loss=1.2558 | val_acc=0.8832


Epoch 18/20: 100%|██████████| 729/729 [00:02<00:00, 336.28batch/s, loss=0.3465]


[Epoch 18] train_loss=0.3852 | val_loss=1.2589 | val_acc=0.8802


Epoch 19/20: 100%|██████████| 729/729 [00:02<00:00, 347.07batch/s, loss=0.4368]


[Epoch 19] train_loss=0.3839 | val_loss=1.2534 | val_acc=0.8833


Epoch 20/20: 100%|██████████| 729/729 [00:02<00:00, 344.15batch/s, loss=0.3690]


[Epoch 20] train_loss=0.3836 | val_loss=1.2560 | val_acc=0.8826
[DONE] Saved MLP router → /root/autodl-tmp/cache/llmRouterMLP/gate_router_mlp.pt | best_acc=0.8833
[DONE] Saved label map → /root/autodl-tmp/cache/llmRouterMLP/gate_label_mapping.json


In [17]:
#@title Evaluate Top-K accuracy on cached validation set
import torch

@torch.no_grad()
def eval_topk(model, val_ds, Ks=(1,2,3)):
    X = val_ds.X.to(device=device, dtype=dtype)
    y = val_ds.y.to(device)
    logits = model(X)
    probs = torch.softmax(logits, dim=-1)
    for K in Ks:
        topk = probs.topk(K, dim=-1).indices
        acc = (topk == y.unsqueeze(1)).any(dim=1).float().mean().item()
        print(f"Top-{K} accuracy: {acc:.4f}")

# Load the best MLP router and evaluate
router_mlp2 = RouterMLP(embed_dim, len(gate_categories)).to(device).to(dtype)
router_mlp2.load_state_dict(torch.load(CKPT_DIR2 / "gate_router_mlp.pt", map_location=device))
router_mlp2.eval()
eval_topk(router_mlp2, val_ds, Ks=(1,2,3))

# this will have 1 for top-3 accuracy if we are using 3 cates, because this is the chance of true cate is in top-k
# But maybe we do not need the accuracy to be very high in gating, because we will use top-k in MoLoRAs

Top-1 accuracy: 0.8833
Top-2 accuracy: 0.9509
Top-3 accuracy: 0.9643


In [18]:
#@title Thresholded gate → Top-K fallback at inference
# Instead of always choosing a single adapter (Top-1) or always mixing multiple (Top-K), dapts dynamically based on the router’s confidence in its prediction.
@torch.no_grad()
def gate_weights_thresholded(prompt: str, top_k_fallback: int = 2, tau: float = 0.60):
    emb = st_model.encode([prompt], convert_to_tensor=True, device=device, show_progress_bar=False).to(dtype)
    logits = router_mlp2(emb)
    probs = torch.softmax(logits, dim=-1)  # [1, C]
    pmax, imax = probs.max(dim=-1)         # [1]
    if pmax.item() >= tau:
        idxs = [imax.item()]
        ws = [1.0]
        mode = "top1"
    else:
        k = min(top_k_fallback, probs.size(-1))
        topk = probs.topk(k=k, dim=-1)
        idxs = topk.indices.squeeze(0).tolist()
        ws = (topk.values.squeeze(0) / topk.values.sum()).tolist()
        mode = f"top{k}"
    return idxs, ws, mode, pmax.item()

# Example:
idxs, ws, mode, conf = gate_weights_thresholded("I need a kettle and toaster set")
print(idxs, ws, mode, conf)

[25] [1.0] top1 0.969990074634552


In [19]:
#@title Temperature scaling on validation (grid search)
# make the router’s predicted probabilities match reality
import torch.nn.functional as F

@torch.no_grad()
def tune_temperature(model, val_loader, t_min=0.5, t_max=3.0, steps=26):
    best_T, best_nll = 1.0, float("inf")
    for T in torch.linspace(t_min, t_max, steps=steps, device=device):
        nll, n = 0.0, 0
        for X, y in val_loader:
            logits = model(X) / T
            nll += F.cross_entropy(logits, y, reduction="sum").item()
            n += y.size(0)
        if nll < best_nll:
            best_nll, best_T = nll, T.item()
    return best_T

T_star = tune_temperature(router_mlp2, val_loader)
print(f"Calibrated temperature T* = {T_star:.3f}")
# to use, change torch.softmax(router(emb), dim=-1) to torch.softmax(router(emb) / T_star, dim=-1)

Calibrated temperature T* = 0.700


In [20]:
#@title Full Evaluation (Top-K, calibration, thresholded gate)
import torch
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
from tqdm import tqdm


# Load the trained router
router_eval = RouterMLP(embed_dim, len(gate_categories)).to(device).to(dtype)
ckpt = torch.load(CKPT_DIR2 / "gate_router_mlp.pt", map_location=device)
router_eval.load_state_dict(ckpt)
router_eval.eval()


# Top-1 evaluation
all_preds, all_labels = [], []
with torch.no_grad():
    for X, y in tqdm(val_loader, desc="Evaluating Top-1"):
        logits = router_eval(X)
        preds = logits.argmax(-1)
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(y.cpu().tolist())

print("==== Top-1 classification report ====")
print(classification_report(all_labels, all_preds, target_names=gate_categories))
print("Confusion matrix:")
print(confusion_matrix(all_labels, all_preds))

# Top-K accuracy metrics (K=1,2,3)
@torch.no_grad()
def eval_topk(model, val_ds, Ks=(1,2,3)):
    X = val_ds.X.to(device=device, dtype=dtype)
    y = val_ds.y.to(device)
    logits = model(X)
    probs = torch.softmax(logits, dim=-1)
    results = {}
    for K in Ks:
        topk = probs.topk(K, dim=-1).indices
        acc = (topk == y.unsqueeze(1)).any(dim=1).float().mean().item()
        results[f"Top-{K}"] = acc
    return results

topk_scores = eval_topk(router_eval, val_ds)
print("==== Top-K Accuracies ====")
for k, v in topk_scores.items():
    print(f"{k}: {v:.4f}")


# Temperature calibration on validation set
@torch.no_grad()
def tune_temperature(model, val_loader, t_min=0.5, t_max=3.0, steps=26):
    best_T, best_nll = 1.0, float("inf")
    for T in torch.linspace(t_min, t_max, steps=steps, device=device):
        nll, n = 0.0, 0
        for X, y in val_loader:
            logits = model(X) / T
            nll += F.cross_entropy(logits, y, reduction="sum").item()
            n += y.size(0)
        if nll < best_nll:
            best_nll, best_T = nll, T.item()
    return best_T

T_star = tune_temperature(router_eval, val_loader)
print(f"==== Calibrated temperature T* = {T_star:.3f} ====")


# Thresholded-gate evaluation (Top-1 unless uncertain)
@torch.no_grad()
def eval_thresholded_gate(model, val_ds, tau=0.6, top_k_fallback=2, T=1.0):
    X = val_ds.X.to(device=device, dtype=dtype)
    y = val_ds.y.to(device)
    logits = model(X) / T
    probs = torch.softmax(logits, dim=-1)
    pmax, imax = probs.max(dim=-1)
    # If confident → Top-1; else → Top-K
    topk = probs.topk(top_k_fallback, dim=-1).indices
    preds_thresh = []
    for i in range(len(y)):
        if pmax[i] >= tau:
            preds_thresh.append(imax[i].item())
        else:
            # Choose the top-K category that includes y[i] if any, else top1 fallback
            if y[i].item() in topk[i]:
                preds_thresh.append(y[i].item())  # counted as correct
            else:
                preds_thresh.append(imax[i].item())
    acc = np.mean(np.array(preds_thresh) == y.cpu().numpy())
    print(f"[τ={tau:.2f}, K={top_k_fallback}] adaptive Top-1/Top-{top_k_fallback} accuracy: {acc:.4f}")
    return acc

# Run the thresholded evaluation both uncalibrated and calibrated
acc_uncal = eval_thresholded_gate(router_eval, val_ds, tau=0.6, top_k_fallback=2, T=1.0)
acc_cal   = eval_thresholded_gate(router_eval, val_ds, tau=0.6, top_k_fallback=2, T=T_star)

# Summary
print("==== Summary ====")
print(f"Top-1 accuracy:   {topk_scores['Top-1']:.4f}")
print(f"Top-2 accuracy:   {topk_scores['Top-2']:.4f}")
print(f"Top-3 accuracy:   {topk_scores['Top-3']:.4f}")
print(f"Thresholded acc (uncalibrated): {acc_uncal:.4f}")
print(f"Thresholded acc (calibrated T*): {acc_cal:.4f}")
print("Note: Higher Top-K or calibrated thresholded acc → better MoLoRA routing coverage.")

# τ is the confidence threshold for the router’s top-1 prediction.If p_max ≥ τ: the model is confident enough → trust Top-1

Evaluating Top-1: 100%|██████████| 91/91 [00:00<00:00, 531.94it/s]


==== Top-1 classification report ====
                             precision    recall  f1-score   support

                 Appliances       0.93      0.99      0.96       229
     Arts_Crafts_and_Sewing       0.74      0.95      0.83       126
                 Automotive       0.81      0.87      0.84       151
              Baby_Products       0.92      0.98      0.95       333
                      Books       0.65      0.96      0.77       470
              CDs_and_Vinyl       0.92      0.94      0.93       381
Cell_Phones_and_Accessories       0.96      0.97      0.97       357
              Digital_Music       0.87      0.95      0.91       173
                Electronics       0.82      0.91      0.86       399
   Grocery_and_Gourmet_Food       0.88      0.95      0.91       286
          Handmade_Products       0.85      0.94      0.89       285
                     Health       0.71      0.80      0.75       180
  Industrial_and_Scientific       0.87      0.82      0.84      

In [32]:
#@title Integration with LoRA adapters (Top‑1 & Top‑K)
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

ENABLE_LORA_INTEGRATION = False  # router-only by default
peft_model = None

if ENABLE_LORA_INTEGRATION:
    BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # placeholder, change to the model LoRA used.
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
    base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=dtype, device_map="auto")
    peft_model = PeftModel(base_model)  # wraps but no adapters yet
else:
    print("[INFO] LoRA integration disabled (router-only).")

# --- Load LoRAs. Each directory should contain adapter_config.json & adapter_model.bin
# IMPORTANT: the base model **must match** the base used to train each LoRA.
# peft_model.load_adapter("/content/lora_ckpts/lm_all_beauty", adapter_name="lm_all_beauty")
# peft_model.load_adapter("/content/lora_ckpts/lm_amazon_fashion", adapter_name="lm_amazon_fashion")
# peft_model.load_adapter("/content/lora_ckpts/lm_amazon_home", adapter_name="lm_amazon_home")

# Example for loading a single adapter
# peft_model = PeftModel.from_pretrained(base_model, "/content/lora_ckpts/lm_all_beauty", adapter_name="lm_all_beauty")

if peft_model is not None:
    peft_model.eval()

# --- Reload gate
router = SimpleRouter(hidden_size=embed_dim, num_adapters=len(gate_categories)).to(device).to(dtype)
router.load_state_dict(torch.load(CKPT_DIR / "gate_router.pt", map_location=device))
router.eval()

@torch.no_grad()
def gate_weights_from_text(prompt: str, top_k: int = 1):
    """Compute router softmax and return Top‑K (indices, weights)."""
    emb = st_model.encode([prompt], convert_to_tensor=True, device=device, show_progress_bar=False).to(dtype)
    logits = router(emb)                  # [1, C]
    probs = torch.softmax(logits, dim=-1) # [1, C]
    k = min(top_k, probs.size(-1))
    topk = torch.topk(probs, k=k, dim=-1)
    idxs = topk.indices.squeeze(0).tolist()
    ws = topk.values.squeeze(0).tolist()
    # normalize to sum=1 for safety
    s = sum(ws) + 1e-12
    ws = [w/s for w in ws]
    return idxs, ws

def _ensure_loaded_adapters(adapter_names: List[str]):
    if peft_model is None:
        raise ValueError("peft_model is None. Enable LoRA integration to use adapters.")
    missing = [a for a in adapter_names if a not in peft_model.peft_config]
    if missing:
        raise ValueError(f"Adapters not loaded: {missing}. Use peft_model.load_adapter(path, adapter_name) first.")


@torch.no_grad()
def generate_top1(prompt: str, adapter_names: List[str], max_new_tokens: int = 64):
    idxs, ws = gate_weights_from_text(prompt, top_k=1)
    _ensure_loaded_adapters([adapter_names[i] for i in idxs])
    chosen = adapter_names[idxs[0]]
    peft_model.set_adapter(chosen)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    out = peft_model.generate(**inputs, max_new_tokens=max_new_tokens)
    return tokenizer.decode(out[0], skip_special_tokens=True), {"adapter": chosen}


@torch.no_grad()
def generate_topk(prompt: str, adapter_names: List[str], top_k: int = 2, max_new_tokens: int = 64):
    """Top‑K mixture inference.
    If the PEFT version exposes `add_weighted_adapter`, we can fuse adapters once and run a single generate().
    Otherwise we do a safe fallback that mixes logits step‑by‑step (slower).
    """
    idxs, ws = gate_weights_from_text(prompt, top_k=top_k)
    names = [adapter_names[i] for i in idxs]
    _ensure_loaded_adapters(names)

    # Try fast path: weighted fusion into a temporary adapter
    if hasattr(peft_model, "add_weighted_adapter"):
        try:
            tmp_name = "mol_temp"
            # Clean old temp
            if tmp_name in getattr(peft_model, "peft_config", {}):
                peft_model.delete_adapter(tmp_name)
            peft_model.add_weighted_adapter(adapters=names, weights=ws, adapter_name=tmp_name)
            peft_model.set_adapter(tmp_name)
            inputs = tokenizer(prompt, return_tensors="pt").to(device)
            out = peft_model.generate(**inputs, max_new_tokens=max_new_tokens)
            text = tokenizer.decode(out[0], skip_special_tokens=True)
            peft_model.delete_adapter(tmp_name)
            return text, {"adapters": names, "weights": ws, "mode": "weighted_fusion"}
        except Exception as e:
            print("[WARN] Weighted fusion unavailable, falling back to logits mixture:", e)

    # Fallback: per‑step logits mixture (k forward passes per token)
    # NOTE: This is slower but works on all PEFT versions.
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = inputs["input_ids"]
    attn = inputs.get("attention_mask", None)
    generated = input_ids
    past_key_values = None
    peft_model.generation_config.use_cache = True

    for _ in range(max_new_tokens):
        logits_sum = None
        pkv_next = None
        for name, w in zip(names, ws):
            peft_model.set_adapter(name)
            out = peft_model(input_ids=generated, attention_mask=attn, use_cache=True, past_key_values=past_key_values)
            logits_i = out.logits[:, -1, :]  # last token
            if logits_sum is None:
                logits_sum = w * logits_i
                pkv_next = out.past_key_values
            else:
                logits_sum = logits_sum + w * logits_i

        next_id = torch.argmax(logits_sum, dim=-1, keepdim=True)
        generated = torch.cat([generated, next_id], dim=-1)
        if attn is not None:
            attn = torch.cat([attn, torch.ones_like(next_id)], dim=-1)
        past_key_values = pkv_next  # not strictly correct for mixture, keeps speed reasonable

    text = tokenizer.decode(generated[0], skip_special_tokens=True)
    return text, {"adapters": names, "weights": ws, "mode": "logits_mixture"}

[INFO] LoRA integration disabled (router-only).


In [33]:
#@title Demo
adapter_names = [f"lm_{c}" for c in gate_categories]

prompt = "I'm looking for a pair of running shoes for daily jogging—any suggestions?"
# Before running, make sure loaded real adapters whose names match adapter_names.
# text1, info1 = generate_top1(prompt, adapter_names, max_new_tokens=64)
# print("\n[TOP-1]", info1, "\n", text1)

# textk, infok = generate_topk(prompt, adapter_names, top_k=2, max_new_tokens=64)
# print("\n[TOP-K]", infok, "\n", textk)

In [34]:
#@title Train RouterMLPPlus (deeper, label smoothing)
import torch.nn as nn
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_

class RouterMLPPlus(nn.Module):
    def __init__(self, hidden_size, num_adapters, widths=(512, 256), dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_size, widths[0]),
            nn.LayerNorm(widths[0]),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(widths[0], widths[1]),
            nn.LayerNorm(widths[1]),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(widths[1], num_adapters)
        )
    def forward(self, x):
        return self.net(x)

router_mlp_plus = RouterMLPPlus(embed_dim, len(gate_categories)).to(device).to(dtype)

max_epochs = 30
use_balanced = True
loader = train_loader_balanced if use_balanced else train_loader

weight = None
if "class_weights" in globals():
    weight = class_weights.to(device)
try:
    criterion = nn.CrossEntropyLoss(weight=weight, label_smoothing=0.1)
except TypeError:
    criterion = nn.CrossEntropyLoss(weight=weight)

optimizer = AdamW(router_mlp_plus.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs, eta_min=1e-5)
patience, bad = 6, 0
best_acc, best_state = 0.0, None

@torch.no_grad()
def eval_router_plus(model):
    model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    for X, y in val_loader:
        logits = model(X)
        loss = criterion(logits, y)
        preds = logits.argmax(-1)
        correct += (preds == y).sum().item()
        total += y.numel()
        loss_sum += loss.item() * y.size(0)
    return loss_sum / max(total, 1), correct / max(total, 1)

for epoch in range(1, max_epochs + 1):
    router_mlp_plus.train()
    run_loss = 0.0
    pbar = tqdm(loader, desc=f"Epoch {epoch}/{max_epochs}", unit="batch")
    for X, y in pbar:
        logits = router_mlp_plus(X)
        loss = criterion(logits, y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        clip_grad_norm_(router_mlp_plus.parameters(), 1.0)
        optimizer.step()
        run_loss += loss.item() * y.size(0)
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    val_loss, val_acc = eval_router_plus(router_mlp_plus)
    train_loss = run_loss / len(train_ds)
    print(f"[Epoch {epoch}] train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc:.4f}")

    if val_acc > best_acc:
        best_acc, bad = val_acc, 0
        best_state = {k: v.detach().cpu() for k, v in router_mlp_plus.state_dict().items()}
    else:
        bad += 1
        if bad >= patience:
            print(f"[EarlyStop] no val acc improvement for {patience} epochs.")
            break
    scheduler.step()

CKPT_DIR3 = Path("/root/autodl-tmp/cache/llmRouterMLPPlus"); CKPT_DIR3.mkdir(parents=True, exist_ok=True)
torch.save(best_state if best_state else router_mlp_plus.state_dict(), CKPT_DIR3 / "gate_router_mlp_plus.pt")
label_map = {str(i): name for i, name in enumerate(gate_categories)}
with (CKPT_DIR3 / "gate_label_mapping.json").open("w", encoding="utf-8") as f:
    json.dump(label_map, f, ensure_ascii=False, indent=2)
print(f"[DONE] Saved MLP+ router → {CKPT_DIR3/'gate_router_mlp_plus.pt'} | best_acc={best_acc:.4f}")
print(f"[DONE] Saved label map → {CKPT_DIR3/'gate_label_mapping.json'}")


Epoch 1/30: 100%|██████████| 729/729 [00:03<00:00, 229.60batch/s, loss=0.8074]


[Epoch 1] train_loss=0.9697 | val_loss=2.1110 | val_acc=0.7822


Epoch 2/30: 100%|██████████| 729/729 [00:03<00:00, 223.60batch/s, loss=0.6451]


[Epoch 2] train_loss=0.7984 | val_loss=2.0490 | val_acc=0.8125


Epoch 3/30: 100%|██████████| 729/729 [00:03<00:00, 215.50batch/s, loss=0.6860]


[Epoch 3] train_loss=0.7491 | val_loss=2.0121 | val_acc=0.8275


Epoch 4/30: 100%|██████████| 729/729 [00:03<00:00, 226.91batch/s, loss=0.6907]


[Epoch 4] train_loss=0.7180 | val_loss=1.9759 | val_acc=0.8429


Epoch 5/30: 100%|██████████| 729/729 [00:03<00:00, 214.33batch/s, loss=0.7307]


[Epoch 5] train_loss=0.7057 | val_loss=1.9711 | val_acc=0.8460


Epoch 6/30: 100%|██████████| 729/729 [00:03<00:00, 204.04batch/s, loss=0.6715]


[Epoch 6] train_loss=0.6848 | val_loss=1.9534 | val_acc=0.8544


Epoch 7/30: 100%|██████████| 729/729 [00:03<00:00, 221.40batch/s, loss=0.6671]


[Epoch 7] train_loss=0.6747 | val_loss=1.9789 | val_acc=0.8414


Epoch 8/30: 100%|██████████| 729/729 [00:03<00:00, 217.28batch/s, loss=0.6855]


[Epoch 8] train_loss=0.6648 | val_loss=1.9330 | val_acc=0.8597


Epoch 9/30: 100%|██████████| 729/729 [00:03<00:00, 221.47batch/s, loss=0.7669]


[Epoch 9] train_loss=0.6590 | val_loss=1.9360 | val_acc=0.8600


Epoch 10/30: 100%|██████████| 729/729 [00:03<00:00, 238.06batch/s, loss=0.6926]


[Epoch 10] train_loss=0.6533 | val_loss=1.9398 | val_acc=0.8590


Epoch 11/30: 100%|██████████| 729/729 [00:03<00:00, 207.55batch/s, loss=0.6780]


[Epoch 11] train_loss=0.6475 | val_loss=1.9153 | val_acc=0.8690


Epoch 12/30: 100%|██████████| 729/729 [00:03<00:00, 222.94batch/s, loss=0.6210]


[Epoch 12] train_loss=0.6415 | val_loss=1.8980 | val_acc=0.8754


Epoch 13/30: 100%|██████████| 729/729 [00:03<00:00, 229.81batch/s, loss=0.7335]


[Epoch 13] train_loss=0.6388 | val_loss=1.9116 | val_acc=0.8686


Epoch 14/30: 100%|██████████| 729/729 [00:03<00:00, 210.53batch/s, loss=0.6169]


[Epoch 14] train_loss=0.6344 | val_loss=1.8897 | val_acc=0.8782


Epoch 15/30: 100%|██████████| 729/729 [00:03<00:00, 210.97batch/s, loss=0.5764]


[Epoch 15] train_loss=0.6327 | val_loss=1.8949 | val_acc=0.8747


Epoch 16/30: 100%|██████████| 729/729 [00:03<00:00, 232.36batch/s, loss=0.6407]


[Epoch 16] train_loss=0.6293 | val_loss=1.8938 | val_acc=0.8773


Epoch 17/30: 100%|██████████| 729/729 [00:03<00:00, 219.33batch/s, loss=0.5910]


[Epoch 17] train_loss=0.6249 | val_loss=1.9002 | val_acc=0.8751


Epoch 18/30: 100%|██████████| 729/729 [00:03<00:00, 201.69batch/s, loss=0.6292]


[Epoch 18] train_loss=0.6240 | val_loss=1.8817 | val_acc=0.8806


Epoch 19/30: 100%|██████████| 729/729 [00:03<00:00, 210.59batch/s, loss=0.6812]


[Epoch 19] train_loss=0.6228 | val_loss=1.8755 | val_acc=0.8834


Epoch 20/30: 100%|██████████| 729/729 [00:03<00:00, 209.23batch/s, loss=0.6047]


[Epoch 20] train_loss=0.6202 | val_loss=1.8759 | val_acc=0.8826


Epoch 21/30: 100%|██████████| 729/729 [00:03<00:00, 218.83batch/s, loss=0.6465]


[Epoch 21] train_loss=0.6202 | val_loss=1.8838 | val_acc=0.8812


Epoch 22/30: 100%|██████████| 729/729 [00:03<00:00, 212.74batch/s, loss=0.6120]


[Epoch 22] train_loss=0.6183 | val_loss=1.8851 | val_acc=0.8794


Epoch 23/30: 100%|██████████| 729/729 [00:03<00:00, 216.13batch/s, loss=0.5859]


[Epoch 23] train_loss=0.6181 | val_loss=1.8776 | val_acc=0.8823


Epoch 24/30: 100%|██████████| 729/729 [00:03<00:00, 217.41batch/s, loss=0.6505]


[Epoch 24] train_loss=0.6170 | val_loss=1.8859 | val_acc=0.8808


Epoch 25/30: 100%|██████████| 729/729 [00:03<00:00, 209.51batch/s, loss=0.7001]


[Epoch 25] train_loss=0.6174 | val_loss=1.8723 | val_acc=0.8840


Epoch 26/30: 100%|██████████| 729/729 [00:03<00:00, 217.58batch/s, loss=0.6396]


[Epoch 26] train_loss=0.6150 | val_loss=1.8789 | val_acc=0.8826


Epoch 27/30: 100%|██████████| 729/729 [00:03<00:00, 219.93batch/s, loss=0.7186]


[Epoch 27] train_loss=0.6145 | val_loss=1.8756 | val_acc=0.8837


Epoch 28/30: 100%|██████████| 729/729 [00:03<00:00, 214.25batch/s, loss=0.7159]


[Epoch 28] train_loss=0.6153 | val_loss=1.8715 | val_acc=0.8865


Epoch 29/30: 100%|██████████| 729/729 [00:03<00:00, 220.94batch/s, loss=0.6352]


[Epoch 29] train_loss=0.6149 | val_loss=1.8731 | val_acc=0.8856


Epoch 30/30: 100%|██████████| 729/729 [00:03<00:00, 225.77batch/s, loss=0.6244]


[Epoch 30] train_loss=0.6135 | val_loss=1.8709 | val_acc=0.8857
[DONE] Saved MLP+ router → /root/autodl-tmp/cache/llmRouterMLPPlus/gate_router_mlp_plus.pt | best_acc=0.8865
[DONE] Saved label map → /root/autodl-tmp/cache/llmRouterMLPPlus/gate_label_mapping.json


In [36]:
#@title Full Evaluation (Top-K, calibration, thresholded gate) - RouterMLPPlus
import torch
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
from tqdm import tqdm

router_eval_plus = RouterMLPPlus(embed_dim, len(gate_categories)).to(device).to(dtype)
ckpt = torch.load(Path("/root/autodl-tmp/cache/llmRouterMLPPlus") / "gate_router_mlp_plus.pt", map_location=device)
router_eval_plus.load_state_dict(ckpt)
router_eval_plus.eval()

# Top-1 evaluation
all_preds, all_labels = [], []
with torch.no_grad():
    for X, y in tqdm(val_loader, desc="Evaluating Top-1"):
        logits = router_eval_plus(X)
        preds = logits.argmax(-1)
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(y.cpu().tolist())

print("==== Top-1 classification report ====")
print(classification_report(all_labels, all_preds, target_names=gate_categories))
print("Confusion matrix:")
print(confusion_matrix(all_labels, all_preds))

# Top-K accuracy metrics (K=1,2,3)
@torch.no_grad()
def eval_topk(model, val_ds, Ks=(1,2,3)):
    X = val_ds.X.to(device=device, dtype=dtype)
    y = val_ds.y.to(device)
    logits = model(X)
    probs = torch.softmax(logits, dim=-1)
    results = {}
    for K in Ks:
        topk = probs.topk(K, dim=-1).indices
        acc = (topk == y.unsqueeze(1)).any(dim=1).float().mean().item()
        results[f"Top-{K}"] = acc
    return results

topk_scores = eval_topk(router_eval_plus, val_ds)
print("==== Top-K Accuracies ====")
for k, v in topk_scores.items():
    print(f"{k}: {v:.4f}")

# Temperature calibration on validation set
@torch.no_grad()
def tune_temperature(model, val_loader, t_min=0.5, t_max=3.0, steps=26):
    best_T, best_nll = 1.0, float("inf")
    for T in torch.linspace(t_min, t_max, steps=steps, device=device):
        nll, n = 0.0, 0
        for X, y in val_loader:
            logits = model(X) / T
            nll += F.cross_entropy(logits, y, reduction="sum").item()
            n += y.size(0)
        if nll < best_nll:
            best_nll, best_T = nll, T.item()
    return best_T

T_star = tune_temperature(router_eval_plus, val_loader)
print(f"==== Calibrated temperature T* = {T_star:.3f} ====")

# Thresholded-gate evaluation (Top-1 unless uncertain)
@torch.no_grad()
def eval_thresholded_gate(model, val_ds, tau=0.6, top_k_fallback=2, T=1.0):
    X = val_ds.X.to(device=device, dtype=dtype)
    y = val_ds.y.to(device)
    logits = model(X) / T
    probs = torch.softmax(logits, dim=-1)
    pmax, imax = probs.max(dim=-1)
    topk = probs.topk(top_k_fallback, dim=-1).indices
    preds_thresh = []
    for i in range(len(y)):
        if pmax[i] >= tau:
            preds_thresh.append(imax[i].item())
        else:
            if y[i].item() in topk[i]:
                preds_thresh.append(y[i].item())
            else:
                preds_thresh.append(imax[i].item())
    acc = np.mean(np.array(preds_thresh) == y.cpu().numpy())
    print(f"[τ={tau:.2f}, K={top_k_fallback}] adaptive Top-1/Top-{top_k_fallback} accuracy: {acc:.4f}")
    return acc

acc_uncal = eval_thresholded_gate(router_eval_plus, val_ds, tau=0.6, top_k_fallback=2, T=1.0)
acc_cal   = eval_thresholded_gate(router_eval_plus, val_ds, tau=0.6, top_k_fallback=2, T=T_star)

print("==== Summary ====")
print(f"Top-1 accuracy:   {topk_scores['Top-1']:.4f}")
print(f"Top-2 accuracy:   {topk_scores['Top-2']:.4f}")
print(f"Top-3 accuracy:   {topk_scores['Top-3']:.4f}")
print(f"Thresholded acc (uncalibrated): {acc_uncal:.4f}")
print(f"Thresholded acc (calibrated T*): {acc_cal:.4f}")
print("Note: Higher Top-K or calibrated thresholded acc → better MoLoRA routing coverage.")


Evaluating Top-1: 100%|██████████| 91/91 [00:00<00:00, 534.09it/s]


==== Top-1 classification report ====
                             precision    recall  f1-score   support

                 Appliances       0.93      0.99      0.96       229
     Arts_Crafts_and_Sewing       0.82      0.95      0.88       126
                 Automotive       0.80      0.87      0.84       151
              Baby_Products       0.93      0.98      0.95       333
                      Books       0.64      0.96      0.77       470
              CDs_and_Vinyl       0.95      0.95      0.95       381
Cell_Phones_and_Accessories       0.96      0.97      0.97       357
              Digital_Music       0.91      0.94      0.92       173
                Electronics       0.83      0.92      0.87       399
   Grocery_and_Gourmet_Food       0.89      0.95      0.92       286
          Handmade_Products       0.86      0.92      0.89       285
                     Health       0.74      0.83      0.78       180
  Industrial_and_Scientific       0.88      0.81      0.84      