In [None]:
# !pip install --upgrade transformers torch accelerate bitsandbytes torchvision torchaudio fastai

In [None]:
from datasets import load_dataset

ds = load_dataset('fancyzhx/ag_news',)
ds

In [None]:
ds['test'].select(range(5))

In [None]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    LogitsProcessor,
    LogitsProcessorList,
    BitsAndBytesConfig
)

device = "cuda" if torch.cuda.is_available() else "cpu"

# 3. Define a K‐Way LogitsProcessor
class KWayLogitsProcessor(LogitsProcessor):
    def __init__(self, allowed_token_ids: list[int]):
        super().__init__()
        self.allowed_ids = set(allowed_token_ids)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # scores: (batch_size, vocab_size)
        mask = torch.zeros_like(scores, dtype=torch.bool)
        for t in self.allowed_ids:
            mask[:, t] = True
        # ban everything else
        return scores.masked_fill(~mask, -float("inf"))

In [None]:
from huggingface_hub import login
from transformers import MistralCommonBackend, FineGrainedFP8Config, Mistral3ForConditionalGeneration

access_token = "..."
login(token=access_token, )

# model_id = "google/gemma-3-1b-it"
# model_id = 'Qwen/Qwen3-0.6B'
# model_id = 'Qwen/Qwen3-4B-Instruct-2507'
# model_id = 'meta-llama/Llama-3.2-3B-Instruct'
model_id = 'mistralai/Ministral-3-3B-Instruct-2512'
# tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer = MistralCommonBackend.from_pretrained(model_id)  # for ministral
# tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
# quantization_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_compute_dtype=torch.float16
# )
# model = AutoModelForCausalLM.from_pretrained(
#     model_id,
#     # quantization_config=quantization_config,
#     # torch_dtype=torch.bfloat16,
#     device_map="auto",
# )
# .to(device)  # loads in BF16 if GPU; falls back to FP32 on CPU :contentReference[oaicite:0]{index=0}

model = Mistral3ForConditionalGeneration.from_pretrained(
    model_id,
    # quantization_config=quantization_config,
    # torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=FineGrainedFP8Config(dequantize=True)  # for ministral
)

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model.resize_token_embeddings(len(tokenizer))
# tokenizer.pad_token = tokenizer.eos_token

In [None]:
def _tri(L_in, L_out):
    # attention sweep over prefill + generation with KV cache
    return (L_in * (L_in + 1) / 2.0) + (L_out * L_in) + ((L_out - 1) * L_out / 2.0)

def _arch_constants_from_config(cfg):
    """
    Infer c_mlp (d^2-heavy) and c_attn (d·L) from HF config.
    - c_mlp = (projections) + (MLP)
      projections: Q,O full-size + K,V reduced by r_kv = n_kv_heads / n_heads
        => proj_factor ≈ 2 + 2 * r_kv   (≈4 when MHA; smaller when GQA/MQA)
      MLP: use expansion r = intermediate_size / hidden_size
        GeLU-like => mlp_factor ≈ 2 * r
        SwiGLU/SiLU-like => mlp_factor ≈ 3 * r
    - c_attn: ~2 (QK^T + AV). Keep a small mid-range default.
    """
    d = getattr(cfg, "hidden_size", None)
    if d is None:
      d = getattr(cfg.text_config, "hidden_size")
    n = getattr(cfg, "num_hidden_layers", None)
    if n is None:
      n = getattr(cfg.text_config, "num_hidden_layers")
    inter = getattr(cfg, "intermediate_size", 4 * d)
    r = inter / d

    act = (getattr(cfg, "hidden_activation", "") or "").lower()
    if "swiglu" in act or "silu" in act or "swish" in act:
        mlp_factor = 3.0 * r
    else:
        mlp_factor = 2.0 * r

    h = getattr(cfg, "num_attention_heads", None)
    h_kv = getattr(cfg, "num_key_value_heads", h)
    r_kv = (h_kv / h) if (h and h_kv) else 1.0
    proj_factor = 2.0 + 2.0 * r_kv  # Q,O full (2) + K,V scaled by r_kv

    c_mlp = proj_factor + mlp_factor   # total d^2-heavy constant
    c_attn = 2.5                       # mild mid-point for attention kernels
    return n, d, c_mlp, c_attn

def query_cost(config, L_in: int, L_out: int, mode: str = "tflops",
               count_mac_as_2flop: bool = True):
    """
    Returns a single scalar per-query cost.
    mode="units"  -> unitless, param-free, architecture-agnostic
    mode="tflops" -> weighted & scaled (uses config to estimate constants)
    """
    n = getattr(config, "num_hidden_layers", None)
    if n is None:
      n = getattr(config.text_config, 'num_hidden_layers')
    d = getattr(config, "hidden_size", None)
    if d is None:
      d = getattr(config.text_config, 'hidden_size')
    tri = _tri(L_in, L_out)

    if mode == "units":
        return n * ((L_in + L_out) * (d ** 2) + d * tri)

    # tflops: architecture-aware constants + MAC->FLOPs + scale
    n, d, c_mlp, c_attn = _arch_constants_from_config(config)
    a = n * c_mlp * (d ** 2)     # d^2-heavy piece
    b = n * c_attn * d           # d·L attention piece
    flops = (L_in + L_out) * a + b * tri
    if count_mac_as_2flop:
        flops *= 2.0
    return flops / 1e12  # TFLOPs-ish


