
# Vision Multi-Task Notebook — Custom Models (runnable)
This notebook provides **self-contained implementations** of several vision tasks **using models written here** (no external pretrained models for the core custom architectures).
It also keeps the option to **load torchvision's ResNet50** for classification if you want to compare.

Features present for every task:
1. Dataset balance check
2. Augmentations using `torchvision.transforms`
3. Downloading datasets from torchvision (where appropriate) or generating small synthetic datasets so training is runnable
4. Ability to load an external pretrained model (ResNet50) for classification as an option
5. Loading a checkpoint if it exists
6. Saving checkpoints during training
7. Inference function that loads checkpoint and runs model on sample images
8. Plots: data samples, model predictions, loss curves, and TensorBoard logging

Tasks included (each is runnable):
- Classification (CIFAR10) — Custom small CNN + optional ResNet50 head
- Semantic Segmentation — Small UNet trained on synthetic shapes dataset
- Super-Resolution — SRCNN trained on CIFAR10 downsampled images
- GAN (DCGAN) — Generator/Discriminator trained on MNIST
- Simple Object Detection (toy) — single-object images with one rectangle; model predicts box coords + class
- Keypoint Regression (toy) — predict single keypoint in synthetic images (circle center)

Run cells per-section. Heavy training cells are small by default (1-3 epochs) as smoke-tests — increase epochs for real training.


In [3]:

# Utilities: plotting, checkpointing, tensorboard logging, balance check, augmentations
import os, math, torch, torch.nn as nn, torch.optim as optim
import torchvision, torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from collections import Counter
import random
import numpy as np
from PIL import Image, ImageDraw

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('DEVICE =', DEVICE)

def check_data_balance(dataset, name='dataset'):
    labels = []
    if hasattr(dataset, '__len__') and len(dataset)>0:
        for i in range(min(len(dataset), 10000)):  # limit for speed
            item = dataset[i]
            if isinstance(item, tuple) and len(item)>1:
                y = item[1]
            else:
                y = item
            # handle when y is dict or tensor
            if isinstance(y, dict) and 'label' in y:
                y = y['label']
            if isinstance(y, torch.Tensor):
                y = y.item() if y.numel()==1 else y.tolist()
            labels.append(y)
    counter = Counter(labels)
    print(f'=== Balance for {name} (showing up to 20 classes) ===')
    for k,c in list(counter.items())[:20]:
        print(f'class {k}: {c}')
    return counter

def imshow(img, title=None, unnormalize=True):
    if isinstance(img, torch.Tensor):
        img = img.cpu()
        if unnormalize:
            img = img*0.5 + 0.5
        img = img.permute(1,2,0).numpy()
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.axis('off')

def imshow_batch(images, titles=None, n=6):
    plt.figure(figsize=(15,3))
    for i in range(n):
        plt.subplot(1,n,i+1)
        im = images[i]
        imshow(im, title=(titles[i] if titles else None))

def save_checkpoint(model, optimizer, epoch, path):
    os.makedirs(os.path.dirname(path) or '.', exist_ok=True)
    torch.save({'model_state': model.state_dict(), 'optim_state': optimizer.state_dict() if optimizer else None, 'epoch': epoch}, path)
    print('Saved checkpoint to', path)

def load_checkpoint_if_exists(model, optimizer, path):
    if os.path.exists(path):
        ckpt = torch.load(path, map_location=DEVICE)
        model.load_state_dict(ckpt['model_state'])
        if optimizer and ckpt.get('optim_state') is not None:
            optimizer.load_state_dict(ckpt['optim_state'])
        print('Loaded checkpoint from', path, 'epoch', ckpt.get('epoch'))
        return ckpt.get('epoch',0)
    else:
        print('No checkpoint at', path)
        return 0

def make_writer(logdir='runs/exp'):
    os.makedirs(logdir, exist_ok=True)
    return SummaryWriter(logdir)


DEVICE = cpu


  from .autonotebook import tqdm as notebook_tqdm



## Classification — Custom CNN (optionally compare with torchvision ResNet50)

Dataset: CIFAR10 (downloaded via torchvision).

