In [None]:

import torch
import torch.nn as nn
import torchvision
import numpy as np
import sklearn
# !pip install sklearn
import sklearn.metrics
import matcher
from tqdm import tqdm
import torchvision
from torchvision.datasets import CocoDetection
from torchvision import transforms
from torch.utils.data import DataLoader


In [None]:
from typing import Dict
class CNNbackbone(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.resnet = list(torchvision.models.resnet34().children())[:-2]
        self.cuttedResnet34 = nn.Sequential(*self.resnet)
        self.cnn1 = nn.Conv2d(512, 2048, 3, 1, 1)
        self.descaler = nn.Conv2d(2048, d, 1)
    def forward(self, input: torch.Tensor):
        # image input z,c,800,600
        y = self.cuttedResnet34(input)
        y = self.cnn1(y)
        y = self.descaler(y)
        return y
class DETR(nn.Module):
    def __init__(self, descaleDims: int, num_heads: int, num_classes: int, num_queries: int):
        super().__init__()
        self.backbone = CNNbackbone(descaleDims)
        self.descale = descaleDims

        self.object_queries = nn.Parameter(torch.rand(num_queries, self.descale))
        self.col_embed = nn.Parameter(torch.rand(100, self.descale // 2))
        self.row_embed = nn.Parameter(torch.rand(100, self.descale // 2))

        self.num_heads = num_heads
        self.num_classes = num_classes


        self.encoderL = nn.TransformerEncoderLayer(self.descale, self.num_heads)
        self.encoder = nn.TransformerEncoder(encoder_layer=self.encoderL, num_layers=4)

        self.decoderL = nn.TransformerDecoderLayer(self.descale, self.num_heads)
        self.decoder = nn.TransformerDecoder(self.decoderL, 4)

        self.ffn1 = FeedForward(self.descale, 0.1)
        self.ffn2 = FeedForward(self.descale, 0.1)
        self.ffn3 = FeedForward(self.descale, 0.1)
        
        self.single_box_out = nn.Linear(self.descale, 4)# X, Y _n | W, H _n
        self.categorical = nn.Linear(self.descale, num_classes + 1)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        pass
    def forward(self, input: torch.Tensor) -> Dict[str, torch.Tensor]:

        W_ORG, H_ORG = input.shape[-2:]

        y = self.backbone(input).squeeze(0)
        batch_size, d, H, W = (None, None, None, None)
        if(len(y.size()) == 3):
            y = y.unsqueeze(0)
        if (len(y.size()) == 4):
            batch_size, d, H, W = y.size()
        # print(y.size())
        # y = y.reshape(d, H*W)
        # print(self.col_embed[:W])
        emb_pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1)
        ], dim=-1).flatten(0, 1).unsqueeze(1)
        # add posistional embedding
        # print(f"{y.shape}, {emb_pos.shape}, {H}, {W}")
        # print(f"{y.flatten(2).permute(2,0,1).shape}, {emb_pos.shape}")
        h = y.flatten(2).permute(2, 0, 1)
        h = h + emb_pos
        # y = y.unsqueeze(1).flatten(2).permute(2,0,1) + emb_pos
        # print(f"{h.shape}, {emb_pos.shape}")
        y = self.encoder.forward(h)
        
        tgt = self.object_queries.unsqueeze(1).repeat(1, batch_size, 1)
        # print(f"{tgt.shape} {y.shape}")
        y = self.decoder.forward(tgt, y)
        # torch.Size([1024, 475]) d, H*W
        # dla kazdego d
        y = self.ffn3(self.ffn2(self.ffn1(y)))
        
        boxes = self.single_box_out(y)
        classes = self.categorical(y)

        return {"pred_boxes": self.prediction_normalized_xywh_to_x1y1x2y2(self.sigmoid(boxes), W_ORG, H_ORG),
                "pred_logits": classes}
    def prediction_normalized_xywh_to_x1y1x2y2(self, box: torch.Tensor, imgWidth: int, imgHeight: int):
        x_n = box[..., 0]
        y_n = box[..., 1]
        w_n = box[..., 2]
        h_n = box[..., 3]

        x1 = x_n * imgWidth
        y1 = y_n * imgHeight
        x2 = x1 + w_n * imgWidth
        y2 = y1 + h_n * imgHeight

        return torch.stack([x1, y1, x2, y2], dim=-1)







In [None]:
IMAGE_FIX_WIDTH = 800
IMAGE_FIX_HEIGHT = 800
# Define a simple transform (resize, to tensor, etc.)
transform = transforms.Compose([
    transforms.Resize((IMAGE_FIX_WIDTH, IMAGE_FIX_HEIGHT)),
    transforms.ToTensor()
])

# Load the training dataset
coco_train = CocoDetection(
    root='coco/train2017',
    annFile='coco/annotations/instances_train2017.json',
    transform=transform
)
coco_valid = CocoDetection(
    root='coco/val2017',
    annFile='coco/annotations/instances_val2017.json',
    transform=transform
)
def collate_fn(batch):
    """
    batch: list of tuples (image, anns)
      - image: Tensor[C,H,W]
      - anns: list of dicts, each with keys 'bbox' and 'category_id' (and others)
    Returns:
      images: Tensor[B, C, H, W]
      targets: list of B dicts, each with:
        - 'boxes': Tensor[N_i, 4] in [x1,y1,x2,y2] format
        - 'labels': Tensor[N_i]
    """
    images, all_anns = zip(*batch)
    # Stack images into [B, C, H, W]
    images = torch.stack(images, dim=0)

    targets = []
    for anns in all_anns:
        boxes = []
        labels = []
        for ann in anns:
            x, y, w, h = ann['bbox']
            boxes.append([x, y, x + w, y + h])
            labels.append(ann['category_id'] - 1)  # zero‑base labels

        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)

        targets.append({
            'boxes': boxes,
            'labels': labels
        })

    return images, targets


# Load into DataLoader
train_loader = DataLoader(coco_train, batch_size=8, shuffle=True, num_workers=8, collate_fn=collate_fn)
valid_loader = DataLoader(coco_valid, batch_size=8, shuffle=True, num_workers=8, collate_fn=collate_fn)

In [None]:
from util.box_ops import generalized_box_iou


NUM_CLASSES = 90

def train_fn(model: nn.Module, dataloader: torch.utils.data.DataLoader, epochs: int = 30):
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(DEVICE)
    model.train()
    loss_clasfier = nn.CrossEntropyLoss() 
    optimizer = torch.optim.AdamW(model.parameters())
    loss_box_matcher = matcher.HungarianMatcher(device=DEVICE).to(DEVICE)

    # every accum_steps to optim step (weight update)
    accumulation_steps = 16

    for epoch in tqdm(range(epochs), desc="Training", total=epochs):
        model.train()
        batch_loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
        for batch_idx, (images, targets) in enumerate(batch_loop):
            images = images.to(DEVICE)
            optimizer.zero_grad()
            prediction = model.forward(images)
            # [(tensor_indicies_boxes, tensor_indicies_clasf)]
            # print(prediction['pred_boxes'].shape)
            prediction['pred_boxes'] = prediction['pred_boxes'].permute(1, 0, 2)
            prediction['pred_logits'] = prediction['pred_logits'].permute(1, 0, 2)
            recon = loss_box_matcher.forward(prediction, targets)
            

            pred_boxes = [prediction['pred_boxes'][k, recon[k][0], :] for k in range(prediction['pred_boxes'].shape[0])]
            tg_boxes = [targets[k]['boxes'].to(DEVICE) for k in range(len(targets))]
            
            pred_clasf = [prediction['pred_logits'][k, recon[k][1], :] for k in range(prediction['pred_logits'].shape[0])]
            tg_clasf = [targets[k]['labels'].to(DEVICE) for k in range(len(targets))]
            iou_loss = 1 - torch.mean(generalized_box_iou(torch.cat([pred_image_boxes for pred_image_boxes in pred_boxes]),\
                                                           torch.cat([true_image_boxes for true_image_boxes in tg_boxes])))
            one_hot_target_class = torch.nn.functional.one_hot(torch.cat([tg for tg in tg_clasf]), num_classes=NUM_CLASSES+1)
            loss_clasf = loss_clasfier.forward(torch.cat([pc for pc in pred_clasf]), one_hot_target_class.float())
            total_loss = 0.0

            total_loss += (iou_loss + loss_clasf) / accumulation_steps
            batch_loop.set_postfix({
                'Loss': f'{total_loss:.4f}',
                'IoU': f'{1 - iou_loss.item():.4f}',
                'Cls': f'{loss_clasf.item():.4f}'
            })
            # for i, targetValue in enumerate(targets):
            #     prediction_boxes = prediction['pred_boxes'][recon[i][0]]
            #     prediction_clasf = prediction['pred_logits'][recon[i][1]]
            #     print(f"{prediction['pred_boxes'].shape} {prediction_boxes.shape} , {targets[i]['boxes'].shape}")
            total_loss.backward()
            if (batch_idx + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
        
        torch.save(model, f"chkpts/detr_{epoch+1}")
    model.eval()

In [None]:
# box1 = torch.tensor([[10,10,20,20], [50, 100, 200, 300]])
# box2 = torch.tensor([[10,10,15,15], [80, 120, 220, 310]])
# print(f"iou: {calculate_iou(box1, box2)}")

# print(box1.shape, box2.shape)
# box1[:, :2], box2[:, :2]

In [None]:
detr = None
if detr:
    del detr
detr = DETR(1024, 8, NUM_CLASSES, 128).to('cuda')

In [None]:
# sample
with torch.no_grad():
    images, targets = next(iter(train_loader))
    loss_box_matcher = matcher.HungarianMatcher(device='cuda').to('cuda')
    prediction = detr(torch.rand((8,3,800,800), device='cuda'))
    prediction['pred_boxes'] = prediction['pred_boxes'].permute(1, 0, 2)
    prediction['pred_logits'] = prediction['pred_logits'].permute(1, 0, 2)
    recon = loss_box_matcher.forward(prediction, targets)
    pred_boxes = [prediction['pred_boxes'][k, recon[k][0], :] for k in range(prediction['pred_boxes'].shape[0])]
    tg_boxes = [targets[k]['boxes'] for k in range(len(targets))]
    pred_clasf = [prediction['pred_logits'][k, recon[k][1], :] for k in range(prediction['pred_logits'].shape[0])]
    tg_clasf = [targets[k]['labels'] for k in range(len(targets))]
    print(f"{pred_boxes}")

In [None]:
# for i in range(prediction['pred_logits'].shape[0]):
#     print(f"i: {i}, {recon[i][1]}")
#     print(prediction['pred_logits'][i, recon[i][1], :])

In [None]:
# t1 = torch.randn((8,32,91))
# for i in range(prediction['pred_logits'].shape[0]):
#     print(f"recon: {recon[i][1]}, {prediction['pred_logits'][i, recon[i][1], :]}")


In [None]:
# with torch.no_grad():
#     pred = detr(torch.rand((32,3,800,800), device='cuda'))
#     x = pred["pred_logits"]
#     classes = torch.argmax(pred["pred_logits"], dim=-1)
#     unique_classes = torch.unique(classes)
#     print(unique_classes)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

num_params = count_parameters(detr)
space_mb = num_params * 4 / (1024 ** 2)
print(f"Parameters: {num_params}")
print(f"Estimated space: {space_mb:.2f} MB")

In [None]:
l1 = torch.tensor([0.05, 0.01])
l2 = torch.tensor([0.3, 0.4])
H = 4
W = 2
l1.repeat(H, 1, 1)[:H]
# pos = torch.cat([l1.repeat(H, 1, 1)[:H].unsqueeze(0), l2.repeat(1, W, 1).unsqueeze(1)])

In [None]:
# k = torch.rand((2, 2, 2)) # B C P -> P B C, B P C
# k, k.flatten(2).permute(2, 0, 1)

In [None]:
## Playground for row and col embed in detr
# H = 1
# W = 2
# hidden_dim = 8
# row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
# col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

# pos = torch.cat([
#  col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
#  row_embed[:H].unsqueeze(1).repeat(1, W, 1),
# ], dim=-1).flatten(0, 1).unsqueeze(1)
# row_embed[0], row_embed[:H].unsqueeze(1).repeat(1, W, 1)

In [None]:
# l = torch.tensor([[1,2,3,4], [4,5,6,7]])

# H = 2
# W = 2
# # l1 = l.unsqueeze(0).repeat(H, 1, 1)
# # l2 = l.unsqueeze(1).repeat(1, W, 1)

# torch.cat([
#     l.unsqueeze(0).repeat(H, 1, 1),
#     l.unsqueeze(1).repeat(1, W, 1),
# ], dim=-1).flatten(0, 1).unsqueeze(1)
# # l1.shape, l2.shape#, torch.cat([l1,l2], dim=-1)