# Lightning Ray

In this notebook, we perform a basic transformer classification task. 
The main purpose is exploration of PyTorch Lightning and Ray


Lets start with a simple smoke test. We will perform an inference baseline on this machine with nothing added on

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch, time

MODEL_ID = "sshleifer/tiny-distilroberta-base" # Super small, OK on CPU
ds = load_dataset("glue", "sst2", split = "train[:200]") #small slice

tok = AutoTokenizer.from_pretrained(MODEL_ID)
batch = tok(list(ds["sentence"][:8]),
            padding = True,
            truncation = True, 
            max_length=128,
            return_tensors ="pt")

print("Tokenized shapes:", {k: tuple(v.shape) for k, v in batch.items()})

model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels =2)
model.eval()

with torch.inference_mode():
    _ = model(**batch) # warmup
    iters = 50
    t0 = time.time()
    for _ in range(iters):
        _ = model(**batch)
        dt = time.time() - t0
        bs = batch["input_ids"].shape[0]
        print(f"Average inference per batch {dt/iters*1000:.2f} ms (batch_size) {bs}")
        
print("Smoke Test Complete")

  from .autonotebook import tqdm as notebook_tqdm


Tokenized shapes: {'input_ids': (8, 33), 'attention_mask': (8, 33)}


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at sshleifer/tiny-distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Average inference per batch 0.02 ms (batch_size) 8
Average inference per batch 0.06 ms (batch_size) 8
Average inference per batch 0.10 ms (batch_size) 8
Average inference per batch 0.12 ms (batch_size) 8
Average inference per batch 0.16 ms (batch_size) 8
Average inference per batch 0.18 ms (batch_size) 8
Average inference per batch 0.20 ms (batch_size) 8
Average inference per batch 0.22 ms (batch_size) 8
Average inference per batch 0.24 ms (batch_size) 8
Average inference per batch 0.31 ms (batch_size) 8
Average inference per batch 0.33 ms (batch_size) 8
Average inference per batch 0.37 ms (batch_size) 8
Average inference per batch 0.39 ms (batch_size) 8
Average inference per batch 0.43 ms (batch_size) 8
Average inference per batch 0.49 ms (batch_size) 8
Average inference per batch 0.53 ms (batch_size) 8
Average inference per batch 0.57 ms (batch_size) 8
Average inference per batch 0.61 ms (batch_size) 8
Average inference per batch 0.63 ms (batch_size) 8
Average inference per batch 0.6

Lets introduce some lightning elements

In [3]:
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding
from torch.utils.data import DataLoader
import pytorch_lightning as pl

MODEL_ID = "sshleifer/tiny-distilroberta-base"
MAX_LEN = 128


class SST2DataModule(pl.LightningDataModule):
    def __init__(self, model_id = MODEL_ID, batch_size =32,num_workers=0, pin_memory=False, persistent_workers=False):
        super().__init__()
        self.model_id = model_id
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        self.collate = DataCollatorWithPadding(self.tokenizer)
        
    def prepare_data(self):
        # download/cache only
        load_dataset("glue", "sst2")
        AutoTokenizer.from_pretrained(self.model_id)
        
    def setup(self, stage= None):
        ds_train = load_dataset("glue", "sst2", split="train[:1000]")
        ds_val = load_dataset("glue", "sst2", split="validation[:200]")
        
        def tok_fn(examples):
            t = self.tokenizer(
                examples["sentence"],
                truncation=True,
                max_length=MAX_LEN,
            )
            t["labels"] = examples["label"]  # copy labels -> 'labels'
            return t
        
        # batched tokenization
        ds_train = ds_train.map(tok_fn, batched=True,  remove_columns=ds_train.column_names)
        ds_val = ds_val.map(tok_fn, batched=True,  remove_columns=ds_val.column_names)    
        
        self.ds_train, self.ds_val = ds_train, ds_val
        
    def train_dataloader(self):
        return DataLoader(
            self.ds_train,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=self.collate,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
        )
        
    def val_dataloader(self):
        return DataLoader(
            self.ds_val,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=self.collate,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
        )
        
