In [None]:

# config.py
from pathlib import Path

import torch

CONFIG = {
    "IMG_SIZE": 320,
    "BATCH_SIZE": 8,
    "NUM_CLASSES": 3,
    "EPOCHS": 100,
    "LR": 1e-4,
    "BACKBONE": "resnet50",
    "PRETRAINED": True,
    "TRAIN_DIR": Path("/home/bachelor/ml-carbucks/data/car_dd/images/train"),
    "VAL_DIR": Path("/home/bachelor/ml-carbucks/data/car_dd/images/val"),
    "TRAIN_ANN": Path("/home/bachelor/ml-carbucks/data/car_dd/instances_train.json"),
    "VAL_ANN": Path("/home/bachelor/ml-carbucks/data/car_dd/instances_val.json"),
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "CONF_THRESH": 0.3,
    "NMS_IOU": 0.5,
    "TOPK": 100,
    "GAUSSIAN_RADIUS": 2,
}

# dataset.py
import torch
from torchvision import transforms
from ml_carbucks.utils.coco import create_dataset_custom

# class DamageDataset(torch.utils.data.Dataset):
#     def __init__(self, img_dir, ann_file, img_size, augment=False, limit=None):
#         self.dataset = create_dataset_custom(
#             name="damage",
#             img_dir=img_dir,
#             ann_file=ann_file,
#             limit=limit
#         )
#         self.img_size = img_size
#         self.augment = augment
#         self.transform = transforms.Compose([
#             transforms.ToTensor(),
#             transforms.Resize((img_size, img_size)),
#         ])
#     def __len__(self):
#         return len(self.dataset)
#     def __getitem__(self, idx):
#         img, target = self.dataset[idx]
#         img = self.transform(img)
#         # convert target fields to tensors
#         target_t = {
#             'bbox': torch.tensor(target['bbox'], dtype=torch.float32),
#             'cls': torch.tensor(target['cls'], dtype=torch.long)
#         }
#         return img, target_t

# dataset_aug.py
import torch
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from ml_carbucks.utils.coco import create_dataset_custom
import numpy as np


class DamageDatasetAug(Dataset):
    def __init__(self, img_dir, ann_file, img_size, augment=True, limit=None):
        self.dataset = create_dataset_custom(
            name="damage", img_dir=img_dir, ann_file=ann_file, limit=limit
        )
        self.img_size = img_size
        self.augment = augment

        normalize = A.Normalize(
            mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0
        )

        self.train_transform = A.Compose(
            [
                A.Resize(img_size, img_size),
                A.HorizontalFlip(p=0.5),
                A.RandomBrightnessContrast(p=0.5),
                A.ShiftScaleRotate(
                    shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5
                ),
                normalize,
                ToTensorV2(),
            ]
        )

        self.val_transform = A.Compose(
            [
                A.Resize(img_size, img_size),
                normalize,
                ToTensorV2(),
            ]
        )

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, target = self.dataset[idx]
        bboxes = target["bbox"]
        labels = target["cls"]

        transform = self.train_transform if self.augment else self.val_transform
        transformed = transform(image=np.array(img), bboxes=bboxes, class_labels=labels)
        img = transformed["image"]  # <-- now float32 automatically
        bboxes = torch.tensor(transformed["bboxes"], dtype=torch.float32)
        labels = torch.tensor(transformed["class_labels"], dtype=torch.long)

        return img, {"bbox": bboxes, "cls": labels}


# model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50


class CenterNetHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Conv2d(in_channels, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        self.heatmap = nn.Conv2d(256, num_classes, 1)
        self.wh = nn.Conv2d(256, 2, 1)
        self.offset = nn.Conv2d(256, 2, 1)
        self._init_weights()

    def _init_weights(self):
        self.heatmap.bias.data.fill_(-2.19)

    def forward(self, x):
        feat = self.shared(x)
        return {
            "heatmap": torch.sigmoid(self.heatmap(feat)),
            "wh": self.wh(feat),
            "offset": self.offset(feat),
        }


class CenterNet(nn.Module):
    def __init__(self, num_classes=3, backbone_name="resnet50", pretrained=True):
        super().__init__()
        backbone = resnet50(weights="IMAGENET1K_V1" if pretrained else None)
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])
        self.head = CenterNetHead(2048, num_classes)

    def forward(self, x):
        feat = self.backbone(x)
        return self.head(feat)