We provide:
- Custom simple CNN defined below
- Option to use torchvision.models.resnet50 (not pretrained by default) for comparison
- Full training loop, checkpointing, inference, plotting, TensorBoard logging


In [4]:

# Classification: data loaders, models, train/eval, inference
from torchvision import datasets, models

def get_cifar10_loaders(batch_size=128, augment=True):
    mean, std = (0.4914, 0.4822, 0.4465), (0.247,0.243,0.261)
    if augment:
        train_tfms = T.Compose([T.RandomHorizontalFlip(), T.RandomCrop(32, padding=4), T.ToTensor(), T.Normalize(mean,std)])
    else:
        train_tfms = T.Compose([T.ToTensor(), T.Normalize(mean,std)])
    test_tfms = T.Compose([T.ToTensor(), T.Normalize(mean,std)])
    train_set = datasets.CIFAR10('./data', train=True, download=True, transform=train_tfms)
    test_set = datasets.CIFAR10('./data', train=False, download=True, transform=test_tfms)
    return DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2), train_set.classes, train_set

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3,32,3,padding=1), nn.ReLU(), nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,padding=1), nn.ReLU(), nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
            nn.Conv2d(64,128,3,padding=1), nn.ReLU(), nn.BatchNorm2d(128),
            nn.AdaptiveAvgPool2d(1)
        )
        self.classifier = nn.Linear(128, num_classes)
    def forward(self,x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

def train_classification(model, train_loader, test_loader, epochs=3, lr=1e-3, ckpt_path='ckpts/cls.pth', use_tensorboard=True):
    model = model.to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    start_epoch = load_checkpoint_if_exists(model, optimizer, ckpt_path)
    writer = make_writer('runs/classification') if use_tensorboard else None
    history = {'train_loss':[], 'train_acc':[], 'val_loss':[], 'val_acc':[]}
    for e in range(start_epoch, epochs):
        model.train()
        total_loss, total, correct = 0,0,0
        for x,y in tqdm(train_loader):
            x,y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out,y)
            loss.backward(); optimizer.step()
            total_loss += loss.item()*x.size(0)
            _,pred = out.max(1); correct += pred.eq(y).sum().item(); total += y.size(0)
        train_loss = total_loss/len(train_loader.dataset); train_acc = correct/total
        val_loss, val_acc = eval_classification(model, test_loader, criterion)
        history['train_loss'].append(train_loss); history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss); history['val_acc'].append(val_acc)
        print(f'Epoch {e+1}/{epochs} train_loss={train_loss:.4f} train_acc={train_acc:.4f} val_loss={val_loss:.4f} val_acc={val_acc:.4f}')
        if writer:
            writer.add_scalars('loss', {'train':train_loss, 'val':val_loss}, e)
            writer.add_scalars('acc', {'train':train_acc, 'val':val_acc}, e)
        save_checkpoint(model, optimizer, e+1, ckpt_path)
    if writer: writer.close()
    return history

def eval_classification(model, loader, criterion):
    model.eval()
    total_loss, total, correct = 0,0,0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(DEVICE), y.to(DEVICE)
            out = model(x); loss = criterion(out,y)
            total_loss += loss.item()*x.size(0)
            _,pred = out.max(1); correct += pred.eq(y).sum().item(); total += y.size(0)
    return total_loss/len(loader.dataset), correct/total

def inference_classification(ckpt_path, dataloader, model_builder='custom', num_classes=10):
    if model_builder=='resnet50':
        model = models.resnet50(pretrained=False); model.fc = nn.Linear(model.fc.in_features, num_classes)
    else:
        model = SimpleCNN(num_classes=num_classes)
    _ = load_checkpoint_if_exists(model, None, ckpt_path)
    model = model.to(DEVICE).eval()
    imgs, labels = next(iter(dataloader))
    imgs_cpu = imgs.clone()
    imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
    with torch.no_grad():
        out = model(imgs)
        _, preds = out.max(1)
    # plot few
    titles = [f'P:{preds[i].item()} / G:{labels[i].item()}' for i in range(6)]
    imshow_batch(imgs_cpu, titles, n=6)



## Semantic Segmentation — UNet on Synthetic Shapes

We generate synthetic images with random colored rectangles/circles on a background and create pixel-wise masks for training a small UNet.
This makes the example fully runnable without heavy datasets.