print("DataModule Defined")
        
        


DataModule Defined


In [4]:
#Now the lighnting module. This wraps the HF model
import torch
import torch.nn as nn
import pytorch_lightning as pl
from transformers import AutoModelForSequenceClassification

class LitTinyClassifier(pl.LightningModule):
    def __init__(self, model_id=MODEL_ID, lr=5e-5):
        super().__init__()
        self.save_hyperparameters()
        self.model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=2)
        self.lr = lr
        
    def forward(self, **batch):
        return self.model(**batch)
    
    
    def training_step(self, batch, batch_idx):
        out = self(**batch)
        loss = out.loss
        
        #Quick accuracy sanity check
        preds = out.logits.argmax(dim=-1)
        acc = (preds == batch["labels"]).float().mean()
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        out = self(**batch)
        loss = out.loss
        preds = out.logits.argmax(dim=-1)
        acc = (preds == batch["labels"]).float().mean()
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)
    
print("Lightning Module Ready")
        

        
        
        



Lightning Module Ready


In [5]:
#Lets do just a single epoch of training

import pytorch_lightning as pl
import torch

pl.seed_everything(42, workers=True)

dm = SST2DataModule(
    model_id=MODEL_ID,
    batch_size=32,
    num_workers=2,          # start at 2 on Windows
    pin_memory=True,        # good for CUDA async H2D copies
    persistent_workers=True # avoid respawn cost each epoch
)
dm.prepare_data()
dm.setup()

model = LitTinyClassifier(model_id=MODEL_ID, lr = 5e-5)

precision = "bf16-mixed" if hasattr(torch.cuda, "is_available") and torch.cuda.is_available() else "32-true"

trainer = pl.Trainer(
    max_epochs=1,
    accelerator="gpu",
    devices=1,
    precision=precision,
    log_every_n_steps=10,
)

trainer.fit(model, datamodule=dm)

Seed set to 42
Map: 100%|██████████| 1000/1000 [00:00<00:00, 17480.13 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 10243.00 examples/s]
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at sshleifer/tiny-distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using bfloat16 Automatic Mixed Precision (AMP)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3050 6GB Laptop GPU') that has Tensor Cores. To proper

                                                                           



Epoch 0: 100%|██████████| 32/32 [00:01<00:00, 18.18it/s, v_num=2, val_loss=0.693, val_acc=0.495, train_loss=0.693, train_acc=0.542]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 32/32 [00:01<00:00, 18.00it/s, v_num=2, val_loss=0.693, val_acc=0.495, train_loss=0.693, train_acc=0.542]


In [6]:
# Inference latency
import torch, numpy as np
import time

def _to_device(batch, device):
    return {k: v.to(device) for k, v in batch.items()}

def _should_sync(device: str) -> bool:
    return device.startswith("cuda") and torch.cuda.is_available()

def measure_latency(hf_model, batch, repeats = 200, warmup = 20, device ="cpu"):
    hf_model.eval().to(device)
    batch = _to_device(batch, device)
    
    # warmup (not timed)
    with torch.inference_mode():
        for _ in range(warmup):
            _ = hf_model(**batch)
    if _should_sync(device):
        torch.cuda.synchronize()
        
    # timed loop
    samples_ms = []
    with torch.inference_mode():
        for _ in range(repeats):
            t0 = time.perf_counter()
            _ = hf_model(**batch)
            if _should_sync(device):
                torch.cuda.synchronize()
            samples_ms.append((time.perf_counter() - t0) * 1000.0)

    samples_ms = np.asarray(samples_ms, dtype=float)
    return {
        "batch_size": int(batch["input_ids"].shape[0]),
        "mean_ms": float(samples_ms.mean()),
        "p50_ms": float(np.percentile(samples_ms, 50)),
        "p95_ms": float(np.percentile(samples_ms, 95)),
        "p99_ms": float(np.percentile(samples_ms, 99)),
        "repeats": int(repeats),
    }
    
# grab a validation batch
val_loader = dm.val_dataloader()
batch_val = next(iter(val_loader))

