In [None]:
import json, random
from datetime import datetime
from pathlib import Path
from typing import Dict, List

import numpy as np
import timm
import torch
import torch.nn as nn
from PIL import Image
from timm.optim import optim_factory
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from torchvision.transforms import functional as TF

BASE_DIR: Path = Path(r"C:\Users\meme machine\Downloads\Capstone-20250618T001711Z-1-001\Capstone\project")
IMG_SIZE = 224
PATCH_SIZE = 16
BATCH_SIZE = 256
EPOCHS = 5
NUM_WORKERS = 4
LR = 1e-4
WEIGHT_DECAY = 1e-4
GRAD_CLIP_NORM = 1.0
AUG_HFLIP = True
GRAD_CHECKPOINT = True
LABEL_SMOOTH = 0.05
CLS_WEIGHT = 1.0
REG_WEIGHT = 1.0
IOU_WEIGHT = 0.5
MODEL_NAME = "vit_large_patch16_224"
SEED = 42

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_CUDA_AMP = (DEVICE.type == "cuda")

RUN_DIR = Path("runs") / datetime.now().strftime("%Y%m%d-%H%M%S")
RUN_DIR.mkdir(parents=True, exist_ok=True)

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.use_deterministic_algorithms(False)
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

def make_patch_centers(img_size: int = 224, patch: int = 16):
    xs, ys = np.meshgrid(
        np.arange(patch // 2, img_size, patch),
        np.arange(patch // 2, img_size, patch),
    )
    return torch.tensor(np.stack([xs, ys], -1).reshape(-1, 2), dtype=torch.float32)

PATCH_CENTERS = make_patch_centers(IMG_SIZE, PATCH_SIZE)
NUM_PATCHES = PATCH_CENTERS.shape[0]

def load_lvis_mapping(base_dir: Path):
    p_train = base_dir / "annotations" / "lvis_v1_train.json"
    p_val = base_dir / "annotations" / "lvis_v1_val.json"
    p = p_train if p_train.exists() else p_val
    data = json.load(open(p, "r", encoding="utf-8"))
    cats = sorted(data["categories"], key=lambda c: c["id"])
    id_to_idx = {c["id"]: i for i, c in enumerate(cats)}
    names = [c["name"] for c in cats]
    return id_to_idx, names

CAT_ID_TO_IDX, CAT_NAMES = load_lvis_mapping(BASE_DIR)
NUM_CLASSES = len(CAT_NAMES)
NUM_LOGITS = NUM_CLASSES + 1

class LVISDataset(Dataset):
    def __init__(self, split: str, transform: T.Compose, id_to_idx: Dict[int, int], aug_hflip: bool = False):
        root = BASE_DIR / f"{split}2017"
        ann = BASE_DIR / "annotations" / f"lvis_v1_{split}.json"
        self.root, self.transform, self.aug_hflip = root, transform, aug_hflip
        with open(ann, "r", encoding="utf-8") as f:
            data = json.load(f)
        self.images: Dict[int, dict] = {im["id"]: im for im in data["images"]}
        self.ann_map: Dict[int, List[dict]] = {}
        for a in data["annotations"]:
            if a.get("iscrowd", 0):
                continue
            self.ann_map.setdefault(a["image_id"], []).append(a)
        self.ids = list(self.images.keys())
        self.id_to_idx = id_to_idx

    def __len__(self):
        return len(self.ids)

    def _resize_boxes(self, boxes, sx, sy):
        out = []
        for (x, y, w, h) in boxes:
            out.append([x * sx, y * sy, (x + w) * sx, (y + h) * sy])
        return out

    def __getitem__(self, idx):
        iid = self.ids[idx]
        info = self.images[iid]
        img_path = self.root / info["file_name"]
        img = Image.open(img_path).convert("RGB")
        sx, sy = IMG_SIZE / info["width"], IMG_SIZE / info["height"]
        anns = self.ann_map.get(iid, [])
        boxes_coco = [ann["bbox"] for ann in anns]
        labels = [self.id_to_idx[int(ann["category_id"])] for ann in anns]
        boxes_xyxy = self._resize_boxes(boxes_coco, sx, sy)
        if self.transform:
            img = self.transform(img)
        boxes = torch.tensor(boxes_xyxy, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)
        if self.aug_hflip and random.random() < 0.5:
            img = TF.hflip(img)
            if boxes.numel() > 0:
                x1, y1, x2, y2 = boxes.unbind(-1)
                new_x1 = IMG_SIZE - x2
                new_x2 = IMG_SIZE - x1
                boxes = torch.stack([new_x1, y1, new_x2, y2], dim=-1)
        return img, {"boxes": boxes, "labels": labels}

def collate(batch):
    imgs, tgts = zip(*batch)
    return list(imgs), list(tgts)

class ViTDetector(nn.Module):
    def __init__(self, num_classes: int, model_name: str = "vit_base_patch16_224", grad_checkpoint: bool = False):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0, drop_path_rate=0.1)
        if grad_checkpoint and hasattr(self.backbone, "set_grad_checkpointing"):
            self.backbone.set_grad_checkpointing()
        d = self.backbone.num_features
        self.head = nn.Sequential(
            nn.LayerNorm(d),
            nn.Linear(d, d),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(d, 4 + num_classes + 1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feats = self.backbone.forward_features(x)[:, 1:]
        out = self.head(feats)
        return out

def iou_aligned(b1: torch.Tensor, b2: torch.Tensor):
    x1 = torch.max(b1[:, 0], b2[:, 0])
    y1 = torch.max(b1[:, 1], b2[:, 1])
    x2 = torch.min(b1[:, 2], b2[:, 2])
    y2 = torch.min(b1[:, 3], b2[:, 3])
    iw = (x2 - x1).clamp(min=0)
    ih = (y2 - y1).clamp(min=0)
    inter = iw * ih
    a1 = (b1[:, 2] - b1[:, 0]).clamp(min=0) * (b1[:, 3] - b1[:, 1]).clamp(min=0)
    a2 = (b2[:, 2] - b2[:, 0]).clamp(min=0) * (b2[:, 3] - b2[:, 1]).clamp(min=0)
    union = a1 + a2 - inter + 1e-6
    return inter / union

class PatchCriterion(nn.Module):
    def __init__(self, cls_w=1.0, reg_w=1.0, iou_w=0.5, label_smooth=0.05):
        super().__init__()
        self.cls_loss = nn.CrossEntropyLoss(label_smoothing=label_smooth)
        self.reg_loss = nn.SmoothL1Loss()
        self.cls_w = cls_w
        self.reg_w = reg_w
        self.iou_w = iou_w

    def assign(self, boxes: torch.Tensor, labels: torch.Tensor):
        P = NUM_PATCHES
        device = boxes.device
        cls_t = torch.zeros(P, dtype=torch.long, device=device)
        reg_t = torch.zeros(P, 4, dtype=torch.float32, device=device)
        mask = torch.zeros(P, dtype=torch.bool, device=device)
        if boxes.numel() == 0:
            return cls_t, reg_t, mask
        centers = PATCH_CENTERS.to(device)
        cx, cy = centers[:, 0], centers[:, 1]
        for box, lbl in zip(boxes, labels):
            x1, y1, x2, y2 = box
            inside = (cx >= x1) & (cx <= x2) & (cy >= y1) & (cy <= y2)
            if inside.any():
                cls_t[inside] = int(lbl) + 1
                reg_t[inside] = box
                mask[inside] = True
        return cls_t, reg_t, mask

    def forward(self, pred: torch.Tensor, targets: List[dict]):
        B, P, _ = pred.shape
        reg_p, cls_p = pred[..., :4], pred[..., 4:]
        cls_t_all = torch.zeros(B, P, dtype=torch.long, device=pred.device)
        reg_t_all = torch.zeros(B, P, 4, dtype=torch.float32, device=pred.device)
        mask_all = torch.zeros(B, P, dtype=torch.bool, device=pred.device)
        for i, tgt in enumerate(targets):
            c, r, m = self.assign(tgt["boxes"].to(pred.device), tgt["labels"].to(pred.device))
            cls_t_all[i], reg_t_all[i], mask_all[i] = c, r, m
        cls_loss = self.cls_loss(cls_p.permute(0, 2, 1), cls_t_all)
        if mask_all.any():
            rp = reg_p[mask_all]
            rt = reg_t_all[mask_all]
            reg_loss = self.reg_loss(rp, rt)
            iou = iou_aligned(rp, rt)
            iou_loss = (1.0 - iou).mean()
        else:
            reg_loss = torch.tensor(0.0, device=pred.device)
            iou_loss = torch.tensor(0.0, device=pred.device)
        return self.cls_w * cls_loss + self.reg_w * reg_loss + self.iou_w * iou_loss

criterion = PatchCriterion(CLS_WEIGHT, REG_WEIGHT, IOU_WEIGHT, LABEL_SMOOTH).to(DEVICE)

tf_train = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ColorJitter(0.1, 0.1, 0.1),
    T.ToTensor(),
    T.Normalize([0.5] * 3, [0.5] * 3),
])

tf_val = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize([0.5] * 3, [0.5] * 3),
])

