In [None]:
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50

from ml_carbucks.utils.coco import create_dataset_custom

IMG_SIZE = 320
BATCH_SIZE = 1

class CenterNetHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        # Shared conv
        self.shared = nn.Sequential(
            nn.Conv2d(in_channels, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        # Separate heads
        self.heatmap = nn.Conv2d(256, num_classes, 1)
        self.wh = nn.Conv2d(256, 2, 1)
        self.offset = nn.Conv2d(256, 2, 1)
        self._init_weights()

    def _init_weights(self):
        # Heatmap bias init -> lower confidence initially
        self.heatmap.bias.data.fill_(-2.19)

    def forward(self, x):
        feat = self.shared(x)
        return {
            'heatmap': torch.sigmoid(self.heatmap(feat)),
            'wh': self.wh(feat),
            'offset': self.offset(feat),
        }


class CenterNet(nn.Module):
    def __init__(self, num_classes=3, backbone_name='resnet50', pretrained=True):
        super().__init__()
        backbone = resnet50(weights='IMAGENET1K_V1' if pretrained else None)
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])  # C5 feature map
        self.head = CenterNetHead(2048, num_classes)

    def forward(self, x):
        feat = self.backbone(x)
        out = self.head(feat)
        return out

def focal_loss(pred, gt):
    pos_inds = gt.eq(1).float()
    neg_inds = gt.lt(1).float()
    neg_weights = torch.pow(1 - gt, 4)

    pred = torch.clamp(pred, 1e-6, 1 - 1e-6)
    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

    num_pos = pos_inds.float().sum()
    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()

    return -(pos_loss + neg_loss) / (num_pos + 1e-4)

def compute_loss(preds, targets):
    hm_loss = focal_loss(preds['heatmap'], targets['heatmap'])
    wh_loss = F.l1_loss(preds['wh'], targets['wh'], reduction='mean')
    off_loss = F.l1_loss(preds['offset'], targets['offset'], reduction='mean')
    return hm_loss + wh_loss + off_loss, (hm_loss, wh_loss, off_loss)


def decode_predictions(heatmap, wh, offset, K=100):
    batch, cat, height, width = heatmap.size()
    heatmap = F.max_pool2d(heatmap, 3, stride=1, padding=1)
    scores, inds = torch.topk(heatmap.view(batch, cat, -1), K)
    ys = (inds // width).float()
    xs = (inds % width).float()

    wh = wh.view(batch, 2, -1)
    offset = offset.view(batch, 2, -1)

    xs = xs + offset[:, 0:1, :].gather(2, inds)
    ys = ys + offset[:, 1:2, :].gather(2, inds)

    w = wh[:, 0:1, :].gather(2, inds)
    h = wh[:, 1:2, :].gather(2, inds)

    x1 = xs - w / 2
    y1 = ys - h / 2
    x2 = xs + w / 2
    y2 = ys + h / 2
    return torch.stack([x1, y1, x2, y2], dim=-1), scores

import torch
import torch.nn.functional as F
from torch import optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CenterNet(num_classes=3, backbone_name='resnet50', pretrained=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def encode_targets(boxes, labels, output_size, num_classes):
    heatmap = torch.zeros((num_classes, *output_size), device=device)
    wh = torch.zeros((2, *output_size), device=device)
    offset = torch.zeros((2, *output_size), device=device)

    # You must rescale boxes from image size â†’ output size (downsample by stride=4)
    stride = 4  
    for box, cls in zip(boxes, labels):
        x1, y1, x2, y2 = box
        cx = (x1 + x2) / 2 / stride
        cy = (y1 + y2) / 2 / stride
        w = (x2 - x1) / stride
        h = (y2 - y1) / stride
        cx_int, cy_int = int(cx), int(cy)

        # mark the heatmap (you can add a small gaussian, here just set 1)
        heatmap[cls, cy_int, cx_int] = 1
        wh[:, cy_int, cx_int] = torch.tensor([w, h], device=device)
        offset[:, cy_int, cx_int] = torch.tensor([cx - cx_int, cy - cy_int], device=device)

    return {'heatmap': heatmap, 'wh': wh, 'offset': offset}


train_dataset = create_dataset_custom(
    name="train",
    img_dir=Path("/home/bachelor/ml-carbucks/data/car_dd/images/train"),
    ann_file=Path("/home/bachelor/ml-carbucks/data/car_dd/instances_train.json"),
    limit=1
)
from effdet.data import create_loader
train_loader = create_loader(
    train_dataset,
    input_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    use_prefetcher=True,
    num_workers=4,
    pin_mem=False,
)


model.train()
for epoch in range(10):  # just 10 for test
    for imgs, targets in train_loader:
        imgs = imgs.to(device)

        # Convert your COCO targets to CenterNet targets
        batch_targets = [encode_targets(t['boxes'], t['labels'], (imgs.shape[2]//4, imgs.shape[3]//4), 3) for t in targets]
        batch_hm = torch.stack([bt['heatmap'] for bt in batch_targets])
        batch_wh = torch.stack([bt['wh'] for bt in batch_targets])
        batch_off = torch.stack([bt['offset'] for bt in batch_targets])

        pred = model(imgs)
        loss, (hm_loss, wh_loss, off_loss) = compute_loss(pred, {'heatmap': batch_hm, 'wh': batch_wh, 'offset': batch_off})

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} | Loss: {loss.item():.4f} | hm: {hm_loss.item():.4f} | wh: {wh_loss.item():.4f} | off: {off_loss.item():.4f}")

