<a href="https://colab.research.google.com/github/afzal/LLM-NetThroughput/blob/main/LLM_NetThroughput.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -U "transformers>=4.40" sentencepiece




In [2]:
from transformers import T5ForConditionalGeneration, T5TokenizerFast
import torch
print("OK. Torch version:", torch.__version__)


OK. Torch version: 2.8.0+cpu


In [3]:
!pip install -U datasets pandas torch accelerate



In [2]:

%%writefile llm_bw_predictor.py
#!/usr/bin/env python3
"""
LLM bandwidth predictor (t5-small fine-tune)

- Train on CSV with: loss, delay_ms, bandwidth_mbps
- Predict bandwidth (Mbps) for new loss & delay pairs
- Treats the task as text-to-text for a small LLM (T5-small)

Install:
  pip install -U "transformers>=4.40" datasets pandas torch accelerate sentencepiece

Examples:
  # make synthetic data
  python llm_bw_predictor.py --make_synth synth.csv

  # train (85/15 split if no --valid_csv)
  python llm_bw_predictor.py --train_csv synth.csv --output_dir ./bw_model --epochs 5 --batch_size 16

  # predict
  python llm_bw_predictor.py --predict --model_dir ./bw_model --loss 0.3 --delay_ms 45
"""
import argparse
import os
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd

from transformers import (
    T5ForConditionalGeneration,
    T5TokenizerFast,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)
from datasets import Dataset, DatasetDict


# ----------------------------
# Helpers
# ----------------------------
def normalize_columns(df: pd.DataFrame) -> pd.DataFrame:
    cols = {c.lower().strip(): c for c in df.columns}
    # Accept a few variants
    loss_col = next((cols[k] for k in cols if k in ["loss", "loss_pct", "packet_loss", "loss_percent"]), None)
    delay_col = next((cols[k] for k in cols if k in ["delay_ms", "rtt_ms", "latency_ms", "delay"]), None)
    bw_col = next((cols[k] for k in cols if k in ["bandwidth_mbps", "throughput_mbps", "bw_mbps", "bandwidth"]), None)

    if loss_col is None or delay_col is None or bw_col is None:
        raise ValueError("CSV must include columns for loss, delay_ms, bandwidth_mbps (case-insensitive).")

    df = df.rename(columns={loss_col: "loss", delay_col: "delay_ms", bw_col: "bandwidth_mbps"})
    return df[["loss", "delay_ms", "bandwidth_mbps"]].copy()


def format_loss_for_prompt(loss_value: float) -> str:
    # If the loss looks like 0..1, present as percentage %
    if 0 <= loss_value <= 1.0:
        return f"{loss_value*100:.2f}%"
    # If it looks already like percentage (e.g., 2.5 meaning 2.5%)
    if 0 <= loss_value <= 100:
        return f"{loss_value:.2f}%"
    return f"{loss_value:.4f}"


def make_prompt(loss: float, delay_ms: float) -> str:
    return f"predict bandwidth (Mbps) given: loss={format_loss_for_prompt(loss)}, delay={delay_ms:.2f} ms"


def make_target(bw_mbps: float) -> str:
    return f"{bw_mbps:.2f} Mbps"


def parse_bw_from_text(text: str) -> Optional[float]:
    # Find first number in the text (optionally followed by 'Mbps')
    import re
    m = re.search(r"(-?\d+(\.\d+)?)", text)
    if m:
        return float(m.group(1))
    return None