In [None]:

# Synthetic segmentation dataset
class ShapesDataset(Dataset):
    def __init__(self, n=1000, size=128, transforms=None):
        self.n = n; self.size = size; self.transforms = transforms
    def __len__(self): return self.n
    def __getitem__(self, idx):
        # create blank image and mask
        img = Image.new('RGB', (self.size,self.size), (0,0,0))
        mask = Image.new('L', (self.size,self.size), 0)
        draw = ImageDraw.Draw(img); md = ImageDraw.Draw(mask)
        # random shape
        shape_type = random.choice(['rect','circle'])
        if shape_type=='rect':
            x0 = random.randint(10,self.size//2); y0 = random.randint(10,self.size//2)
            x1 = random.randint(self.size//2,self.size-10); y1 = random.randint(self.size//2,self.size-10)
            color = tuple(random.randint(50,255) for _ in range(3))
            draw.rectangle([x0,y0,x1,y1], fill=color)
            md.rectangle([x0,y0,x1,y1], fill=1)
        else:
            cx = random.randint(20,self.size-20); cy = random.randint(20,self.size-20); r = random.randint(10, self.size//3)
            color = tuple(random.randint(50,255) for _ in range(3))
            draw.ellipse([cx-r,cy-r,cx+r,cy+r], fill=color)
            md.ellipse([cx-r,cy-r,cx+r,cy+r], fill=1)
        img_t = T.ToTensor()(img)*2-1  # range [-1,1]
        mask_t = torch.from_numpy(np.array(mask)).long()
        mask_t = (mask_t>0).long()
        return img_t, mask_t

# Simple UNet
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(nn.Conv2d(in_ch,out_ch,3,padding=1), nn.ReLU(), nn.Conv2d(out_ch,out_ch,3,padding=1), nn.ReLU())
    def forward(self,x): return self.net(x)

class UNet(nn.Module):
    def __init__(self, n_classes=2):
        super().__init__()
        self.d1 = DoubleConv(3,32); self.p1 = nn.MaxPool2d(2)
        self.d2 = DoubleConv(32,64); self.p2 = nn.MaxPool2d(2)
        self.d3 = DoubleConv(64,128); self.up2 = nn.ConvTranspose2d(128,64,2,stride=2)
        self.d4 = DoubleConv(128,64); self.up1 = nn.ConvTranspose2d(64,32,2,stride=2)
        self.outc = nn.Conv2d(32,n_classes,1)
    def forward(self,x):
        x1 = self.d1(x); x2 = self.p1(x1)
        x3 = self.d2(x2); x4 = self.p2(x3)
        x5 = self.d3(x4); x6 = self.up2(x5)
        x6 = torch.cat([x6,x3], dim=1); x7 = self.d4(x6)
        x8 = self.up1(x7); x8 = torch.cat([x8,x1], dim=1)
        return self.outc(x8)

# training loop for segmentation (pixel-wise CE)
def train_segmentation(model, train_loader, val_loader, epochs=3, lr=1e-3, ckpt='ckpts/seg.pth'):
    model = model.to(DEVICE); opt = optim.Adam(model.parameters(), lr=lr); criterion = nn.CrossEntropyLoss()
    start = load_checkpoint_if_exists(model, opt, ckpt)
    writer = make_writer('runs/seg')
    for e in range(start, epochs):
        model.train()
        tloss=0; tot=0
        for x,y in tqdm(train_loader):
            x,y = x.to(DEVICE), y.to(DEVICE)
            opt.zero_grad(); out = model(x)  # (N,C,H,W)
            loss = criterion(out, y)
            loss.backward(); opt.step()
            tloss += loss.item()*x.size(0); tot += x.size(0)
        val_loss = 0
        model.eval()
        with torch.no_grad():
            for x,y in val_loader:
                x,y = x.to(DEVICE), y.to(DEVICE)
                out = model(x); val_loss += criterion(out,y).item()*x.size(0)
        val_loss /= len(val_loader.dataset)
        print(f'Epoch {e+1}/{epochs} train_loss={tloss/len(train_loader.dataset):.4f} val_loss={val_loss:.4f}')
        writer.add_scalars('loss', {'train':tloss/len(train_loader.dataset),'val':val_loss}, e)
        save_checkpoint(model, opt, e+1, ckpt)
    writer.close()

def inference_segmentation(ckpt, model, dataloader):
    _ = load_checkpoint_if_exists(model, None, ckpt)
    model.to(DEVICE).eval()
    x,y = next(iter(dataloader))
    with torch.no_grad():
        out = model(x.to(DEVICE))
        pred = out.argmax(1).cpu()
    # plot
    for i in range(4):
        plt.figure(figsize=(8,3))
        plt.subplot(1,3,1); imshow((x[i]+1)/2); plt.title('Image')
        plt.subplot(1,3,2); plt.imshow(y[i].numpy()); plt.title('GT Mask'); plt.axis('off')
        plt.subplot(1,3,3); plt.imshow(pred[i].numpy()); plt.title('Pred Mask'); plt.axis('off')
        plt.show()



## Super-Resolution — SRCNN on CIFAR10 downsampled images
A simple SRCNN-like model is trained to reconstruct CIFAR images from a downsampled+upsampled version.


In [None]:

class SRCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3,64,9,padding=4), nn.ReLU(),
            nn.Conv2d(64,32,5,padding=2), nn.ReLU(),
            nn.Conv2d(32,3,5,padding=2)
        )
    def forward(self,x): return self.net(x)

def train_sr(model, train_loader, epochs=3, lr=1e-3, ckpt='ckpts/sr.pth'):
    model = model.to(DEVICE); opt = optim.Adam(model.parameters(), lr=lr); criterion = nn.MSELoss()
    start = load_checkpoint_if_exists(model, opt, ckpt)
    writer = make_writer('runs/sr')
    for e in range(start, epochs):
        model.train(); total_loss=0
        for x,_ in tqdm(train_loader):
            # create low-res input
            x = x.to(DEVICE)
            lr = nn.functional.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
            lr_up = nn.functional.interpolate(lr, scale_factor=2.0, mode='bilinear', align_corners=False)
            opt.zero_grad(); out = model(lr_up); loss = criterion(out, x); loss.backward(); opt.step()
            total_loss += loss.item()*x.size(0)
        avg = total_loss/len(train_loader.dataset)
        print(f'Epoch {e+1}/{epochs} loss={avg:.6f}')
        writer.add_scalar('loss/train', avg, e); save_checkpoint(model,opt,e+1,ckpt)
    writer.close()

def inference_sr(ckpt, model, dataloader):
    _ = load_checkpoint_if_exists(model, None, ckpt)
    model.to(DEVICE).eval()
    imgs, _ = next(iter(dataloader))
    imgs = imgs.to(DEVICE)
    lr = nn.functional.interpolate(imgs, scale_factor=0.5, mode='bilinear', align_corners=False)
    lr_up = nn.functional.interpolate(lr, scale_factor=2.0, mode='bilinear', align_corners=False)
    with torch.no_grad():
        out = model(lr_up).cpu()
    # plot inputs and outputs
    for i in range(4):
        plt.figure(figsize=(8,3))
        plt.subplot(1,3,1); imshow((imgs[i].cpu()+1)/2); plt.title('GT')
        plt.subplot(1,3,2); imshow((lr_up[i].cpu()+1)/2); plt.title('LR upsampled')
        plt.subplot(1,3,3); imshow((out[i]+1)/2); plt.title('SRCNN out')
        plt.show()



## Generative Model — DCGAN on MNIST
Simple DCGAN implementation using models defined in the notebook and training loop that saves checkpoints and plots samples.


In [None]:

# DCGAN generator and discriminator (MNIST)
class DCGAN_G(nn.Module):
    def __init__(self, zdim=100):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(zdim, 128, 4, 1, 0, bias=False),
            nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self,z): return self.net(z)