# bs=8
fp32_bs8 = measure_latency(model.model, batch_val, repeats=150, warmup=30, device="cpu")

# bs=1 (edge/onboard-ish)
single = {k: v[:1].clone() for k, v in batch_val.items()}
fp32_bs1 = measure_latency(model.model, single, repeats=300, warmup=50, device="cpu")

print("FP32 baseline (bs=8):", fp32_bs8)
print("FP32 baseline (bs=1):", fp32_bs1)

FP32 baseline (bs=8): {'batch_size': 32, 'mean_ms': 2.0493573353936276, 'p50_ms': 1.9735000096261501, 'p95_ms': 2.537545037921517, 'p99_ms': 2.7957989822607487, 'repeats': 150}
FP32 baseline (bs=1): {'batch_size': 1, 'mean_ms': 1.238872332808872, 'p50_ms': 1.1840499937534332, 'p95_ms': 1.54692500946112, 'p99_ms': 2.0025459968019272, 'repeats': 300}


In [7]:
import torch

qmodel = torch.quantization.quantize_dynamic(
    model.model, 
    {torch.nn.Linear},
    dtype=torch.qint8
)


q_bs8 = measure_latency(qmodel, batch_val, repeats=150, warmup=30, device="cpu")
q_bs1 = measure_latency(qmodel, single,   repeats=300, warmup=50, device="cpu")


print("INT8 quant (bs=8):", q_bs8)
print("INT8 quant (bs=1):", q_bs1)

# quick accuracy sanity on a few batches
def quick_accuracy(hf_model, loader, max_batches=10, device="cpu"):
    hf_model.eval().to(device)
    correct = total = 0
    with torch.inference_mode():
        for i, b in enumerate(loader):
            if i >= max_batches: break
            b = _to_device(b, device)
            out = hf_model(**b)
            preds = out.logits.argmax(dim=-1)
            correct += (preds == b["labels"]).sum().item()
            total   += preds.numel()
    return correct / total