# ----------------------------
# Dataset preparation
# ----------------------------
def build_hf_dataset(train_csv: str, valid_csv: Optional[str]) -> DatasetDict:
    train_df = normalize_columns(pd.read_csv(train_csv))
    if valid_csv and Path(valid_csv).exists():
        valid_df = normalize_columns(pd.read_csv(valid_csv))
    else:
        # simple split
        split_idx = int(0.85 * len(train_df))
        valid_df = train_df.iloc[split_idx:].reset_index(drop=True)
        train_df = train_df.iloc[:split_idx].reset_index(drop=True)

    def to_records(df: pd.DataFrame):
        return [
            {"input_text": make_prompt(r.loss, r.delay_ms), "target_text": make_target(r.bandwidth_mbps)}
            for r in df.itertuples(index=False)
        ]

    train_records = to_records(train_df)
    valid_records = to_records(valid_df)

    return DatasetDict(
        {
            "train": Dataset.from_list(train_records),
            "validation": Dataset.from_list(valid_records),
        }
    )


# ----------------------------
# Training
# ----------------------------
def train_model(
    train_csv: str,
    valid_csv: Optional[str],
    output_dir: str,
    model_name: str = "t5-small",
    epochs: int = 5,
    batch_size: int = 16,
    lr: float = 3e-4,
    weight_decay: float = 0.01,
    warmup_ratio: float = 0.03,
):
    ds = build_hf_dataset(train_csv, valid_csv)
    tokenizer = T5TokenizerFast.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)

    def tokenize_examples(batch):
        model_inputs = tokenizer(batch["input_text"], padding=False, truncation=True)
        labels = tokenizer(batch["target_text"], padding=False, truncation=True)
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    tokenized = ds.map(tokenize_examples, batched=True, remove_columns=ds["train"].column_names)

    args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        learning_rate=lr,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay=weight_decay,
        warmup_ratio=warmup_ratio,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_steps=50,
        predict_with_generate=True,
        fp16=True if os.environ.get("USE_FP16", "1") == "1" else False,
        report_to=[],
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
    )

    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

    # We'll compute a simple numeric metric after each epoch by generating on the validation set.
    def compute_eval_metrics(eval_preds):
        preds, labels = eval_preds
        # Convert label IDs to text
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(
            np.where(labels != -100, labels, tokenizer.pad_token_id), skip_special_tokens=True
        )

        pred_vals = [parse_bw_from_text(t) for t in decoded_preds]
        label_vals = [parse_bw_from_text(t) for t in decoded_labels]

        # Filter pairs where parsing failed
        pairs = [(p, l) for p, l in zip(pred_vals, label_vals) if p is not None and l is not None]
        if not pairs:
            return {"rmse": float("nan"), "mae": float("nan")}

        errors = [p - l for p, l in pairs]
        mae = float(np.mean(np.abs(errors)))
        rmse = float(np.sqrt(np.mean(np.square(errors))))
        return {"rmse": rmse, "mae": mae}

    trainer = Seq2SeqTrainer(
        model=model,
        args=args,
        train_dataset=tokenized["train"],
        eval_dataset=tokenized["validation"],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_eval_metrics,
    )

    trainer.train()
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model saved to: {output_dir}")


# ----------------------------
# Inference
# ----------------------------
def load_model(model_dir: str):
    tokenizer = T5TokenizerFast.from_pretrained(model_dir)
    model = T5ForConditionalGeneration.from_pretrained(model_dir)
    return tokenizer, model


def predict(model_dir: str, loss: float, delay_ms: float, num_return_sequences: int = 1):
    import torch

    tokenizer, model = load_model(model_dir)
    prompt = make_prompt(loss, delay_ms)
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=12,
            num_return_sequences=num_return_sequences,
            do_sample=False,
        )
    texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    bws = [parse_bw_from_text(t) for t in texts]
    return list(zip(texts, bws))


# ----------------------------
# Synthetic data (optional)
# ----------------------------
def make_synth(csv_path: str, n: int = 2000, seed: int = 13):
    rng = np.random.default_rng(seed)
    # Random delays (10..200 ms) and loss (0..3%)
    delay = rng.uniform(10, 200, size=n)
    loss_pct = rng.uniform(0, 3, size=n)  # percentage
    # A toy "true" function for bandwidth (Mbps)
    base = rng.uniform(20, 80, size=n)
    bw = base * np.exp(-0.01 * (delay - 10)) * (1 - 0.015 * loss_pct) + rng.normal(0, 1.5, size=n)
    bw = np.clip(bw, 0.5, None)

    df = pd.DataFrame(
        {
            "loss": loss_pct,  # in percent
            "delay_ms": delay,
            "bandwidth_mbps": bw,
        }
    )
    df.to_csv(csv_path, index=False)
    print(f"Synthetic dataset written to {csv_path} (rows={n})")