# losses.py
import torch
import torch.nn.functional as F


def focal_loss(pred, gt):
    pos_inds = gt.eq(1).float()
    neg_inds = gt.lt(1).float()
    neg_weights = torch.pow(1 - gt, 4)

    pred = torch.clamp(pred, 1e-6, 1 - 1e-6)
    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

    num_pos = pos_inds.sum()
    return -(pos_loss.sum() + neg_loss.sum()) / (num_pos + 1e-4)


def compute_loss(preds, targets):
    hm_loss = focal_loss(preds["heatmap"], targets["heatmap"])
    wh_loss = F.smooth_l1_loss(preds["wh"], targets["wh"], reduction="mean")
    off_loss = F.smooth_l1_loss(preds["offset"], targets["offset"], reduction="mean")
    return hm_loss + wh_loss + off_loss, (hm_loss, wh_loss, off_loss)


# utils.py
import torch
import torchvision
import torch.nn.functional as F


def draw_gaussian_2d(heatmap, cx, cy, sigma):
    """Draw a 2D Gaussian on the heatmap centered at (cx, cy)."""
    tmp_size = int(3 * sigma)
    mu_x, mu_y = int(cx), int(cy)
    H, W = heatmap.shape[-2:]

    ul = [max(0, mu_x - tmp_size), max(0, mu_y - tmp_size)]
    br = [min(W, mu_x + tmp_size + 1), min(H, mu_y + tmp_size + 1)]

    size = 2 * tmp_size + 1
    x = torch.arange(0, size, dtype=torch.float32, device=heatmap.device)
    y = x[:, None]
    g = torch.exp(-((x - tmp_size) ** 2 + (y - tmp_size) ** 2) / (2 * sigma**2))

    # crop g to match valid region in heatmap
    g_x = max(0, -(mu_x - tmp_size)), min(br[0] - ul[0], size)
    g_y = max(0, -(mu_y - tmp_size)), min(br[1] - ul[1], size)

    heatmap_region = heatmap[ul[1] : br[1], ul[0] : br[0]]
    g_region = g[g_y[0] : g_y[1], g_x[0] : g_x[1]]

    # handle rare mismatch by clipping dynamically
    hH, hW = heatmap_region.shape
    gH, gW = g_region.shape
    minH, minW = min(hH, gH), min(hW, gW)

    heatmap[ul[1] : ul[1] + minH, ul[0] : ul[0] + minW] = torch.max(
        heatmap[ul[1] : ul[1] + minH, ul[0] : ul[0] + minW], g_region[:minH, :minW]
    )


def encode_targets_vectorized(
    batch_boxes, batch_labels, output_size, num_classes, stride, sigma=2
):
    """
    batch_boxes: list of tensors [num_boxes, 4] per image
    batch_labels: list of tensors [num_boxes] per image
    returns: dict of tensors [B, C, H, W] heatmap, wh, offset
    """
    B = len(batch_boxes)
    H, W = output_size
    device = batch_boxes[0].device

    heatmaps = torch.zeros(B, num_classes, H, W, device=device)
    wh_maps = torch.zeros(B, 2, H, W, device=device)
    offset_maps = torch.zeros(B, 2, H, W, device=device)

    for b in range(B):
        boxes = batch_boxes[b]
        labels = batch_labels[b]
        if boxes.numel() == 0:
            continue
        cx = (boxes[:, 0] + boxes[:, 2]) / 2 / stride
        cy = (boxes[:, 1] + boxes[:, 3]) / 2 / stride
        w = (boxes[:, 2] - boxes[:, 0]) / stride
        h = (boxes[:, 3] - boxes[:, 1]) / stride
        cx_int = cx.long()
        cy_int = cy.long()

        for i in range(len(boxes)):
            cls = labels[i]
            if 0 <= cx_int[i] < W and 0 <= cy_int[i] < H:
                draw_gaussian_2d(heatmaps[b, cls], cx[i], cy[i], sigma)
                wh_maps[b, :, cy_int[i], cx_int[i]] = torch.tensor(
                    [w[i], h[i]], device=device
                )
                offset_maps[b, :, cy_int[i], cx_int[i]] = torch.tensor(
                    [cx[i] - cx_int[i], cy[i] - cy_int[i]], device=device
                )

    return {"heatmap": heatmaps, "wh": wh_maps, "offset": offset_maps}