acc_fp32 = quick_accuracy(model.model, dm.val_dataloader(), max_batches=10)
acc_int8 = quick_accuracy(qmodel,      dm.val_dataloader(), max_batches=10)
print(f"Quick val accuracy FP32: {acc_fp32:.3f} | INT8: {acc_int8:.3f}")

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  qmodel = torch.quantization.quantize_dynamic(


INT8 quant (bs=8): {'batch_size': 32, 'mean_ms': 3.5069446661509573, 'p50_ms': 3.36169998627156, 'p95_ms': 4.2835799453314385, 'p99_ms': 5.731635066913436, 'repeats': 150}
INT8 quant (bs=1): {'batch_size': 1, 'mean_ms': 3.549782671422387, 'p50_ms': 3.6640000180341303, 'p95_ms': 4.545009910361841, 'p99_ms': 4.901293938746674, 'repeats': 300}
Quick val accuracy FP32: 0.495 | INT8: 0.495


In [8]:
fp32_cuda_bs8 = measure_latency(model.model, batch_val, repeats=200, warmup=60, device="cuda")
single = {k: v[:1].clone() for k, v in batch_val.items()}
fp32_cuda_bs1 = measure_latency(model.model, single, repeats=300, warmup=80, device="cuda")
print("CUDA FP (bs=8):", fp32_cuda_bs8)
print("CUDA FP (bs=1):", fp32_cuda_bs1)

CUDA FP (bs=8): {'batch_size': 32, 'mean_ms': 1.9203689997084439, 'p50_ms': 1.893299981020391, 'p95_ms': 2.1348749636672437, 'p99_ms': 2.5890200270805495, 'repeats': 200}
CUDA FP (bs=1): {'batch_size': 1, 'mean_ms': 1.8396346647447597, 'p50_ms': 1.7695500864647329, 'p95_ms': 2.163690055022016, 'p99_ms': 3.2102429785300024, 'repeats': 300}


In [9]:
%pip install -U ray


Collecting ray
  Downloading ray-2.49.1-cp312-cp312-win_amd64.whl.metadata (21 kB)
Collecting click>=7.0 (from ray)
  Downloading click-8.2.1-py3-none-any.whl.metadata (2.5 kB)
Collecting jsonschema (from ray)
  Using cached jsonschema-4.25.1-py3-none-any.whl.metadata (7.6 kB)
Collecting msgpack<2.0.0,>=1.0.0 (from ray)
  Downloading msgpack-1.1.1-cp312-cp312-win_amd64.whl.metadata (8.6 kB)
Collecting jsonschema-specifications>=2023.03.6 (from jsonschema->ray)
  Using cached jsonschema_specifications-2025.4.1-py3-none-any.whl.metadata (2.9 kB)
Collecting referencing>=0.28.4 (from jsonschema->ray)
  Using cached referencing-0.36.2-py3-none-any.whl.metadata (2.8 kB)
Collecting rpds-py>=0.7.1 (from jsonschema->ray)
  Downloading rpds_py-0.27.1-cp312-cp312-win_amd64.whl.metadata (4.3 kB)
Downloading ray-2.49.1-cp312-cp312-win_amd64.whl (26.2 MB)
   ---------------------------------------- 0.0/26.2 MB ? eta -:--:--
   ---------------------------------------- 0.3/26.2 MB ? eta -:--:--
   - -

In [10]:
# Try the extras (often fine on 3.12). If this errors, skip it—base `ray` is enough.
%pip install -U "ray[train]"


Collecting tensorboardX>=1.9 (from ray[train])
  Downloading tensorboardx-2.6.4-py3-none-any.whl.metadata (6.2 kB)
Collecting pydantic!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,<3 (from ray[train])
  Downloading pydantic-2.11.7-py3-none-any.whl.metadata (67 kB)
Collecting annotated-types>=0.6.0 (from pydantic!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,<3->ray[train])
  Downloading annotated_types-0.7.0-py3-none-any.whl.metadata (15 kB)
Collecting pydantic-core==2.33.2 (from pydantic!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,<3->ray[train])
  Downloading pydantic_core-2.33.2-cp312-cp312-win_amd64.whl.metadata (6.9 kB)
Collecting typing-inspection>=0.4.0 (from pydantic!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,<3->ray[train])
  Downloading typing_inspection-0.4.1-py3-none-any.whl.metadata (2.6 kB)
Downloading pydantic-2.11.7-py3-none-any.whl (444 kB)
Downloading pydantic_core-2.33.2-cp312-cp312-win_amd64.whl (2.0 MB)
   ---------------------------------------- 0.0/2.0 MB ? eta -:--:--
   ---------- -

In [11]:
import ray, platform, sys
print("Ray:", ray.__version__, "| Python:", sys.version.split()[0], "| OS:", platform.platform())

ray.init(ignore_reinit_error=True)  # starts a local Ray runtime

@ray.remote
def square(x): 
    return x * x

futures = [square.remote(i) for i in range(5)]
print("Squares:", ray.get(futures))

# (Optional) tiny actor test
@ray.remote
class Counter:
    def __init__(self): self.n = 0
    def inc(self): self.n += 1; return self.n

c = Counter.remote()
print("Counter:", ray.get([c.inc.remote() for _ in range(3)]))

ray.shutdown()


2025-09-06 23:08:07,474	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Ray: 2.49.1 | Python: 3.12.10 | OS: Windows-11-10.0.26100-SP0


2025-09-06 23:08:10,215	INFO worker.py:1951 -- Started a local Ray instance.


Squares: [0, 1, 4, 9, 16]
Counter: [1, 2, 3]


In [12]:
try:
    from ray.train.torch import TorchTrainer
    from ray.train import ScalingConfig
    print("Ray Train OK")
except Exception as e:
    print("Ray Train not available:", e)


2025-09-06 23:08:12,922	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-09-06 23:08:13,001	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Ray Train OK


In [16]:

# Cell 1 — CUDA-friendly DataModule (workers/pinning)
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding
from torch.utils.data import DataLoader
import pytorch_lightning as pl

# keep your existing MODEL_ID / MAX_LEN from before
# MODEL_ID = "..." 
# MAX_LEN = 128

class SST2DataModule(pl.LightningDataModule):
    def __init__(self, model_id=MODEL_ID, batch_size=32, num_workers=2, pin_memory=True, persistent_workers=True, prefetch_factor=2):
        super().__init__()
        self.model_id = model_id
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers
        self.prefetch_factor = prefetch_factor
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        self.collate = DataCollatorWithPadding(self.tokenizer)

    def prepare_data(self):
        load_dataset("glue", "sst2")
        AutoTokenizer.from_pretrained(self.model_id)

    def setup(self, stage=None):
        ds_train = load_dataset("glue", "sst2", split="train[:1000]")
        ds_val   = load_dataset("glue", "sst2", split="validation[:200]")

        def tok_fn(batch):
            t = self.tokenizer(batch["sentence"], truncation=True, max_length=MAX_LEN)
            t["labels"] = batch["label"]
            return t

        ds_train = ds_train.map(tok_fn, batched=True, remove_columns=ds_train.column_names)
        ds_val   = ds_val.map(tok_fn,   batched=True, remove_columns=ds_val.column_names)
        self.ds_train, self.ds_val = ds_train, ds_val

    def _loader(self, ds, shuffle: bool):
        kw = dict(
            dataset=ds,
            batch_size=self.batch_size,
            shuffle=shuffle,
            collate_fn=self.collate,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
        )
        if self.num_workers > 0:
            kw["prefetch_factor"] = self.prefetch_factor
        return DataLoader(**kw)

    def train_dataloader(self):
        return self._loader(self.ds_train, shuffle=True)

    def val_dataloader(self):
        return self._loader(self.ds_val, shuffle=False)

print("DataModule (CUDA-ready) defined.")


DataModule (CUDA-ready) defined.


In [17]:
# Cell 2 — GPU training
import pytorch_lightning as pl, torch

pl.seed_everything(42, workers=True)

dm = SST2DataModule(model_id=MODEL_ID, batch_size=32, num_workers=2, pin_memory=True, persistent_workers=True)

dm.prepare_data(); dm.setup()

model = LitTinyClassifier(model_id=MODEL_ID, lr=5e-5)

can_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
precision = "bf16-mixed" if can_bf16 else ("16-mixed" if torch.cuda.is_available() else "32-true")

trainer = pl.Trainer(
    max_epochs=1,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    precision=precision,
    log_every_n_steps=10,
)
trainer.fit(model, datamodule=dm)


Seed set to 42
Map: 100%|██████████| 1000/1000 [00:00<00:00, 16523.48 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 6877.43 examples/s]
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at sshleifer/tiny-distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using bfloat16 Automatic Mixed Precision (AMP)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Map: 100%|██████████| 1000/1000 [00:00<00:00, 17056.12 examples/s]
Map: 100%|██████████| 200/200 [00:00<

                                                                           



Epoch 0: 100%|██████████| 32/32 [00:01<00:00, 18.64it/s, v_num=3, val_loss=0.693, val_acc=0.495, train_loss=0.693, train_acc=0.542]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 32/32 [00:01<00:00, 18.48it/s, v_num=3, val_loss=0.693, val_acc=0.495, train_loss=0.693, train_acc=0.542]


In [18]:
# Cell 3 — CUDA latency
import time, numpy as np, torch

def _to_device(batch, device):
    return {k: v.to(device, non_blocking=True) for k, v in batch.items()}

def _should_sync(device: str) -> bool:
    return device.startswith("cuda") and torch.cuda.is_available()

def measure_latency(hf_model, batch, repeats=200, warmup=60, device="cuda", amp=False):
    hf_model.eval().to(device)
    batch = _to_device(batch, device)

    # warmup
    with torch.inference_mode():
        if amp and device.startswith("cuda"):
            with torch.autocast(device_type="cuda", dtype=(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)):
                for _ in range(warmup): _ = hf_model(**batch)
        else:
            for _ in range(warmup): _ = hf_model(**batch)
    if _should_sync(device): torch.cuda.synchronize()

    times = []
    with torch.inference_mode():
        for _ in range(repeats):
            t0 = time.perf_counter()
            if amp and device.startswith("cuda"):
                with torch.autocast(device_type="cuda", dtype=(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)):
                    _ = hf_model(**batch)
            else:
                _ = hf_model(**batch)
            if _should_sync(device): torch.cuda.synchronize()
            times.append((time.perf_counter() - t0) * 1000.0)

    arr = np.asarray(times, dtype=float)
    return {
        "batch_size": int(batch["input_ids"].shape[0]),
        "mean_ms": float(arr.mean()),
        "p50_ms": float(np.percentile(arr, 50)),
        "p95_ms": float(np.percentile(arr, 95)),
        "p99_ms": float(np.percentile(arr, 99)),
        "repeats": int(repeats),
        "amp": bool(amp),
        "device": device,
    }

# grab a val batch
val_loader = dm.val_dataloader()
batch_val = next(iter(val_loader))

# bs=8
fp32_cuda_bs8 = measure_latency(model.model, batch_val, repeats=200, warmup=80, device="cuda", amp=False)
amp_cuda_bs8  = measure_latency(model.model, batch_val, repeats=200, warmup=80, device="cuda", amp=True)

# bs=1
single = {k: v[:1].clone() for k, v in batch_val.items()}
fp32_cuda_bs1 = measure_latency(model.model, single, repeats=300, warmup=100, device="cuda", amp=False)
amp_cuda_bs1  = measure_latency(model.model, single, repeats=300, warmup=100, device="cuda", amp=True)

print("CUDA FP32 (bs=8):", fp32_cuda_bs8)
print("CUDA AMP  (bs=8):", amp_cuda_bs8)
print("CUDA FP32 (bs=1):", fp32_cuda_bs1)
print("CUDA AMP  (bs=1):", amp_cuda_bs1)


CUDA FP32 (bs=8): {'batch_size': 32, 'mean_ms': 2.1199090004665777, 'p50_ms': 1.9768000347539783, 'p95_ms': 3.5530599532648877, 'p99_ms': 3.999122950481251, 'repeats': 200, 'amp': False, 'device': 'cuda'}
CUDA AMP  (bs=8): {'batch_size': 32, 'mean_ms': 2.750707999803126, 'p50_ms': 2.607349946629256, 'p95_ms': 3.9444049878511573, 'p99_ms': 5.274887993000447, 'repeats': 200, 'amp': True, 'device': 'cuda'}
CUDA FP32 (bs=1): {'batch_size': 1, 'mean_ms': 2.2397536667995155, 'p50_ms': 1.9745000172406435, 'p95_ms': 3.5744099703151733, 'p99_ms': 3.9626729849260256, 'repeats': 300, 'amp': False, 'device': 'cuda'}
CUDA AMP  (bs=1): {'batch_size': 1, 'mean_ms': 2.769810004004588, 'p50_ms': 2.556700026616454, 'p95_ms': 4.814254981465638, 'p99_ms': 5.666363951750099, 'repeats': 300, 'amp': True, 'device': 'cuda'}


In [19]:
# Cell 4 — torch.compile (inference)
import torch
if hasattr(torch, "compile"):
    try:
        cmodel = torch.compile(model.model, dynamic=True)
        c_bs8 = measure_latency(cmodel, batch_val, repeats=200, warmup=120, device="cuda", amp=False)
        c_bs1 = measure_latency(cmodel, single,   repeats=300, warmup=150, device="cuda", amp=False)
        print("CUDA compile FP32 (bs=8):", c_bs8)
        print("CUDA compile FP32 (bs=1):", c_bs1)
    except Exception as e:
        print("torch.compile not usable here:", repr(e))
else:
    print("torch.compile not available in this build.")


  return torch.layer_norm(
W0906 23:13:26.596000 32672 Lib\site-packages\torch\_inductor\utils.py:1436] [0/0_1] Not enough SMs to use max_autotune_gemm mode


torch.compile not usable here: TritonMissing('Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at: https://github.com/triton-lang/triton\n\nSet TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you\'re reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"\n')


Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.