# ----------------------------
# Main
# ----------------------------
def main():
    parser = argparse.ArgumentParser(description="LLM bandwidth predictor (t5-small fine-tune)")
    parser.add_argument("--train_csv", type=str, help="Training CSV with loss, delay_ms, bandwidth_mbps")
    parser.add_argument("--valid_csv", type=str, default=None, help="Validation CSV (optional)")
    parser.add_argument("--output_dir", type=str, default="./bw_model", help="Model output directory")
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--predict", action="store_true", help="Run inference instead of training")
    parser.add_argument("--model_dir", type=str, default="./bw_model", help="Model dir for --predict")
    parser.add_argument(
        "--loss",
        type=float,
        default=None,
        help="Loss value for inference (e.g., 0.3 meaning 0.3% if <=1, else percent)",
    )
    parser.add_argument("--delay_ms", type=float, default=None, help="Delay (ms) for inference")
    parser.add_argument("--make_synth", type=str, default=None, help="Write a synthetic CSV to this path and exit")
    args = parser.parse_args()

    if args.make_synth:
        make_synth(args.make_synth)
        return

    if args.predict:
        if args.loss is None or args.delay_ms is None:
            raise SystemExit("--predict requires --loss and --delay_ms")
        results = predict(args.model_dir, args.loss, args.delay_ms, num_return_sequences=1)
        txt, bw = results[0]
        print(f"Model output: {txt}")
        if bw is not None:
            print(f"Parsed bandwidth estimate: {bw:.2f} Mbps")
        else:
            print("Could not parse a numeric bandwidth from the model output.")
        return

    if not args.train_csv:
        raise SystemExit("--train_csv is required for training")

    train_model(
        train_csv=args.train_csv,
        valid_csv=args.valid_csv,
        output_dir=args.output_dir,
        epochs=args.epochs,
        batch_size=args.batch_size,
        lr=args.lr,
    )


if __name__ == "__main__":
    main()


Writing llm_bw_predictor.py


In [3]:
!python llm_bw_predictor.py --make_synth synth.csv

2025-08-16 05:06:42.587964: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755320802.619834    1676 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755320802.629583    1676 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1755320802.653907    1676 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755320802.653971    1676 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755320802.653976    1676 computation_placer.cc:177] computation placer alr

In [9]:
!head -n 5 synth.csv


loss,delay_ms,bandwidth_mbps
1.1815093920577224,174.31154153315143,4.143646445682554
2.950307935616407,172.50747783709122,15.908660585082956
2.8268444008844287,164.09444576902501,8.565892021196026
2.9762169337032014,59.674808669130556,10.472085475811841


In [4]:
!python llm_bw_predictor.py --train_csv synth.csv --output_dir ./bw_model --epochs 3 --batch_size 16

2025-08-16 05:07:04.522008: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755320824.554381    1813 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755320824.563784    1813 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1755320824.594088    1813 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755320824.594167    1813 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755320824.594178    1813 computation_placer.cc:177] computation placer alr

In [5]:
!python llm_bw_predictor.py --predict --model_dir ./bw_model --loss 0.3 --delay_ms 45


2025-08-16 05:24:00.263947: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755321840.291004    5864 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755321840.298877    5864 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1755321840.322528    5864 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755321840.322590    5864 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755321840.322595    5864 computation_placer.cc:177] computation placer alr

from transformers import T5ForConditionalGeneration, T5TokenizerFast
import torch
print("OK. Torch version:", torch.__version__)