def encode_targets(boxes, labels, output_size, num_classes, stride, radius=2):
    heatmap = torch.zeros((num_classes, *output_size))
    wh = torch.zeros((2, *output_size))
    offset = torch.zeros((2, *output_size))

    def draw_gaussian(hm, x, y, radius):
        diameter = 2 * radius + 1
        gaussian = torch.exp(
            -(
                (torch.arange(diameter).view(-1, 1) - radius) ** 2
                + (torch.arange(diameter).view(1, -1) - radius) ** 2
            )
            / (2 * (radius / 3) ** 2)
        )
        x0 = max(0, x - radius)
        y0 = max(0, y - radius)
        x1 = min(hm.shape[1], x + radius + 1)
        y1 = min(hm.shape[0], y + radius + 1)
        hm[y0:y1, x0:x1] = torch.max(hm[y0:y1, x0:x1], gaussian[: y1 - y0, : x1 - x0])

    H, W = output_size
    for box, cls in zip(boxes, labels):
        x1, y1, x2, y2 = box
        cx = (x1 + x2) / 2 / stride
        cy = (y1 + y2) / 2 / stride
        w = (x2 - x1) / stride
        h = (y2 - y1) / stride
        cx_int, cy_int = int(cx), int(cy)
        if 0 <= cx_int < W and 0 <= cy_int < H:
            draw_gaussian(heatmap[cls], cx_int, cy_int, radius)
            wh[:, cy_int, cx_int] = torch.tensor([w, h])
            offset[:, cy_int, cx_int] = torch.tensor([cx - cx_int, cy - cy_int])
    return {"heatmap": heatmap, "wh": wh, "offset": offset}