class DCGAN_D(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1,64,4,2,1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64,128,4,2,1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(), nn.Linear(128*7*7,1), nn.Sigmoid()
        )
    def forward(self,x): return self.net(x)

def train_dcgan(G, D, dataloader, epochs=3, ckpt='ckpts/dcgan.pth'):
    G, D = G.to(DEVICE), D.to(DEVICE)
    optG = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5,0.999))
    optD = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5,0.999))
    criterion = nn.BCELoss()
    start = 0
    writer = make_writer('runs/dcgan')
    for e in range(epochs):
        for x,_ in tqdm(dataloader):
            real = x.to(DEVICE)
            bsz = real.size(0)
            # train D
            optD.zero_grad()
            label_real = torch.ones(bsz,1, device=DEVICE); label_fake = torch.zeros(bsz,1, device=DEVICE)
            out_real = D(real).view(-1,1); lossD_real = criterion(out_real, label_real)
            z = torch.randn(bsz,100,1,1, device=DEVICE); fake = G(z)
            out_fake = D(fake.detach()).view(-1,1); lossD_fake = criterion(out_fake, label_fake)
            lossD = lossD_real + lossD_fake; lossD.backward(); optD.step()
            # train G
            optG.zero_grad(); out_fake2 = D(fake).view(-1,1); lossG = criterion(out_fake2, label_real); lossG.backward(); optG.step()
        # log and save sample
        print(f'Epoch {e+1}/{epochs} lossD={lossD.item():.4f} lossG={lossG.item():.4f}')
        with torch.no_grad():
            z = torch.randn(16,100,1,1, device=DEVICE); samples = G(z).cpu()
        fig, axs = plt.subplots(1,8, figsize=(12,2))
        for i in range(8):
            axs[i].imshow((samples[i,0]+1)/2, cmap='gray'); axs[i].axis('off')
        plt.show()
        save_checkpoint(G, optG, e+1, ckpt + '_G')
        save_checkpoint(D, optD, e+1, ckpt + '_D')
    writer.close()



