In [17]:
import torch
import torch.nn as nn
from torch.amp import autocast

In [4]:
device = torch.device("cuda")

In [5]:
model = nn.Sequential(
    nn.Linear(4096, 4096),
    nn.ReLU(),
    nn.Linear(4096, 4096)
).to(device)

In [6]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [12]:
for _ in range(10):
    x = torch.randn(32, 4096, device=device)

    with autocast("cuda", dtype=torch.bfloat16):
        y = model(x)
        loss = y.pow(2).mean()

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    print("loss: ", loss.item())

loss:  0.0371742770075798
loss:  0.036889269948005676
loss:  0.036593616008758545
loss:  0.036329906433820724
loss:  0.03667307645082474
loss:  0.036307454109191895
loss:  0.03571825474500656
loss:  0.036094337701797485
loss:  0.036077916622161865
loss:  0.03585782274603844


In [13]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 410720 KiB | 672928 KiB |  12141 MiB |  11740 MiB |
|       from large pool | 409856 KiB | 672000 KiB |  11920 MiB |  11520 MiB |
|       from small pool |    864 KiB |   3937 KiB |    221 MiB |    220 MiB |
|---------------------------------------------------------------------------|
| Active memory         | 410720 KiB | 672928 KiB |  12141 MiB |  11740 MiB |
|       from large pool | 409856 KiB | 672000 KiB |  11920 MiB |  11520 MiB |
|       from small pool |    864 KiB |   3937 KiB |    221 MiB |    220 MiB |
|---------------------------------------------------------------

In [14]:
stats = torch.cuda.memory_stats()

def to_mb(x):
    return x / 1024 / 1024

print(f"Allocated: {to_mb(stats['allocated_bytes.all.current']):.2f} MB")
print(f"Reserved:  {to_mb(stats['reserved_bytes.all.current']):.2f} MB")


Allocated: 401.09 MB
Reserved:  730.00 MB


In [23]:
pip install --upgrade pip


Collecting pip
  Downloading pip-25.3-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.3-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m16.1 MB/s[0m  [33m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 25.2
    Uninstalling pip-25.2:
      Successfully uninstalled pip-25.2
Successfully installed pip-25.3
Note: you may need to restart the kernel to use updated packages.


In [31]:
!python -m pip install --no-cache-dir \
  "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl"


Collecting flash-attn==2.8.3+cu12torch2.8cxx11abitrue
  Downloading https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl (256.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m256.0/256.0 MB[0m [31m433.1 MB/s[0m  [33m0:00:00[0ma [36m0:00:01[0m
Collecting einops (from flash-attn==2.8.3+cu12torch2.8cxx11abitrue)
  Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)
Downloading einops-0.8.1-py3-none-any.whl (64 kB)
Installing collected packages: einops, flash-attn
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [flash-attn]2[0m [flash-attn]
[1A[2KSuccessfully installed einops-0.8.1 flash-attn-2.8.3


In [32]:
!python -m pip install flash-attn --no-build-isolation




In [33]:
from flash_attn import flash_attn_func

In [34]:
q = torch.randn(1, 4096, 32, 128, device="cuda", dtype=torch.bfloat16)
k = torch.randn_like(q)
v = torch.randn_like(q)

out = flash_attn_func(q, k, v, causal=True)
print(out.shape)

torch.Size([1, 4096, 32, 128])