In [None]:
label_tokens = ['world', 'sports', 'business', 'technology']
print(label_tokens)
label_ids = []
for lbl in label_tokens:
    toks = tokenizer(f" {lbl}", add_special_tokens=False).input_ids
    print(toks[0])
    assert len(toks) == 1, f"'{lbl}' mapped to multiple tokens!"
    label_ids.append(toks[0])

processor = KWayLogitsProcessor(label_ids)

def prompt1(sentence):
    return f"""### Input:
Which of  the following categories (World, Sports, Business or Sci/Tech.) does this news article belong to? "{sentence}"
### Response:"""

def prompt2(sentence):
   return f"""
You are a news topic classification assistant.
Your task is to classify the given news article into exactly one of the following four categories:

1. World – International news, geopolitics, diplomacy, conflicts, global events, or foreign affairs

2. Sports – Sports events, games, teams, athletes, scores, tournaments, or athletic competitions

3. Business – Business, finance, economics, markets, companies, trade, or corporate news

4. Technology – Science, technology, research, engineering, computing, internet, or scientific discoveries

Rules:

1. Answer exactly one category: World, Sports, Business or Technology.

2. Base your decision on the main topic of the article.

3. If multiple topics appear, select the dominant one.

4. Do not create new categories.

5. Do not attempt real-world fact verification; classify only based on textual characteristics and internal consistency.

Text to classify: "{sentence}"

Response: """

def classify_k_way(samples, prompt=False, prompt_func=None, verbose=False):
    texts = samples['text']
    model.eval()
    class_map = {cls: idx for idx, cls in enumerate(label_tokens)}
    if prompt:
      assert prompt_func is not None
      texts = list(map(prompt_func, texts))
    inputs = tokenizer(texts, return_tensors="pt", padding=True).to(device)
    # L_in = int(inputs["input_ids"].shape[-1])

    L_in_list = inputs["attention_mask"].sum(dim=1).detach().cpu().tolist()

    # we only want 1 new token
    with torch.inference_mode():
      generated = model.generate(
          **inputs,
          max_new_tokens=1,
          logits_processor=[processor],
          do_sample=False,      # greedy
          return_dict_in_generate=True,
          output_scores=True, output_logits=True
      )
    # print(L_in)
    # print(generated.sequences, generated.sequences[0],
    #       generated.sequences.shape)
    # L_out = int(generated.sequences[0].shape[-1] - L_in)

    pad_length = inputs["input_ids"].shape[-1]  # seq. length after padding
    L_out_list = []
    for i in range(len(L_in_list)):
      seq_len = generated.sequences[i].shape[-1]
      L_out = max(seq_len - pad_length, 0)
      assert L_out == 1
      L_out_list.append(L_out)

    # print(L_in_list)
    # print(L_out_list)

    # cost_tflops = query_cost(model.config, L_in, L_out,
    #                          mode="tflops")

    # cost_tflops_list = [
    #     query_cost(model.config, L_in_list[i], L_out_list[i], mode="tflops")
    #     for i in range(len(L_in_list))
    # ]

    if verbose:
      print(generated.sequences)
      print(generated.scores)
      print(generated.logits)
    # print(generated.sequences[:, inputs["input_ids"].shape[-1]:])
    # decode predictions
    preds = [tokenizer.decode(g[0], skip_special_tokens=True).strip()
             for g in generated.sequences[:, inputs["input_ids"].shape[-1]:]]
    # (Optional) get probabilities for the two labels
    last_scores = generated.scores[0]  # tensor of shape (batch_size, vocab_size)
    if verbose:
      print(last_scores, last_scores.shape)
    probs = torch.softmax(last_scores, dim=-1)
    if verbose:
      print(probs)
      print(f"Number of zeros: {torch.sum(probs == 0).item()}, Number of non-zeros: {torch.sum(probs != 0).item()}")
      print(f'Non-zero indices: {torch.nonzero(probs)}')  # must be 4
    results = []
    del inputs
    for i, txt in enumerate(texts):
        results.append({
            "text": txt,
            "text_length": L_in_list[i],
            "output_length": L_out_list[i],
            "pred": preds[i],
            "pred_idx": class_map[preds[i]],
            # "prob_negative": probs[i, neg_id].item(),
            # "prob_positive": probs[i, pos_id].item(),
            # "probs": probs[i][probs[i] != 0].tolist(),
            "probs": probs[i, label_ids].tolist(),
            "ground_truth_label": samples['label'][i],
            # "cost_tflops": cost_tflops_list[i]
        })
    return results

In [None]:
import torch
import torch.profiler as prof