def make_loaders(workers: int = 4):
    train_ds = LVISDataset("train", tf_train, CAT_ID_TO_IDX, aug_hflip=AUG_HFLIP)
    val_ds = LVISDataset("val", tf_val, CAT_ID_TO_IDX, aug_hflip=False)
    train_ld = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=workers, collate_fn=collate, pin_memory=True, persistent_workers=(workers > 0))
    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=workers, collate_fn=collate, pin_memory=True, persistent_workers=(workers > 0))
    return train_ld, val_ld

def build_optimizer(model):
    param_groups = optim_factory.param_groups_weight_decay(model, weight_decay=WEIGHT_DECAY)
    opt = torch.optim.AdamW(param_groups, lr=LR)
    return opt

def build_scheduler(opt, train_steps_total):
    warmup_steps = max(100, int(0.05 * train_steps_total))
    sched1 = torch.optim.lr_scheduler.LinearLR(opt, start_factor=0.1, total_iters=warmup_steps)
    sched2 = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1, train_steps_total - warmup_steps))
    sched = torch.optim.lr_scheduler.SequentialLR(opt, schedulers=[sched1, sched2], milestones=[warmup_steps])
    return sched

def train_epoch(model, dl, opt, scaler, scheduler, epoch: int):
    model.train()
    total = 0.0
    n = 0
    for imgs, tgts in dl:
        imgs = torch.stack(imgs).to(DEVICE, non_blocking=True)
        with autocast(device_type="cuda", enabled=USE_CUDA_AMP):
            preds = model(imgs)
            loss = criterion(preds, tgts)
        if USE_CUDA_AMP:
            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
            scaler.step(opt)
            scaler.update()
        else:
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
            opt.step()
        scheduler.step()
        total += float(loss.item())
        n += 1
    return total / max(n, 1)