def decode_predictions(preds, conf_thresh=0.3, stride=32, K=100, nms_kernel=3):
    heatmap, wh, offset = preds["heatmap"], preds["wh"], preds["offset"]
    batch, cat, H, W = heatmap.shape
    pooled = F.max_pool2d(heatmap, nms_kernel, stride=1, padding=nms_kernel // 2)
    heatmap = heatmap * (pooled == heatmap).float()

    boxes_list, scores_list, labels_list = [], [], []

    for b in range(batch):
        for c in range(cat):
            hm_flat = heatmap[b, c].view(-1)
            topk_scores, topk_inds = torch.topk(hm_flat, K)
            mask = topk_scores > conf_thresh
            if mask.sum() == 0:
                continue
            topk_scores = topk_scores[mask]
            topk_inds = topk_inds[mask]
            ys = (topk_inds // W).float()
            xs = (topk_inds % W).float()
            w = wh[b, 0].view(-1)[topk_inds]
            h = wh[b, 1].view(-1)[topk_inds]
            off_x = offset[b, 0].view(-1)[topk_inds]
            off_y = offset[b, 1].view(-1)[topk_inds]
            xs = xs + off_x
            ys = ys + off_y
            x1 = (xs - w / 2) * stride
            y1 = (ys - h / 2) * stride
            x2 = (xs + w / 2) * stride
            y2 = (ys + h / 2) * stride
            boxes_list.append(torch.stack([x1, y1, x2, y2], dim=-1))
            scores_list.append(topk_scores)
            labels_list.append(torch.full_like(topk_scores, c, dtype=torch.int))

    if len(boxes_list) == 0:
        return (
            torch.empty((0, 4)),
            torch.empty((0,)),
            torch.empty((0,), dtype=torch.int),
        )

    boxes = torch.cat(boxes_list)
    scores = torch.cat(scores_list)
    labels = torch.cat(labels_list)
    keep = torchvision.ops.nms(boxes, scores, iou_threshold=0.5)
    return boxes[keep], scores[keep], labels[keep]


# train_production.py
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
import torchvision
from torchvision.models import resnet50
from pathlib import Path
import os

from tqdm import tqdm


# --- Config ---
IMG_SIZE = CONFIG["IMG_SIZE"]
BATCH_SIZE = CONFIG["BATCH_SIZE"]
NUM_CLASSES = CONFIG["NUM_CLASSES"]
EPOCHS = CONFIG["EPOCHS"]
LR = CONFIG["LR"]
CHECKPOINT_DIR = Path("./checkpoints")
CHECKPOINT_DIR.mkdir(exist_ok=True, parents=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
GAUSSIAN_RADIUS = 2


# --- Model ---
class CenterNetHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Conv2d(in_channels, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        self.heatmap = nn.Conv2d(256, num_classes, 1)
        self.wh = nn.Conv2d(256, 2, 1)
        self.offset = nn.Conv2d(256, 2, 1)
        self._init_weights()

    def _init_weights(self):
        self.heatmap.bias.data.fill_(-2.19)

    def forward(self, x):
        feat = self.shared(x)
        return {
            "heatmap": torch.sigmoid(self.heatmap(feat)),
            "wh": self.wh(feat),
            "offset": self.offset(feat),
        }


class CenterNet(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, pretrained=True):
        super().__init__()
        backbone = resnet50(weights="IMAGENET1K_V1" if pretrained else None)
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])
        self.head = CenterNetHead(2048, num_classes)

    def forward(self, x):
        feat = self.backbone(x)
        return self.head(feat)


# --- Loss ---
def focal_loss(pred, gt):
    pos_inds = gt.eq(1).float()
    neg_inds = gt.lt(1).float()
    neg_weights = torch.pow(1 - gt, 4)

    pred = torch.clamp(pred, 1e-6, 1 - 1e-6)
    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

    num_pos = pos_inds.float().sum()
    return -(pos_loss.sum() + neg_loss.sum()) / (num_pos + 1e-4)


def compute_loss(preds, targets):
    hm_loss = focal_loss(preds["heatmap"], targets["heatmap"])
    wh_loss = F.l1_loss(preds["wh"], targets["wh"], reduction="mean")
    off_loss = F.l1_loss(preds["offset"], targets["offset"], reduction="mean")
    return hm_loss + wh_loss + off_loss, (hm_loss, wh_loss, off_loss)


def collate_fn(batch):
    imgs = []
    targets = []
    for sample in batch:
        img, target = sample
        imgs.append(img)
        targets.append(target)
    imgs = torch.stack(imgs, 0)
    return imgs, targets


# --- Dataset & Loader ---
train_dataset = DamageDatasetAug(
    img_dir=Path("/home/bachelor/ml-carbucks/data/car_dd/images/train"),
    ann_file=Path("/home/bachelor/ml-carbucks/data/car_dd/instances_train.json"),
    img_size=IMG_SIZE,
    augment=True,
)
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn,
)

# --- Model & Optimizer ---
model = CenterNet(num_classes=NUM_CLASSES).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler()  # mixed precision

# --- Precompute stride / output size ---
with torch.no_grad():
    dummy = torch.zeros(1, 3, IMG_SIZE, IMG_SIZE).to(DEVICE)
    pred = model(dummy)
target_h, target_w = pred["heatmap"].shape[2], pred["heatmap"].shape[3]
stride = IMG_SIZE // target_h

# --- Training Loop ---
# for epoch in range(EPOCHS):
#     model.train()
#     running_loss = 0
#     for (imgs, targets) in tqdm(train_loader):
#         imgs = imgs.to(DEVICE)
#         batch_boxes = [t['bbox'].to(DEVICE) for t in targets]
#         batch_labels = [t['cls'].to(DEVICE) for t in targets]

#         # Encode targets vectorized
#         batch_targets = encode_targets_vectorized(
#             batch_boxes, batch_labels, output_size=(target_h, target_w),
#             num_classes=NUM_CLASSES, stride=stride, sigma=GAUSSIAN_RADIUS
#         )
#         optimizer.zero_grad()
#         with torch.cuda.amp.autocast():
#             preds = model(imgs)
#             loss, (hm_loss, wh_loss, off_loss) = compute_loss(preds, batch_targets)
#         scaler.scale(loss).backward()
#         scaler.step(optimizer)
#         scaler.update()
#         running_loss += loss.item()
#     avg_loss = running_loss / len(train_loader)
#     print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f} | hm: {hm_loss.item():.4f} | wh: {wh_loss.item():.4f} | off: {off_loss.item():.4f}")
#     # --- Save checkpoint every 5 epochs ---
#     if (epoch+1) % 5 == 0:
#         ckpt_path = CHECKPOINT_DIR / f"centernet_epoch{epoch+1}.pth"
#         torch.save({
#             "epoch": epoch+1,
#             "model_state": model.state_dict(),
#             "optimizer_state": optimizer.state_dict()
#         }, ckpt_path)
#         print(f"Saved checkpoint: {ckpt_path}")


model = CenterNet(num_classes=NUM_CLASSES).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler()  # mixed precision

# load model if exists
latest_ckpt = (
    sorted(CHECKPOINT_DIR.glob("centernet_epoch*.pth"))[-1]
    if list(CHECKPOINT_DIR.glob("centernet_epoch*.pth"))
    else None
)
if latest_ckpt:
    checkpoint = torch.load(latest_ckpt, map_location=DEVICE)
    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    start_epoch = checkpoint["epoch"]
    print(f"Resumed from checkpoint: {latest_ckpt} at epoch {start_epoch}")


model.eval()
torch.set_grad_enabled(False)

all_preds, all_gts = [], []

val_dataset = DamageDatasetAug(
    img_dir=Path("/home/bachelor/ml-carbucks/data/car_dd/images/val"),
    ann_file=Path("/home/bachelor/ml-carbucks/data/car_dd/instances_val.json"),
    img_size=IMG_SIZE,
    augment=False,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn,
)

from torchmetrics.detection.mean_ap import MeanAveragePrecision

all_preds = []
all_gts = []

for imgs, targets in val_loader:
    imgs = imgs.to(DEVICE)
    preds = model(imgs)

    for i in range(len(imgs)):
        # Decode predictions for a single image
        boxes_i, scores_i, labels_i = decode_predictions(
            {k: v[i : i + 1] for k, v in preds.items()}, conf_thresh=0.05, stride=stride
        )

        # Make sure tensors are correct dtype and on CPU
        if len(boxes_i) == 0:
            boxes_i = torch.zeros((0, 4), dtype=torch.float32)
            scores_i = torch.zeros((0,), dtype=torch.float32)
            labels_i = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes_i = boxes_i.float()
            scores_i = scores_i.float()
            labels_i = labels_i.long()

        all_preds.append(
            {"boxes": boxes_i.cpu(), "scores": scores_i.cpu(), "labels": labels_i.cpu()}
        )

        # Ground truth
        gt_boxes = targets[i]["bbox"]
        gt_labels = targets[i]["cls"]
        if len(gt_boxes) == 0:
            gt_boxes = torch.zeros((0, 4), dtype=torch.float32)
            gt_labels = torch.zeros((0,), dtype=torch.int64)
        else:
            gt_boxes = gt_boxes.float()
            gt_labels = gt_labels.long()

        all_gts.append({"boxes": gt_boxes.cpu(), "labels": gt_labels.cpu()})

# Compute metric
metric = MeanAveragePrecision()
metric.update(all_preds, all_gts)
result = metric.compute()
print(result)