def classify_k_way_pytorch_prof(samples, prompt=False, prompt_func=None, verbose=False):
    texts = samples['text']
    model.eval()
    class_map = {cls: idx for idx, cls in enumerate(label_tokens)}

    if prompt:
        assert prompt_func is not None
        texts = list(map(prompt_func, texts))

    # Tokenize with padding
    inputs = tokenizer(texts, return_tensors="pt", padding=True).to(device)
    # input_ids = tokenized["input_ids"].to(device)
    # attention_mask = tokenized["attention_mask"].to(device)
    # inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

    # True input lengths per query
    L_in_list = inputs["attention_mask"].sum(dim=1).tolist()

    # --- Profile model.generate() to get FLOPs ---
    with prof.profile(
        activities=[prof.ProfilerActivity.CPU, prof.ProfilerActivity.CUDA],
        with_flops=True,
        record_shapes=False
    ) as p:
        with torch.inference_mode():
            generated = model.generate(
                **inputs,
                max_new_tokens=1,
                logits_processor=[processor],
                do_sample=False,
                return_dict_in_generate=True,
                output_scores=True,
                output_logits=True
            )

    # Aggregate estimated FLOPs across all ops
    total_flops = sum(
        e.flops for e in p.events() if hasattr(e, "flops") and e.flops is not None
    )
    total_tflops = total_flops / 1e12  # convert to TFLOPs

    # Compute output lengths (same for all with fixed max_new_tokens)
    pad_length = inputs["input_ids"].shape[-1]
    L_out_list = []
    for i in range(len(L_in_list)):
      seq_len = generated.sequences[i].shape[-1]
      L_out = max(seq_len - pad_length, 0)
      assert L_out == 1
      L_out_list.append(L_out)

    # Compute total effective token count across batch
    total_tokens = sum(L_in + L_out for L_in, L_out in zip(L_in_list, L_out_list))

    # Distribute total FLOPs proportionally to (L_in + L_out)
    cost_tflops_list = [
        total_tflops * (L_in + L_out) / total_tokens
        for L_in, L_out in zip(L_in_list, L_out_list)
    ]

    if verbose:
        print(f"Total FLOPs: {total_flops/1e12:.3f} TFLOPs")
        print(f"L_in_list: {L_in_list}")
        print(f"L_out_list: {L_out_list}")
        print(f"Per-query FLOPs (TFLOPs): {cost_tflops_list}")

    # Decode predictions
    preds = [tokenizer.decode(g[0], skip_special_tokens=True).strip()
             for g in generated.sequences[:, inputs["input_ids"].shape[-1]:]]

    # Compute probabilities
    last_scores = generated.scores[0]  # (batch_size, vocab_size)
    if verbose:
      print(last_scores, last_scores.shape)
    probs = torch.softmax(last_scores, dim=-1)

    # Build results with per-query FLOP cost
    results = []
    for i, txt in enumerate(texts):
        results.append({
            "text": txt,
            "text_length": L_in_list[i],
            "output_length": L_out_list[i],
            "pred": preds[i],
            "pred_idx": class_map.get(preds[i], -1),
            # "prob_negative": probs[i, neg_id].item(),
            # "prob_positive": probs[i, pos_id].item(),
            # "probs": probs[i][probs[i] != 0].tolist(),
            "probs": probs[i, label_ids].tolist(),
            "ground_truth_label": samples['label'][i],
            "cost_tflops": cost_tflops_list[i],
        })

    return results

In [None]:
# ds['train']['sentence']

In [None]:
def batch_list(data, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

In [None]:
from sklearn.metrics import accuracy_score
from tqdm import tqdm
# samples = [
#     "I absolutely loved the new album—it's fantastic!",
#     "The product broke after two days; I'm really disappointed."
# ]
# for r in tqdm(classify_binary(samples, prompt=True, verbose=False)):
#     print(r)  # dict item of results list
#     # print(f"> {r['text']!r}")
#     # print(f"  →  {r['pred']}, ({r['pred_idx']})")
#     predicted_labels.append(r['pred_idx'])
#     # print(f"     P(negative) = {r['prob_negative']:.3f},  P(positive) = {r['prob_positive']:.3f}\n")

samples = ds['train'].select(range(8000))
print(len(samples))
predicted_labels = []


batch_size = 16
batches = list(batch_list(samples, batch_size))

print(batches[0])
# exit(0)

results = []
for batch in tqdm(batches):
  # print(sample)
  # r = classify_k_way(batch, prompt=True, prompt_func=prompt2, verbose=False) # list
  r = classify_k_way_pytorch_prof(batch, prompt=True, prompt_func=prompt2, verbose=False) # list
  results.extend(r) if batch_size > 1 else results.append(r) # extend() if batch_size > 1 else append()
  # print(r)
  # predicted_labels.append(r[0]['pred_idx'])
  predicted_labels.extend([ri['pred_idx'] for ri in r])

print(f"\nAccuracy: {accuracy_score(samples['label'], predicted_labels)}")

In [None]:
results

In [None]:
import pickle
with open('agnews_train_ministral3-3b-it_profiler.pkl', 'wb') as f:
    pickle.dump(results, f)