@torch.no_grad()
def validate(model, dl):
    model.eval()
    total = 0.0
    n = 0
    for imgs, tgts in dl:
        imgs = torch.stack(imgs).to(DEVICE, non_blocking=True)
        with autocast(device_type="cuda", enabled=USE_CUDA_AMP):
            preds = model(imgs)
            loss = criterion(preds, tgts)
        total += float(loss.item())
        n += 1
    return total / max(n, 1)
print(f"Epoch")

def main():
    train_ld, val_ld = make_loaders(workers=NUM_WORKERS)
    model = ViTDetector(NUM_CLASSES, model_name=MODEL_NAME, grad_checkpoint=GRAD_CHECKPOINT).to(DEVICE)
    opt = build_optimizer(model)
    scaler = GradScaler(enabled=USE_CUDA_AMP)
    scheduler = build_scheduler(opt, EPOCHS * len(train_ld))
    for epoch in range(1, EPOCHS + 1):
        print(f"Epoch {epoch}/{EPOCHS}")
        tr_loss = train_epoch(model, train_ld, opt, scaler, scheduler, epoch)
        val_loss = validate(model, val_ld)
        print(f"Epoch {epoch}: train {tr_loss:.4f} | val {val_loss:.4f}")
        torch.save(model.state_dict(), RUN_DIR / f"epoch{epoch:02d}.pth")

if __name__ == "__main__":
    try:
        main()
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            print("CUDA out of memory. Reduce BATCH_SIZE or IMG_SIZE, or keep GRAD_CHECKPOINT=True.")
            raise
        else:
            raise




Epoch


model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
  scaler = GradScaler(enabled=USE_CUDA_AMP)


Epoch 1/5


In [None]:
tf_infer = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize([0.5]*3, [0.5]*3),
])

