# Triton CNN (Colab Ready)
This notebook sets up Triton on a Colab GPU, builds a small CNN, implements Triton kernels (LayerNorm, GELU, Swish, and fused LayerNorm+GELU), then trains and benchmarks.

**How to use:** Runtime → Change runtime type → **GPU** → Save → Run cells from top to bottom.

In [1]:
#Verify GPU
!nvidia-smi


Mon Nov 10 03:49:25 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   44C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import os
os.environ["TRITON_BACKEND"] = "cuda"   # set before importing triton
%pip install -q --upgrade --force-reinstall "triton==3.4.0"



[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.[0m[31m
[0m

In [3]:
import torch, triton
print("torch.cuda.is_available():", torch.cuda.is_available())
print("Torch:", torch.__version__, "| CUDA in torch:", torch.version.cuda)
print("Triton:", triton.__version__)

from triton.runtime import driver
print("Active driver:", driver.active)


torch.cuda.is_available(): True
Torch: 2.8.0+cu126 | CUDA in torch: 12.6
Triton: 3.4.0
Active driver: <triton.backends.nvidia.driver.CudaDriver object at 0x7f6190861610>


In [4]:
%%bash
mkdir -p triton-cnn/src/triton_ops
mkdir -p triton-cnn/data
mkdir -p triton-cnn/experiments
echo "Folders ready."




Folders ready.


In [5]:
%%writefile triton-cnn/src/datasets.py
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
def get_mnist_loaders(batch_size=64, root="/content/triton-cnn/data"):
    tfm = transforms.Compose([transforms.ToTensor()])
    train = datasets.MNIST(root, train=True, download=True, transform= tfm)
    test  = datasets.MNIST(root, train=False, download=True, transform= tfm)
    return (DataLoader(train, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True),
            DataLoader(test,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True))
print("datasets.py written")



Overwriting triton-cnn/src/datasets.py


In [6]:
%%writefile triton-cnn/src/triton_ops/layernorm.py
import torch, triton, triton.language as tl
@triton.jit
def _layernorm_row(X, Y, G, B, eps, C, BLOCK: tl.constexpr):
    r = tl.program_id(0)
    offs = r*C + tl.arange(0, BLOCK)
    m = tl.arange(0, BLOCK) < C
    x = tl.load(X + offs, mask=m, other=0.)
    mu  = tl.sum(x, axis=0) / C
    xc  = x - mu
    var = tl.sum(xc*xc, axis=0) / C
    inv = tl.math.rsqrt(var + eps)
    g = tl.load(G + tl.arange(0, BLOCK), mask=m, other=1.)
    b = tl.load(B + tl.arange(0, BLOCK), mask=m, other=0.)
    y = (xc * inv) * g + b
    tl.store(Y + offs, y, mask=m)
def layernorm_triton(x, gamma, beta, eps=1e-5):
    rows, cols = x.shape
    y = torch.empty_like(x)
    BLOCK = min(1024, 1 << (cols-1).bit_length())
    _layernorm_row[(rows,)](x, y, gamma, beta, eps, cols, BLOCK=BLOCK)
    return y
print("layernorm.py written")


Overwriting triton-cnn/src/triton_ops/layernorm.py


In [7]:
%%writefile triton-cnn/src/triton_ops/gelu.py
import torch, triton, triton.language as tl

@triton.jit
def _gelu(X, Y, N: tl.constexpr, BLOCK: tl.constexpr):
    pid  = tl.program_id(0)
    offs = pid*BLOCK + tl.arange(0, BLOCK)
    m = offs < N
    x = tl.load(X + offs, mask=m, other=0.)
    inv_sqrt2 = 0.7071067811865476  # 1/sqrt(2)
    y = 0.5 * x * (1 + tl.math.erf(x * inv_sqrt2))
    tl.store(Y + offs, y, mask=m)

def gelu_triton(x):
    y = torch.empty_like(x)
    N = x.numel()
    BLOCK = 1024
    _gelu[ ((N+BLOCK-1)//BLOCK,) ](x, y, N, BLOCK=BLOCK)
    return y
print("gelu.py written (erf-based GELU)")





Overwriting triton-cnn/src/triton_ops/gelu.py


In [8]:
%%writefile triton-cnn/src/triton_ops/swish.py
import torch, triton, triton.language as tl
@triton.jit
def _swish(X, Y, N: tl.constexpr, BLOCK: tl.constexpr):
    pid  = tl.program_id(0)
    offs = pid*BLOCK + tl.arange(0, BLOCK)
    m = offs < N
    x = tl.load(X + offs, mask=m, other=0.)
    y = x / (1 + tl.exp(-x))
    tl.store(Y + offs, y, mask=m)
def swish_triton(x):
    y = torch.empty_like(x); N=x.numel(); BLOCK=1024
    _swish[ ((N+BLOCK-1)//BLOCK,) ](x, y, N, BLOCK=BLOCK)
    return y
print("swish.py written")



Overwriting triton-cnn/src/triton_ops/swish.py


In [9]:
%%writefile triton-cnn/src/triton_ops/fused_ln_gelu.py
import torch, triton, triton.language as tl

@triton.jit
def _fused_ln_gelu(X, Y, G, B, eps, C, BLOCK: tl.constexpr):
    r = tl.program_id(0)
    offs = r*C + tl.arange(0, BLOCK)
    m = tl.arange(0, BLOCK) < C
    x = tl.load(X + offs, mask=m, other=0.)

    # LayerNorm
    mu  = tl.sum(x, axis=0)/C
    xc  = x - mu
    var = tl.sum(xc*xc, axis=0)/C
    inv = tl.math.rsqrt(var + eps)
    g = tl.load(G + tl.arange(0, BLOCK), mask=m, other=1.)
    b = tl.load(B + tl.arange(0, BLOCK), mask=m, other=0.)
    y = (xc * inv) * g + b

    # GELU (exact, erf-based)
    inv_sqrt2 = 0.7071067811865476
    y = 0.5 * y * (1 + tl.math.erf(y * inv_sqrt2))

    tl.store(Y + offs, y, mask=m)

def fused_ln_gelu_triton(x, gamma, beta, eps=1e-5):
    rows, cols = x.shape
    y = torch.empty_like(x)
    BLOCK = min(1024, 1 << (cols-1).bit_length())
    _fused_ln_gelu[(rows,)](x, y, gamma, beta, eps, cols, BLOCK=BLOCK)
    return y
print("fused_ln_gelu.py written (erf-based GELU)")





Overwriting triton-cnn/src/triton_ops/fused_ln_gelu.py


In [10]:
%%writefile triton-cnn/src/model_triton.py
import torch, torch.nn as nn
from triton_ops.layernorm import layernorm_triton
from triton_ops.gelu import gelu_triton
from triton_ops.swish import swish_triton
from triton_ops.fused_ln_gelu import fused_ln_gelu_triton

class TritonCNN(nn.Module):
    def __init__(self, activation="gelu", use_fused=False, hidden=128):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool  = nn.MaxPool2d(2)   # 28x28 -> 14x14
        self.fc1   = nn.Linear(64*14*14, hidden)   # <-- use this
        self.fc2   = nn.Linear(hidden, 10)
        self.g = nn.Parameter(torch.ones(hidden))
        self.b = nn.Parameter(torch.zeros(hidden))
        self.activation = activation
        self.use_fused  = use_fused

    def forward(self, x):
        # conv stack
        x = self.conv1(x).relu()
        x = self.pool(self.conv2(x).relu())
        # flatten
        x = x.view(x.size(0), -1)          # [B, 64*14*14] = [B, 12544]
        # project to hidden
        x = self.fc1(x)                    # [B, hidden] = [B, 128]
        # normalize + activation (or fused)
        if self.use_fused and self.activation == "gelu":
            x = fused_ln_gelu_triton(x, self.g, self.b)
        else:
            x = layernorm_triton(x, self.g, self.b)
            if self.activation == "gelu":
                x = gelu_triton(x)
            elif self.activation == "swish":
                x = swish_triton(x)
        # classify
        return self.fc2(x)

print("model_triton.py written (uses fc1 before LN/activation)")




Overwriting triton-cnn/src/model_triton.py


In [11]:
%%writefile triton-cnn/src/train.py
import argparse, torch, torch.nn as nn
from datasets import get_mnist_loaders
from model_triton import TritonCNN

def train_one_epoch(model, loader, device, criterion, opt=None):
    model.train(); total=correct=0; running=0.0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        out = model(x); loss = criterion(out,y)
        if opt: opt.zero_grad(); loss.backward(); opt.step()
        running += loss.item()*x.size(0)
        correct += (out.argmax(1)==y).sum().item(); total += x.size(0)
    return running/total, correct/total

def eval_epoch(model, loader, device, criterion):
    model.eval(); total=correct=0; running=0.0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            out = model(x); loss = criterion(out,y)
            running += loss.item()*x.size(0)
            correct += (out.argmax(1)==y).sum().item(); total += x.size(0)
    return running/total, correct/total

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--epochs", type=int, default=1)
    ap.add_argument("--batch", type=int, default=64)
    ap.add_argument("--act", choices=["gelu","swish"], default="gelu")
    ap.add_argument("--fused", action="store_true")
    args = ap.parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    train_loader, test_loader = get_mnist_loaders(args.batch)
    model = TritonCNN(activation=args.act, use_fused=args.fused).to(device)
    crit = nn.CrossEntropyLoss(); opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    for ep in range(1, args.epochs+1):
        tl, ta = train_one_epoch(model, train_loader, device, crit, opt)
        vl, va = eval_epoch(model, test_loader, device, crit)
        print(f"Epoch {ep}: train loss {tl:.4f} acc {ta:.4f} | test loss {vl:.4f} acc {va:.4f}")
if __name__ == "__main__": main()
print("train.py written")



Overwriting triton-cnn/src/train.py


In [12]:
%%writefile triton-cnn/src/benchmark.py
import argparse, torch
from datasets import get_mnist_loaders
from model_triton import TritonCNN
def time_batches(model, loader, device, iters=200, warmup=50):
    model.eval(); x,_ = next(iter(loader)); x=x.to(device)
    for _ in range(warmup):
        with torch.no_grad(): _ = model(x)
    start = torch.cuda.Event(True); end = torch.cuda.Event(True)
    torch.cuda.synchronize(); start.record()
    for _ in range(iters):
        with torch.no_grad(): _ = model(x)
    end.record(); torch.cuda.synchronize()
    return start.elapsed_time(end)/iters
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--act", choices=["gelu","swish"], default="gelu")
    ap.add_argument("--fused", action="store_true")
    ap.add_argument("--batch", type=int, default=64)
    ap.add_argument("--iters", type=int, default=200)
    ap.add_argument("--warmup", type=int, default=50)
    args = ap.parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    train_loader, _ = get_mnist_loaders(args.batch)
    model = TritonCNN(activation=args.act, use_fused=args.fused).to(device)
    ms = time_batches(model, train_loader, device, iters=args.iters, warmup=args.warmup)
    print(f"Triton act={args.act} fused={args.fused} batch={args.batch} -> {ms:.3f} ms/batch")
if __name__ == "__main__": main()
print("benchmark.py written")



Overwriting triton-cnn/src/benchmark.py


## Run training & benchmarks

In [13]:
%cd /content/triton-cnn/src
!python train.py --epochs 1 --batch 64 --act gelu



/content/triton-cnn/src
datasets.py written
layernorm.py written
gelu.py written (erf-based GELU)
swish.py written
fused_ln_gelu.py written (erf-based GELU)
model_triton.py written (uses fc1 before LN/activation)
Epoch 1: train loss 0.9872 acc 0.7802 | test loss 0.5870 acc 0.8544
train.py written


In [19]:
#fused Triton version
%cd /content/triton-cnn/src
!python train.py --epochs 1 --batch 64 --act gelu --fused


/content/triton-cnn/src
datasets.py written
layernorm.py written
gelu.py written (erf-based GELU)
swish.py written
fused_ln_gelu.py written (erf-based GELU)
model_triton.py written (uses fc1 before LN/activation)
Epoch 1: train loss 1.1848 acc 0.7437 | test loss 0.6989 acc 0.8478
train.py written


In [20]:

!python train.py --epochs 1 --batch 64 --act swish


datasets.py written
layernorm.py written
gelu.py written (erf-based GELU)
swish.py written
fused_ln_gelu.py written (erf-based GELU)
model_triton.py written (uses fc1 before LN/activation)
Epoch 1: train loss 0.9769 acc 0.7860 | test loss 0.5499 acc 0.8696
train.py written


In [22]:
# higher accuracy
!python train.py --epochs 3 --batch 64 --act gelu --fused


datasets.py written
layernorm.py written
gelu.py written (erf-based GELU)
swish.py written
fused_ln_gelu.py written (erf-based GELU)
model_triton.py written (uses fc1 before LN/activation)
Epoch 1: train loss 1.0330 acc 0.7756 | test loss 0.6042 acc 0.8565
Epoch 2: train loss 0.5314 acc 0.8630 | test loss 0.4507 acc 0.8831
Epoch 3: train loss 0.4367 acc 0.8800 | test loss 0.3919 acc 0.8950
train.py written


In [15]:

!python benchmark.py --act gelu --batch 64
!python benchmark.py --act gelu --batch 64 --fused
!python benchmark.py --act swish --batch 64


datasets.py written
layernorm.py written
gelu.py written (erf-based GELU)
swish.py written
fused_ln_gelu.py written (erf-based GELU)
model_triton.py written (uses fc1 before LN/activation)
Triton act=gelu fused=False batch=64 -> 1.304 ms/batch
benchmark.py written
datasets.py written
layernorm.py written
gelu.py written (erf-based GELU)
swish.py written
fused_ln_gelu.py written (erf-based GELU)
model_triton.py written (uses fc1 before LN/activation)
Triton act=gelu fused=True batch=64 -> 1.228 ms/batch
benchmark.py written
datasets.py written
layernorm.py written
gelu.py written (erf-based GELU)
swish.py written
fused_ln_gelu.py written (erf-based GELU)
model_triton.py written (uses fc1 before LN/activation)
Triton act=swish fused=False batch=64 -> 1.300 ms/batch
benchmark.py written


In [21]:
#for a picture
from PIL import Image
import torchvision.transforms as transforms
import torch
from google.colab import files
from model_triton import TritonCNN

# upload an image of a single handwritten digit
uploaded = files.upload()
fname = list(uploaded.keys())[0]

# make it look like MNIST
img = Image.open(fname).convert('L').resize((28, 28))
tfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
x = tfm(img).unsqueeze(0).to("cuda")   # (1,1,28,28)

# use the (already trained) model shape
model = TritonCNN(activation="gelu", use_fused=True).to("cuda")
model.eval()
with torch.no_grad():
    out = model(x)
    print("Prediction:", out.argmax(1).item())


layernorm.py written
gelu.py written (erf-based GELU)
swish.py written
fused_ln_gelu.py written (erf-based GELU)
model_triton.py written (uses fc1 before LN/activation)


Saving Screenshot 2025-11-05 193455.png to Screenshot 2025-11-05 193455.png
Prediction: 3
