Weights & Biases

In [1]:
# This is to access the WANDB_API_KEY secret
from google.colab import userdata

In [2]:
!pip install wandb weave

Collecting weave
  Downloading weave-0.52.9-py3-none-any.whl.metadata (27 kB)
Collecting diskcache==5.6.3 (from weave)
  Downloading diskcache-5.6.3-py3-none-any.whl.metadata (20 kB)
Collecting eval-type-backport (from weave)
  Downloading eval_type_backport-0.2.2-py3-none-any.whl.metadata (2.2 kB)
Collecting gql[aiohttp,requests] (from weave)
  Downloading gql-4.0.0-py3-none-any.whl.metadata (10 kB)
Collecting polyfile-weave (from weave)
  Downloading polyfile_weave-0.5.7-py3-none-any.whl.metadata (7.6 kB)
Collecting graphql-core<3.3,>=3.2 (from gql[aiohttp,requests]->weave)
  Downloading graphql_core-3.2.6-py3-none-any.whl.metadata (11 kB)
Collecting backoff<3.0,>=1.11.1 (from gql[aiohttp,requests]->weave)
  Downloading backoff-2.2.1-py3-none-any.whl.metadata (14 kB)
Collecting abnf~=2.2.0 (from polyfile-weave->weave)
  Downloading abnf-2.2.0-py3-none-any.whl.metadata (1.1 kB)