def latest_checkpoint(runs_dir=Path("runs")):
    pths = []
    if runs_dir.exists():
        for root, _, files in os.walk(runs_dir):
            for f in files:
                if f.endswith(".pth"):
                    pths.append(Path(root) / f)
    if not pths:
        raise FileNotFoundError("no checkpoints found in runs/")
    pths.sort(key=lambda p: p.stat().st_mtime, reverse=True)
    return pths[0]

def load_id_to_class(base_dir: Path):
    ann_val = base_dir / "annotations" / "lvis_v1_val.json"
    ann_train = base_dir / "annotations" / "lvis_v1_train.json"
    path = ann_val if ann_val.exists() else ann_train
    if not path.exists():
        return {}
    data = json.load(open(path, "r", encoding="utf-8"))
    return {c["id"]: c["name"] for c in data["categories"]}

class ViTDetector(torch.nn.Module):
    def __init__(self, num_classes: int, grad_checkpoint: bool = False):
        super().__init__()
        import timm
        self.backbone = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=0)
        if grad_checkpoint and hasattr(self.backbone, "set_grad_checkpointing"):
            self.backbone.set_grad_checkpointing()
        d = self.backbone.num_features
        self.head = torch.nn.Linear(d, 4 + num_classes + 1)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feats = self.backbone.forward_features(x)[:, 1:]
        out = self.head(feats)
        return out

def predict_image(img_path: Path, model, id_to_class, score_thresh=0.5, iou_thresh=0.5, max_dets=100):
    img = Image.open(img_path).convert("RGB")
    x = tf_infer(img).unsqueeze(0).to(DEVICE)
    with torch.inference_mode():
        with torch.cuda.amp.autocast(enabled=USE_CUDA_AMP):
            out = model(x)
            reg = out[..., :4][0]
            cls = out[..., 4:][0]
            prob = F.softmax(cls, dim=-1)
            scores, labels = prob.max(dim=-1)
            keep = (labels != 0) & (scores >= score_thresh)
            if keep.sum() == 0:
                return img
            boxes = reg[keep]
            scores = scores[keep]
            labels = labels[keep] - 1
            final_boxes = []
            final_scores = []
            final_labels = []
            for c in labels.unique():
                idx = torch.nonzero(labels == c, as_tuple=False).squeeze(1)
                b = boxes[idx]
                s = scores[idx]
                keep_idx = nms(b, s, iou_thresh)
                keep_idx = keep_idx[:max_dets]
                final_boxes.append(b[keep_idx])
                final_scores.append(s[keep_idx])
                final_labels.append(labels[idx][keep_idx])
            boxes = torch.cat(final_boxes, dim=0)
            scores = torch.cat(final_scores, dim=0)
            labels = torch.cat(final_labels, dim=0)
    draw = ImageDraw.Draw(img.resize((IMG_SIZE, IMG_SIZE)))
    for b, s, l in zip(boxes, scores, labels):
        x1, y1, x2, y2 = [v.item() for v in b.clamp(0, IMG_SIZE)]
        name = id_to_class.get(int(l.item()), f"id_{int(l.item())}")
        draw.rectangle([(x1, y1), (x2, y2)], outline=(255, 0, 0), width=2)
        draw.text((x1 + 2, y1 + 2), f"{name} {s.item():.2f}")
    return img

def load_model(ckpt_path: Path):
    model = ViTDetector(NUM_CLASSES, grad_checkpoint=False).to(DEVICE)
    sd = torch.load(ckpt_path, map_location=DEVICE)
    model.load_state_dict(sd, strict=True)
    model.eval()
    return model

ckpt = latest_checkpoint()
model = load_model(ckpt)
id_to_class = load_id_to_class(BASE_DIR)
VAL_DIR = BASE_DIR / "val2017"
paths = sorted([p for p in VAL_DIR.glob("*.jpg")])[:8] if VAL_DIR.exists() else []
OUT_DIR = Path("preds")
OUT_DIR.mkdir(exist_ok=True, parents=True)

results = []
for p in paths:
    img_pred = predict_image(p, model, id_to_class, score_thresh=0.5, iou_thresh=0.5)
    save_path = OUT_DIR / f"{p.stem}_pred.jpg"
    img_pred.save(save_path)
    results.append(save_path)

for p in results:
    display(Image.open(p))