In [None]:
import time

import torch
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch.optim import AdamW
from transformers import Qwen3ForCausalLM, Qwen2Tokenizer

from tqdm.autonotebook import tqdm
from datasets import load_dataset

In [None]:
torch.set_float32_matmul_precision("high")

In [None]:
MODEL_NAME = "Qwen/Qwen3-0.6B"
model = Qwen3ForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="cuda")
# model.model.config._attn_implementation = "flash_attention_3"
tokenizer = Qwen2Tokenizer.from_pretrained(MODEL_NAME)

model = torch.compile(model, mode="max-autotune-no-cudagraphs")
optimizer = AdamW(model.parameters(), lr=1e-4)

In [None]:
DS = load_dataset("roneneldan/TinyStories")

BATCH_SIZE = 4

train_ds = DS["train"].select(range(BATCH_SIZE*100))
val_ds = DS["train"].select(range(100000, 100000 + BATCH_SIZE*5))

train_ds = train_ds.map(lambda x: tokenizer(x["text"]), batched=True)
val_ds = val_ds.map(lambda x: tokenizer(x["text"]), batched=True)


def collate_fn(batch):
    input_ids = [torch.tensor(x["input_ids"]) for x in batch]
    attention_mask = [torch.tensor(x["attention_mask"]) for x in batch]
    # max_len = max(x.size(0) for x in input_ids)
    max_len = 256
    batch_size = len(batch)
    
    padded_input_ids = torch.zeros((batch_size, max_len), dtype=torch.long)
    padded_attention_mask = torch.zeros((batch_size, max_len), dtype=torch.long)

    for i, (ids, mask) in enumerate(zip(input_ids, attention_mask)):
        padded_input_ids[i, :ids.size(0)] = ids[:max_len]
        padded_attention_mask[i, :mask.size(0)] = mask[:max_len]

    return {
        "input_ids": padded_input_ids,
        "attention_mask": padded_attention_mask,
    }

train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

In [None]:
model.train()

# warmup
print("Warming up...")
for batch in val_dataloader:
    ids = batch["input_ids"].to(model.device)
    attn_mask = batch["attention_mask"].to(model.device)
    loss = model(input_ids=ids, attention_mask=attn_mask, labels=ids).loss
    loss.backward()
    optimizer.zero_grad()

times = []

for idx, batch in enumerate(tqdm(train_dataloader)):
    t0 = time.time()
    ids = batch["input_ids"].to(model.device)
    attn_mask = batch["attention_mask"].to(model.device)
    loss = model(input_ids=ids, attention_mask=attn_mask, labels=ids).loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    times.append(time.time() - t0)

times = torch.tensor(times)
print(f"Mean: {times.mean() * 1000:.2f} ms")
print(f"Std: {times.std() * 1000:.2f} ms")
print(f"Max: {times.max() * 1000:.2f} ms")
print(f"Min: {times.min() * 1000:.2f} ms")
print(f"Median: {times.median() * 1000:.2f} ms")
print(f"90th percentile: {times.quantile(0.9) * 1000:.2f} ms")
print(f"95th percentile: {times.quantile(0.95) * 1000:.2f} ms")
print(f"99th percentile: {times.quantile(0.99) * 1000:.2f} ms")


## Baseline

```
Mean: 87.52 ms
Std: 25.75 ms
Max: 339.04 ms
Min: 64.74 ms
Median: 84.44 ms
90th percentile: 91.71 ms
95th percentile: 92.99 ms
99th percentile: 96.91 ms
```

## Compile max-autotune

```
Mean: 309.25 ms
Std: 342.35 ms
Max: 2898.89 ms
Min: 30.25 ms
Median: 417.34 ms
90th percentile: 450.42 ms
95th percentile: 456.16 ms
99th percentile: 1181.15 ms
```

## Compile max-autotune and fixed shape (max_len=1024)

```
Mean: 66.36 ms
Std: 3.33 ms
Max: 74.56 ms
Min: 54.87 ms
Median: 65.75 ms
90th percentile: 71.71 ms
95th percentile: 72.78 ms
99th percentile: 73.38 ms
```