Collecting cint>=1.0.0 (from polyfile-weave->weave)
  Downloading cint-1.0.0-py3-none-any.whl.metadata (511 

In [1]:
import os, wandb
from google.colab import userdata

api_key = userdata.get('WANDB_API_KEY')
if not api_key:
    raise ValueError("Colab secret 'WANDB_API_KEY' not found. Add it in Colab > ⚙️ > User secrets.")

wandb.login(key=api_key)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33malice-chua[0m ([33malice-chua-university-of-toronto[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Model training & Logging on Wandb

In [2]:
!pip -q install accelerate datasets scikit-learn

In [3]:
!pip install --no-cache-dir "transformers==4.57.0"



In [4]:
!pip -q install onnx onnxruntime

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m57.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m81.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [5]:
import torch, platform
print("PyTorch:", torch.__version__, "| CUDA:", torch.cuda.is_available(), "| Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
print("Python:", platform.python_version())

PyTorch: 2.8.0+cu126 | CUDA: True | Device: Tesla T4
Python: 3.12.12


In [6]:
import transformers, datasets, accelerate
print("transformers:", transformers.__version__)
print("datasets:", datasets.__version__)
print("accelerate:", accelerate.__version__)

transformers: 4.57.0
datasets: 4.0.0
accelerate: 1.10.1


In [7]:
# Install once per session
!pip -q install gdown

import os, pathlib, zipfile, gdown

import re, os, zipfile, pathlib, gdown, requests

BASE = "/content"
DL_DIR = f"{BASE}/data"
os.makedirs(DL_DIR, exist_ok=True)

def extract_id(url: str) -> str:
    # works for /file/d/<id>/view and ...id=<id> formats
    m = re.search(r"/d/([a-zA-Z0-9_-]+)", url)
    if m: return m.group(1)
    m = re.search(r"[?&]id=([a-zA-Z0-9_-]+)", url)
    if m: return m.group(1)
    raise ValueError("Could not extract file id from URL")

def download_drive_csv(url: str, out_csv_path: str) -> str:
    """Download a Drive *file* (csv/zip) or export a Google Sheet to CSV."""
    # Case A: Google Sheets -> export
    if "docs.google.com/spreadsheets/" in url:
        # Build export URL
        sheet_id = extract_id(url)
        export_url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/export?format=csv"
        r = requests.get(export_url)
        r.raise_for_status()
        with open(out_csv_path, "wb") as f:
            f.write(r.content)
        return out_csv_path

    # Case B: Regular Drive file -> use id + force extension
    file_id = extract_id(url)
    tmp_path = out_csv_path  # ensure it ends with .csv or .zip
    path = gdown.download(id=file_id, output=tmp_path, quiet=False)
    if path is None:
        raise RuntimeError(f"Download failed for: {url}")

    # If ZIP, unzip and return first CSV
    if path.endswith(".zip"):
        with zipfile.ZipFile(path, "r") as z:
            z.extractall(os.path.dirname(path))
        csvs = sorted(pathlib.Path(os.path.dirname(path)).rglob("*.csv"))
        if not csvs:
            raise RuntimeError(f"No CSV files found after unzipping {path}")
        return str(csvs[0])

    # If CSV, return it
    if path.endswith(".csv"):
        return path

    # Detect HTML (permission/quota page)
    with open(path, "rb") as f:
        head = f.read(256).lower()
    if b"<html" in head:
        raise RuntimeError(
            "Downloaded an HTML page (likely permission/quota issue). "
            "Make sure the Drive file is shared as 'Anyone with the link' or openable by your account."
        )

    # Last resort: if the content is CSV but filename lacks extension, just rename it
    try:
        # quick sniff: try reading a few bytes and see commas/newlines
        if b"," in head or b"\n" in head:
            new_path = path + ".csv"
            os.rename(path, new_path)
            return new_path
    except Exception:
        pass

    raise RuntimeError(f"Downloaded file is not a .csv or .zip: {path}")

# ==== PUT YOUR LINKS HERE ====
TRAIN_URL = "https://drive.google.com/file/d/1UOKDzTjzT1wgMNrfo_uqI2L6J5bK5rbe/view?usp=sharing"
TEST_URL  = "https://drive.google.com/file/d/1-raSWDL-DDLDR_oDcelnrLJJKsHShEJF/view?usp=sharing"
# =============================
TRAIN_PATH = download_drive_csv(TRAIN_URL, os.path.join(DL_DIR, "train.csv"))
TEST_PATH  = download_drive_csv(TEST_URL,  os.path.join(DL_DIR, "test.csv"))

print("TRAIN_PATH:", TRAIN_PATH)
print("TEST_PATH :", TEST_PATH)

Downloading...
From: https://drive.google.com/uc?id=1UOKDzTjzT1wgMNrfo_uqI2L6J5bK5rbe
To: /content/data/train.csv
100%|██████████| 988k/988k [00:00<00:00, 8.23MB/s]
Downloading...
From: https://drive.google.com/uc?id=1-raSWDL-DDLDR_oDcelnrLJJKsHShEJF
To: /content/data/test.csv
100%|██████████| 421k/421k [00:00<00:00, 5.18MB/s]

TRAIN_PATH: /content/data/train.csv
TEST_PATH : /content/data/test.csv





In [8]:
# =========================
# Flexible HF text-classification trainer
# Swap models by editing MODEL_ID only
# =========================
import os, re, random, numpy as np, pandas as pd, wandb, math
from typing import Dict, Any
from dataclasses import dataclass

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support

import torch
from transformers import (
    AutoTokenizer, AutoConfig, AutoModelForSequenceClassification,
    DataCollatorWithPadding, Trainer, TrainingArguments, EarlyStoppingCallback,
    set_seed
)
from datasets import Dataset

import json, onnx, onnxruntime as ort


# -------------------------
# Reproducibility
# -------------------------
SEED = 42
set_seed(SEED)

# -------------------------
# Data paths
# -------------------------
SUB_PATH   = "submission.csv"

# -------------------------
# Light tweet normalization
# -------------------------
URL_RE  = re.compile(r"https?://\S+|www\.\S+")
USER_RE = re.compile(r"@\w+")
def normalize_tweet(t: str) -> str:
    t = URL_RE.sub(" <url> ", t)
    t = USER_RE.sub(" <user> ", t)
    t = re.sub(r"\s+", " ", t).strip()
    return t

# -------------------------
# Load & build text
# -------------------------
df = pd.read_csv(TRAIN_PATH)
df["text"] = df["text"].astype(str).map(normalize_tweet)

def build_text(row):
    kw  = f" keyword: {row['keyword']}"   if isinstance(row.get("keyword"), str) else ""
    loc = f" location: {row['location']}" if isinstance(row.get("location"), str) else ""
    return f"{kw}{loc} text: {row['text']}"

df["text"] = df.apply(build_text, axis=1)

df_train, df_val = train_test_split(
    df, test_size=0.1, random_state=SEED, stratify=df["target"]
)

df_test = pd.read_csv(TEST_PATH)
df_test["text"] = df_test["text"].astype(str).map(normalize_tweet)
df_test["text"] = df_test.apply(build_text, axis=1)

# -------------------------
# Hugging Face Datasets
# -------------------------
train_ds = Dataset.from_pandas(df_train[["text", "target"]].reset_index(drop=True))
val_ds   = Dataset.from_pandas(df_val[["text", "target"]].reset_index(drop=True))
test_ds  = Dataset.from_pandas(df_test[["id", "text"]].reset_index(drop=True))

# =========================
# Flexible model swap
# =========================
NUM_LABELS = 2

#@title Choose model (edit freely)
MODEL_ID = "google/electra-base-discriminator" #@param {type:"string"}# Examples you can try:
# "roberta-base"
# "microsoft/deberta-v3-base"
# "google/electra-base-discriminator"
# "xlm-roberta-base"
# "bert-base-uncased"

# Recommended per-model defaults (override-able below)
RECS: Dict[str, Dict[str, Any]] = {
    "distilbert": {"max_len": 128, "lr": 1e-5,  "batch": 32},
    "bert":       {"max_len": 256, "lr": 2e-5,  "batch": 32},
    "roberta":    {"max_len": 256, "lr": 2e-5,  "batch": 32},
    "deberta":    {"max_len": 256, "lr": 2e-5,  "batch": 24},
    "electra":    {"max_len": 256, "lr": 2e-5,  "batch": 32},
    "xlm-roberta":{"max_len": 256, "lr": 2e-5,  "batch": 24},
}

def short_name(model_id: str) -> str:
    base = model_id.split("/")[-1]
    # strip size-specific suffixes for routing
    for k in RECS:
        if k in model_id.lower() or k in base.lower():
            return base
    return base

def family_key(model_id: str) -> str:
    lid = model_id.lower()
    for k in RECS:
        if k in lid:
            return k
    # fallback: treat as bert-like
    return "bert"

fam = family_key(MODEL_ID)
defaults = RECS[fam]

# -------------------------
# Train hyperparams (override here if you like)
# -------------------------
EPOCHS = int(os.environ.get("EPOCHS", 1))
BATCH  = int(os.environ.get("BATCH",  defaults["batch"]))
LR     = float(os.environ.get("LR",   defaults["lr"]))
WARMUP_RATIO = float(os.environ.get("WARMUP_RATIO", 0.1))
MAX_LEN = int(os.environ.get("MAX_LEN", defaults["max_len"]))

# -------------------------
# Load tokenizer/config/model with safe defaults
# -------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

# Handle missing pad token (e.g., some RoBERTa variants are fine; causal LMs are not supported here)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
    tokenizer.pad_token = tokenizer.eos_token

config = AutoConfig.from_pretrained(MODEL_ID, num_labels=NUM_LABELS)
model  = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, config=config)

# -------------------------
# Tokenization
# -------------------------
def tok_fn(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        padding=False,
        max_length=MAX_LEN,
    )

train_ds = train_ds.map(tok_fn, batched=True, remove_columns=["text"])
val_ds   = val_ds.map(tok_fn, batched=True, remove_columns=["text"])
test_ds  = test_ds.map(tok_fn, batched=True, remove_columns=["text"])

train_ds = train_ds.rename_column("target", "labels")
val_ds   = val_ds.rename_column("target", "labels")

collator = DataCollatorWithPadding(tokenizer=tokenizer)

# -------------------------
# Metrics
# -------------------------
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = (logits.argmax(axis=-1)).astype(int)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds)
    }

# -------------------------
# W&B
# -------------------------
run_name = f"{short_name(MODEL_ID)}-baseline"
wandb_run = wandb.init(
    entity="alice-chua-university-of-toronto",
    project="disaster-tweets",
    name=run_name,
    config=dict(
        model=MODEL_ID,
        seed=SEED,
        max_len=MAX_LEN,
        batch=BATCH,
        epochs=EPOCHS,
        lr=LR,
        warmup_ratio=WARMUP_RATIO,
        weight_decay=0.01,
    ),
)

# -------------------------
# Training arguments
# -------------------------
out_dir = f"./checkpoints/{short_name(MODEL_ID)}"
args = TrainingArguments(
    output_dir=out_dir,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,

    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH,
    per_device_eval_batch_size=BATCH,
    learning_rate=LR,
    weight_decay=0.01,
    warmup_ratio=WARMUP_RATIO,
    max_grad_norm=1.0,
    fp16=torch.cuda.is_available(),
    logging_steps=50,
    save_total_limit=2,
    report_to=["wandb"],
    run_name=run_name,

    # Nice-to-haves for larger models
    gradient_checkpointing=False,      # flip True for big DeBERTa if needed
    bf16=torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8,
)

# -------------------------
# Trainer
# -------------------------
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

trainer.train()
metrics = trainer.evaluate()
print("Validation metrics:", metrics)

# -------------------------
# Threshold tuning on val
# -------------------------
val_pred = trainer.predict(val_ds)
logits = val_pred.predictions
y_true = val_ds["labels"]

probs = torch.softmax(torch.tensor(logits), dim=1).numpy()[:, 1]

best = {"f1": -1, "thr": 0.5, "prec": 0.0, "rec": 0.0}
for thr in np.linspace(0.2, 0.8, 61):
    y_pred = (probs >= thr).astype(int)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    if f1 > best["f1"] or (np.isclose(f1, best["f1"]) and rec > best["rec"]):
        best = {"f1": float(f1), "thr": float(thr), "prec": float(prec), "rec": float(rec)}
print(f"Best threshold = {best['thr']:.3f} | F1={best['f1']:.4f} (P={best['prec']:.4f}, R={best['rec']:.4f})")

# -------------------------
# Inference on test & submission
# -------------------------
test_logits = trainer.predict(test_ds).predictions
test_probs = torch.softmax(torch.tensor(test_logits), dim=1).numpy()[:, 1]
test_labels = (test_probs >= best["thr"]).astype(int)

sub = pd.DataFrame({"id": df_test["id"], "target": test_labels})
sub.to_csv(SUB_PATH, index=False)
print(f"Saved {SUB_PATH} with tuned threshold {best['thr']:.3f}")

# -------------------------
# Log a PyTorch “training” artifact (safetensors + config + tokenizer)
# -------------------------
from pathlib import Path

BEST_DIR = out_dir  # or the actual best checkpoint dir you resolved
PT_PACK = Path("./artifacts_pt"); PT_PACK.mkdir(parents=True, exist_ok=True)

# Save tokenizer and model (safetensors)
tokenizer.save_pretrained(PT_PACK / "tokenizer")
model.save_pretrained(PT_PACK / "pytorch", safe_serialization=True)  # writes model.safetensors + config.json

# Metadata for the *training* artifact
arch_name = getattr(model.config, "_name_or_path", MODEL_ID)
vocab_sz  = getattr(tokenizer, "vocab_size", None) or len(tokenizer.get_vocab())
dtype_str = str(next(model.parameters()).dtype).replace("torch.", "")
ctx_len   = MAX_LEN

train_meta = {
    "stage": "training",
    "format": "pytorch",
    "architecture": arch_name,
    "context_length": ctx_len,
    "vocab_size": int(vocab_sz),
    "dtype": dtype_str,
    "num_labels": int(model.config.num_labels),
    # eval metrics
    "best_f1": float(best["f1"]),
    "best_threshold": float(best["thr"]),
    "accuracy": float(metrics.get("eval_accuracy", 0.0)),
    "f1_eval": float(metrics.get("eval_f1", 0.0)),
}

train_art = wandb.Artifact(
    name="disaster-tweet-classifier-training",
    type="model",
    metadata=train_meta,
)
train_art.add_dir(str(PT_PACK))
train_logged = wandb_run.log_artifact(train_art, aliases=["training", "candidate"])
# =========================
# ONNX export + latency probe + rich metadata
# =========================
import os, json, time, numpy as np, torch, onnx, onnxruntime as ort
from pathlib import Path

onnx_dir  = f"./onnx_exports/{short_name(MODEL_ID)}"
Path(onnx_dir).mkdir(parents=True, exist_ok=True)
onnx_path = os.path.join(onnx_dir, f"{short_name(MODEL_ID)}.onnx")

# Keep track of original device, then export on CPU
orig_device = next(model.parameters()).device if any(p.requires_grad for p in model.parameters()) else torch.device("cpu")
model_cpu = model.to("cpu").eval()

# --- Build a NON-constant example (two lengths) to avoid folding masks away ---
example = tokenizer(
    ["short", "a much longer example that forces padding differences here"],
    max_length=MAX_LEN, padding=True, truncation=True, return_tensors="pt"
)

# If your family doesn't use token_type_ids (e.g., DeBERTa/RoBERTa), drop it from export
if "deberta" in MODEL_ID.lower():
    example.pop("token_type_ids", None)

# Use the keys actually present in the example
input_names = list(example.keys())              # e.g. ['input_ids','attention_mask']
dynamic_axes = {k: {0: "batch", 1: "sequence"} for k in input_names}
dynamic_axes["logits"] = {0: "batch"}

# Ensure types
example = {k: v.long() for k, v in example.items()}

# Warmup
with torch.inference_mode():
    _ = model_cpu(**example)

# Export (set dynamo=True if you want the new exporter; leave as False to keep legacy path)
EXPORTER_FLAGS = {"do_constant_folding": True, "dynamo": False}
torch.onnx.export(
    model_cpu,
    args=tuple(example[k] for k in input_names),
    f=onnx_path,
    input_names=input_names,
    output_names=["logits"],
    dynamic_axes=dynamic_axes,
    opset_version=17,
    do_constant_folding=EXPORTER_FLAGS["do_constant_folding"],
    # dynamo=EXPORTER_FLAGS["dynamo"],  # uncomment if your torch supports it and you want to opt in
)

# Validate ONNX
onnx.checker.check_model(onnx.load(onnx_path))

# -------------------------
# ORT latency probe (CPU)
# -------------------------
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
sess_input_names = [i.name for i in sess.get_inputs()]
print("ONNX inputs:", sess_input_names)

# Build a tiny batch for timing
enc_np = tokenizer(
    ["probe one", "a longer probe that forces padding"],
    max_length=MAX_LEN, truncation=True, padding=True, return_tensors="np"
)
feed = {}
for n in sess_input_names:
    if n in enc_np:
        feed[n] = enc_np[n].astype(np.int64)
    elif n == "token_type_ids":
        feed[n] = np.zeros_like(enc_np["input_ids"], dtype=np.int64)
    else:
        raise KeyError(f"Missing required ONNX input: {n}")

# Warmup
for _ in range(5):
    _ = sess.run(["logits"], feed)

# Time a few runs
N = 50
t0 = time.time()
for _ in range(N):
    _ = sess.run(["logits"], feed)
lat_ms = (time.time() - t0) * 1000.0 / N  # ms per run for this batch on CPU
print(f"ORT latency (CPU, batch={list(feed.values())[0].shape[0]}): {lat_ms:.2f} ms")

# -------------------------
# Save tokenizer + labels alongside ONNX (for ONNX-only inference)
# -------------------------
tok_dir = os.path.join(onnx_dir, "tokenizer"); os.makedirs(tok_dir, exist_ok=True)
tokenizer.save_pretrained(tok_dir)
with open(os.path.join(onnx_dir, "id2label.json"), "w") as f:
    json.dump({int(k): v for k, v in model.config.id2label.items()}, f)

# -------------------------
# Rich metadata for the artifact
# -------------------------
arch_name = getattr(model.config, "_name_or_path", MODEL_ID)
vocab_sz  = getattr(tokenizer, "vocab_size", None) or len(tokenizer.get_vocab())
dtype_str = str(next(model.parameters()).dtype).replace("torch.", "")
ctx_len   = MAX_LEN

onnx_meta = {
    "stage": "inference",
    "format": "onnx",
    "opset": 17,
    "exporter": "torch.onnx",
    "export_flags": EXPORTER_FLAGS,   # e.g., {'do_constant_folding': True, 'dynamo': False}
    "architecture": arch_name,
    "context_length": int(ctx_len),
    "vocab_size": int(vocab_sz),
    "dtype": dtype_str,
    "num_labels": int(model.config.num_labels),
    # Copy your key eval metrics for convenience
    "best_f1": float(best["f1"]),
    "best_threshold": float(best["thr"]),
    "eval_accuracy": float(metrics.get("eval_accuracy", 0.0)),
    "eval_f1": float(metrics.get("eval_f1", 0.0)),
    # Simple latency probe details (document hardware in notes if you want)
    "latency_ms_cpu_batch": float(lat_ms),
}

# -------------------------
# W&B: log ONNX artifact with rich metadata + useful aliases
# -------------------------
onnx_art = wandb.Artifact(
    name="disaster-tweet-classifier-onnx",  # use a stable, clear name
    type="model",
    metadata=onnx_meta,
)
onnx_art.add_file(onnx_path)   # the .onnx graph
onnx_art.add_dir(tok_dir)      # tokenizer/ + id2label.json
onnx_logged = wandb_run.log_artifact(onnx_art, aliases=["inference", "candidate"])
if hasattr(onnx_logged, "wait"): onnx_logged.wait()

# Return model to its original device (resume PT inference if needed)
model = model.to(orig_device)


# -------------------------
# Link artifacts into the Model Registry
# -------------------------
ENTITY = "alice-chua-university-of-toronto-org"
TARGET = "wandb-registry-model/disaster-tweet-model-registry"  # collection name only

# 1) Link ONNX artifact with inference-oriented aliases
wandb_run.link_artifact(
    artifact=onnx_logged,
    target_path=TARGET,
    aliases=["inference", "candidate"]
)

# 2) Link PyTorch "training/source-of-truth" artifact if available
#    (Assumes you previously created and logged it as `train_logged`.)
try:
    if 'train_logged' in globals() and train_logged is not None:
        wandb_run.link_artifact(
            artifact=train_logged,
            target_path=TARGET,
            aliases=["training", "source-of-truth"]
        )
except Exception as e:
    print(f"(Note) Could not link training artifact: {e}")

# (Optional) Promote one to production (only one artifact can hold 'production' at a time)
wandb_run.link_artifact(
    artifact=onnx_logged,   # or `train_logged` if you want PT to be prod
    target_path=TARGET,
    aliases=["production"]
)

# -------------------------
# Optional: metric-based promotion helper (kept from your version)
# -------------------------
TIE_BREAK_GE = True
api = wandb.Api()
import numpy as np

METRIC = "best_f1"

def _get_registry_artifact(alias: str, target: str = TARGET, entity: str = ENTITY):
    path = f"{entity}/{target}:{alias}"
    try:
        return api.artifact(path)
    except Exception as e:
        print(f"[WARN] Failed to fetch '{alias}' at {path}: {e!r}")
        return None

def _to_float(x, name="value"):
    try:
        return float(x)
    except Exception:
        raise ValueError(f"Cannot parse {name}={x!r} as float")

def decide_and_transition(
    current_artifact: Any,
    metric_key: str = METRIC,
    promote_on_tie: bool = True,
    target: str = TARGET,
):
    """Promote better artifact to 'production'. No 'archived' aliasing."""
    if hasattr(current_artifact, "wait"):
        current_artifact.wait()

    cur_val_raw = (getattr(current_artifact, "metadata", None) or {}).get(metric_key)
    cur_val = _to_float(cur_val_raw, f"current.{metric_key}")

    prev_prod = _get_registry_artifact("production", target=target)
    prev_val = None
    if prev_prod is not None:
        prev_val_raw = (getattr(prev_prod, "metadata", None) or {}).get(metric_key)
        prev_val = float("-inf") if prev_val_raw is None else _to_float(prev_val_raw, f"production.{metric_key}")

    better = (prev_prod is None) or (cur_val > prev_val) or (promote_on_tie and np.isclose(cur_val, prev_val))

    if better:
        wandb_run.link_artifact(artifact=current_artifact, target_path=target, aliases=["production"])
        print(f"✅ Promoted to PRODUCTION. new {metric_key}={cur_val:.4f}"
              + ("" if prev_val is None else f" vs old {prev_val:.4f}"))
        try:
            kept = [a for a in (current_artifact.aliases or []) if a != "candidate"]
            current_artifact.aliases = sorted(set(kept + ["production"]))
            current_artifact.save()
        except Exception as e:
            print(f"(Note) Could not edit project aliases on current: {e}")
    else:
        print(f"ℹ️ Kept existing PRODUCTION (old {metric_key}={prev_val:.4f} ≥ new {cur_val:.4f}).")

print("ENTITY/TARGET:", ENTITY, TARGET)
print("artifact name (onnx):", onnx_logged.name)
print("artifact type:", getattr(onnx_logged, "type", None))  # must be "model"

# If you still want metric-driven promotion, call with whichever artifact you want to consider:
decide_and_transition(onnx_logged, METRIC)

wandb.finish()

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-base-discriminator and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/6851 [00:00<?, ? examples/s]

Map:   0%|          | 0/762 [00:00<?, ? examples/s]

Map:   0%|          | 0/3263 [00:00<?, ? examples/s]

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,0.4148,0.416654,0.833333,0.806697


Validation metrics: {'eval_loss': 0.4166543185710907, 'eval_accuracy': 0.8333333333333334, 'eval_f1': 0.806697108066971, 'eval_runtime': 0.8123, 'eval_samples_per_second': 938.032, 'eval_steps_per_second': 29.544, 'epoch': 1.0}
Best threshold = 0.500 | F1=0.8067 (P=0.8030, R=0.8104)


Saved submission.csv with tuned threshold 0.500


[34m[1mwandb[0m: Adding directory to artifact (artifacts_pt)... Done. 2.1s
  torch.onnx.export(


ONNX inputs: ['input_ids', 'token_type_ids', 'attention_mask']
ORT latency (CPU, batch=2): 49.03 ms


[34m[1mwandb[0m: Adding directory to artifact (onnx_exports/electra-base-discriminator/tokenizer)... Done. 0.0s


ENTITY/TARGET: alice-chua-university-of-toronto-org wandb-registry-model/disaster-tweet-model-registry
artifact name (onnx): disaster-tweet-classifier-onnx:v1
artifact type: model
✅ Promoted to PRODUCTION. new best_f1=0.8067 vs old 0.8067


0,1
eval/accuracy,▁▁
eval/f1,▁▁
eval/loss,▁▁
eval/runtime,▁█
eval/samples_per_second,█▁
eval/steps_per_second,█▁
test/accuracy,▁
test/f1,▁
test/loss,▁
test/runtime,▁█

0,1
eval/accuracy,0.83333
eval/f1,0.8067
eval/loss,0.41665
eval/runtime,0.8123
eval/samples_per_second,938.032
eval/steps_per_second,29.544
test/accuracy,0.83333
test/f1,0.8067
test/loss,0.41665
test/runtime,3.374


# Pull latest model from Wandb and run as chatbot

In [9]:
# =========================
# Pull @production from W&B Model Registry & run an interactive classifier
# =========================
!pip -q install wandb transformers onnxruntime >/dev/null

import os, re, json, pathlib, numpy as np, torch, wandb

# ----- Configure your registry coordinates -----
ENTITY = "alice-chua-university-of-toronto-org"  # org/user that owns the registry
TARGET = "wandb-registry-model/disaster-tweet-model-registry"  # collection (no entity here)
ALIAS  = "production"

api = wandb.Api()
art = api.artifact(f"{ENTITY}/{TARGET}:{ALIAS}")
local_dir = art.download()
meta = art.metadata or {}
fmt  = (meta.get("format") or "").lower()
print("Downloaded artifact:", art.name, "| type:", getattr(art, "type", "?"), "| format:", fmt)
print("Local dir:", local_dir)

# ----- helpers -----
def _find_one(root: str, pattern: str) -> str:
    p = list(pathlib.Path(root).rglob(pattern))
    return str(p[0]) if p else ""

def resolve_model_name(local_dir: str, meta: dict) -> str:
    # 1) prefer explicit metadata from training/logging
    if meta and meta.get("model_id"):
        return meta["model_id"]
    # 2) try HF config if present (PyTorch export)
    cfg_path = os.path.join(local_dir, "config.json")
    if os.path.exists(cfg_path):
        try:
            with open(cfg_path, "r", encoding="utf-8") as f:
                cfg = json.load(f)
            # HF usually stores the original model id here
            if "_name_or_path" in cfg and cfg["_name_or_path"]:
                return cfg["_name_or_path"]
            # fall back to model_type if name_or_path missing
            if "model_type" in cfg:
                return cfg["model_type"]
        except Exception:
            pass
    # 3) ONNX-only: look for a hint file we saved
    hint = os.path.join(local_dir, "tokenizer", "special_tokens_map.json")
    if os.path.exists(hint):
        try:
            from transformers import AutoTokenizer
            tok = AutoTokenizer.from_pretrained(os.path.dirname(hint), use_fast=True)
            if getattr(tok, "name_or_path", None):
                return tok.name_or_path
        except Exception:
            pass
    # 4) last resort: artifact name or folder basename
    return getattr(art, "name", os.path.basename(local_dir))

MODEL_NAME = resolve_model_name(local_dir, meta)
print("Resolved model name:", MODEL_NAME)


# Light tweet normalization (same as training)
URL_RE  = re.compile(r"https?://\S+|www\.\S+")
USER_RE = re.compile(r"@\w+")
def normalize_tweet(t: str) -> str:
    t = URL_RE.sub(" <url> ", str(t))
    t = USER_RE.sub(" <user> ", t)
    t = re.sub(r"\s+", " ", t).strip()
    return t

# Label mapping (fallback if not present in config)
DEFAULT_ID2LABEL = {0: "not disaster", 1: "disaster"}

predict_impl = None
IDX2LABEL = DEFAULT_ID2LABEL

# ===== Branch A: ONNX artifact =====
onnx_path = _find_one(local_dir, "*.onnx")
if fmt == "onnx" or onnx_path:
    import onnxruntime as ort
    from transformers import AutoTokenizer

    tok_dir = _find_one(local_dir, "tokenizer") or local_dir
    tokenizer = AutoTokenizer.from_pretrained(tok_dir, use_fast=True)

    # try to load id2label if present
    id2label_path = _find_one(local_dir, "id2label.json")
    if id2label_path:
        with open(id2label_path, "r") as f:
            raw = json.load(f)
        IDX2LABEL = {int(k): v for k, v in raw.items()}

    sess = ort.InferenceSession(onnx_path or _find_one(local_dir, "*.onnx"), providers=["CPUExecutionProvider"])
    sess_input_names = [i.name for i in sess.get_inputs()]
    print("ONNX inputs:", sess_input_names)

    def predict_one(text: str, max_len: int = int(meta.get("max_len", 256))):
        text = normalize_tweet(text)
        enc = tokenizer([text], max_length=max_len, truncation=True, padding=True, return_tensors="np")

        # --- feed ONLY what the graph declares; synthesize token_type_ids if required ---
        feed = {}
        for name in sess_input_names:
            if name in enc:
                feed[name] = enc[name].astype(np.int64)
            elif name == "token_type_ids":
                feed[name] = np.zeros_like(enc["input_ids"], dtype=np.int64)
            else:
                # If the graph expects a name we didn't create, fail loudly
                raise KeyError(f"Missing required ONNX input: {name}")

        logits = sess.run(["logits"], feed)[0]
        probs = torch.softmax(torch.tensor(logits), dim=-1).numpy()[0]
        pred = int(np.argmax(probs))
        return pred, probs

    print("Loaded ONNX @production ✅ | model:", MODEL_NAME)

# ===== Branch B: PyTorch HF directory =====
else:
    from transformers import AutoTokenizer, AutoModelForSequenceClassification

    tokenizer = AutoTokenizer.from_pretrained(local_dir, use_fast=True)
    model = AutoModelForSequenceClassification.from_pretrained(local_dir)
    if getattr(model.config, "id2label", None):
        IDX2LABEL = {int(k): v for k, v in model.config.id2label.items()}

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device).eval()

    @torch.inference_mode()
    def predict_one(text: str, max_len: int = 256):
        text = normalize_tweet(text)
        enc = tokenizer(text, truncation=True, max_length=max_len, padding=True, return_tensors="pt")
        enc = {k: v.to(device) for k, v in enc.items()}
        logits = model(**enc).logits
        probs = torch.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
        pred = int(np.argmax(probs))
        return pred, probs

    print("Loaded PyTorch HF @production ✅ | model:", MODEL_NAME)

# ----- Interactive loop -----
print("\nType a tweet and press Enter to classify. Type 'quit' (or empty line) to exit.\n")
while True:
    try:
        user_text = input("Tweet> ").strip()
    except EOFError:
        break
    if user_text.lower() in {"", "quit", "exit"}:
        print("Bye!")
        break
    pred, probs = predict_one(user_text, max_len=int(meta.get("max_len", 256)))
    if len(probs) == 2:
        print(f"Prediction: {IDX2LABEL.get(pred, pred)}  |  P({IDX2LABEL.get(0,'0')})={probs[0]:.3f}, P({IDX2LABEL.get(1,'1')})={probs[1]:.3f}\n")
    else:
        print(f"Prediction: {IDX2LABEL.get(pred, pred)}  |  probs={probs}\n")

[34m[1mwandb[0m: Downloading large artifact 'disaster-tweet-model-registry:production', 418.74MB. 5 files...
[34m[1mwandb[0m:   5 of 5 files downloaded.  
Done. 00:00:02.3 (179.0MB/s)


Downloaded artifact: disaster-tweet-model-registry:production | type: model | format: onnx
Local dir: /content/artifacts/disaster-tweet-classifier-onnx:v1
Resolved model name: disaster-tweet-model-registry:production
ONNX inputs: ['input_ids', 'token_type_ids', 'attention_mask']
Loaded ONNX @production ✅ | model: disaster-tweet-model-registry:production

Type a tweet and press Enter to classify. Type 'quit' (or empty line) to exit.

Tweet> hi
Prediction: not disaster  |  P(not disaster)=0.584, P(disaster)=0.416

Tweet> i am in a burning building
Prediction: disaster  |  P(not disaster)=0.410, P(disaster)=0.590

Tweet> help
Prediction: disaster  |  P(not disaster)=0.425, P(disaster)=0.575

Tweet> exirt
Prediction: not disaster  |  P(not disaster)=0.724, P(disaster)=0.276

Tweet> exit
Bye!