## Toy Object Detection — Single-object bounding box regression (synthetic dataset)

We create images with one colored rectangle; the model predicts bounding box coordinates (x_min,y_min,x_max,y_max) normalized to [0,1] and class (single class).
This is a simplified detection setup (not COCO format), but it's fully trainable and demonstrates checkpointing and inference.


In [None]:

# Synthetic detection dataset
class BoxDataset(Dataset):
    def __init__(self, n=1000, size=128, transforms=None):
        self.n=n; self.size=size; self.transforms=transforms
    def __len__(self): return self.n
    def __getitem__(self, idx):
        img = Image.new('RGB',(self.size,self.size),(0,0,0)); draw=ImageDraw.Draw(img)
        x0 = random.randint(5, self.size//2); y0 = random.randint(5, self.size//2)
        x1 = random.randint(self.size//2, self.size-5); y1 = random.randint(self.size//2, self.size-5)
        color = tuple(random.randint(50,255) for _ in range(3)); draw.rectangle([x0,y0,x1,y1], fill=color)
        img_t = T.ToTensor()(img)*2-1
        # normalized box coords and class 0
        box = torch.tensor([x0/self.size, y0/self.size, x1/self.size, y1/self.size], dtype=torch.float32)
        cls = torch.tensor(0, dtype=torch.long)
        return img_t, {'box':box, 'label':cls}

# Simple detection model: feature extractor + head for bbox regression + classification
class SimpleDet(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(nn.Conv2d(3,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
                                      nn.Conv2d(32,64,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
                                      nn.AdaptiveAvgPool2d(1))
        self.fc_box = nn.Linear(64,4)
        self.fc_cls = nn.Linear(64,1)
    def forward(self,x):
        f = self.backbone(x).view(x.size(0), -1)
        box = torch.sigmoid(self.fc_box(f))  # normalized
        cls = torch.sigmoid(self.fc_cls(f)).squeeze(1)
        return box, cls

def train_detection(model, dataset, epochs=5, batch_size=32, ckpt='ckpts/det.pth'):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    model = model.to(DEVICE); opt = optim.Adam(model.parameters(), lr=1e-3)
    l1 = nn.L1Loss(); bce = nn.BCELoss()
    start = load_checkpoint_if_exists(model, opt, ckpt)
    writer = make_writer('runs/det')
    for e in range(start, epochs):
        model.train(); total_loss=0
        for x,y in tqdm(loader):
            x = x.to(DEVICE)
            boxes = torch.stack([item['box'] for item in y]).to(DEVICE)
            labels = torch.stack([item['label'] for item in y]).float().to(DEVICE)
            opt.zero_grad(); pred_box, pred_cls = model(x)
            loss = l1(pred_box, boxes) + bce(pred_cls, labels)
            loss.backward(); opt.step()
            total_loss += loss.item()*x.size(0)
        avg = total_loss/len(loader.dataset)
        print(f'Epoch {e+1}/{epochs} loss={avg:.4f}')
        writer.add_scalar('loss/train', avg, e); save_checkpoint(model,opt,e+1,ckpt)
    writer.close()

def inference_detection(ckpt, model, dataset):
    _ = load_checkpoint_if_exists(model, None, ckpt)
    model.to(DEVICE).eval()
    x,y = dataset[0]
    with torch.no_grad():
        box, cls = model(x.unsqueeze(0).to(DEVICE))
    box = box[0].cpu().numpy()*dataset.size
    print('Pred box coords (px):', box)
    # show image with box
    img = (x+1)/2; plt.imshow(img.permute(1,2,0)); ax=plt.gca()
    rect = plt.Rectangle((box[0],box[1]), box[2]-box[0], box[3]-box[1], edgecolor='r', facecolor='none', linewidth=2); ax.add_patch(rect)
    plt.axis('off'); plt.show()



## Keypoint Regression — Predict center of a circle (synthetic)

We create images with a single filled circle and train a model to regress its (x,y) center normalized to [0,1].


In [None]:

class KeypointDataset(Dataset):
    def __init__(self, n=1000, size=128):
        self.n=n; self.size=size
    def __len__(self): return self.n
    def __getitem__(self, idx):
        img = Image.new('RGB',(self.size,self.size),(0,0,0)); draw=ImageDraw.Draw(img)
        cx = random.randint(20,self.size-20); cy = random.randint(20,self.size-20); r=random.randint(5,20)
        color = tuple(random.randint(50,255) for _ in range(3))
        draw.ellipse([cx-r,cy-r,cx+r,cy+r], fill=color)
        img_t = T.ToTensor()(img)*2-1
        kp = torch.tensor([cx/self.size, cy/self.size], dtype=torch.float32)
        return img_t, kp

class KeypointModel(nn.Module):
    def __init__(self): super().__init__(); self.net = nn.Sequential(
        nn.Conv2d(3,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Conv2d(32,64,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64,2)
    )
    def forward(self,x): return torch.sigmoid(self.net(x))  # normalized

def train_keypoint(model, dataset, epochs=5, batch_size=32, ckpt='ckpts/kp.pth'):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    model = model.to(DEVICE); opt = optim.Adam(model.parameters(), lr=1e-3); lossf = nn.MSELoss()
    start = load_checkpoint_if_exists(model,opt,ckpt); writer=make_writer('runs/kp')
    for e in range(start, epochs):
        model.train(); total=0
        for x,y in tqdm(loader):
            x,y = x.to(DEVICE), y.to(DEVICE)
            opt.zero_grad(); pred = model(x); loss = lossf(pred,y); loss.backward(); opt.step()
            total += loss.item()*x.size(0)
        print(f'Epoch {e+1}/{epochs} loss={total/len(loader.dataset):.6f}'); writer.add_scalar('loss/train', total/len(loader.dataset), e); save_checkpoint(model,opt,e+1,ckpt)
    writer.close()

def inference_kp(ckpt, model, dataset):
    _ = load_checkpoint_if_exists(model, None, ckpt)
    model.to(DEVICE).eval()
    x,kp = dataset[0]
    with torch.no_grad():
        pred = model(x.unsqueeze(0).to(DEVICE)).cpu().numpy()[0]
    img = (x+1)/2; plt.imshow(img.permute(1,2,0)); plt.scatter([pred[0]*dataset.size],[pred[1]*dataset.size], c='r'); plt.show()



---

### How to use this notebook

- Run the Utilities cell first.
- For each task, run the dataset cell (or creation line), then the training cell to train, then the inference cell to visualize results.
- Checkpoints are saved under `ckpts/` and TensorBoard logs under `runs/`.
- Increase epochs for real training. The defaults are small for a quick runnable demo.
