In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet50
import os
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
from torchvision.ops import generalized_box_iou
from scipy.optimize import linear_sum_assignment
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.patches as patches



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
class YoloDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.PNG')]
        self.transform = transform or transforms.Compose([
            transforms.Resize((800, 1066)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_filename = self.image_filenames[idx]
        label_filename = img_filename.replace('.PNG', '.txt')

        image = Image.open(os.path.join(self.image_dir, img_filename)).convert("RGB")
        w, h = image.size

        boxes = []
        labels = []

        with open(os.path.join(self.label_dir, label_filename), 'r') as f:
            for line in f.readlines():
                parts = line.strip().split()
                class_id = int(parts[0])
                x, y, bw, bh = map(float, parts[1:])
                labels.append(class_id)
                boxes.append([x, y, bw, bh])

        image_tensor = self.transform(image)
        target = {
            "labels": torch.tensor(labels, dtype=torch.long),
            "boxes": torch.tensor(boxes, dtype=torch.float)
        }

        return image_tensor, target

In [None]:
class DETR(nn.Module):

    def __init__(self, num_classes=1, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()

        
        self.backbone = resnet50(pretrained=True)
        del self.backbone.fc 
        
        self.conv = nn.Conv2d(2048, hidden_dim, 1)
        
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)
        
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs): 
        x = self.backbone.conv1(inputs)    
        x = self.backbone.bn1(x)           
        x = self.backbone.relu(x)          
        x = self.backbone.maxpool(x)       

        x = self.backbone.layer1(x)        
        x = self.backbone.layer2(x)        
        x = self.backbone.layer3(x)        
        x = self.backbone.layer4(x)        

        h = self.conv(x)                   

        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1) 

        src = pos + 0.1 * h.flatten(2).permute(2, 0, 1)  
        batch_size = inputs.shape[0]
        target = self.query_pos.unsqueeze(1).repeat(1, batch_size, 1)  

        
        h = self.transformer(src, target).transpose(0, 1) 

        linear_cls = self.linear_class(h)        
        linear_bbx = self.linear_bbox(h).sigmoid()  
        
        
        return {'pred_logits': linear_cls,  
                'pred_boxes': linear_bbx}

In [None]:
# Convert boxes from [cx, cy, w, h] to [x1, y1, x2, y2]
def box_cxcywh_to_xyxy(boxes):
    x_c, y_c, w, h = boxes.unbind(-1)
    return torch.stack([x_c - 0.5 * w, y_c - 0.5 * h,
                        x_c + 0.5 * w, y_c + 0.5 * h], dim=-1)

# Compute cost matrix for matching
def compute_cost(pred_logits, pred_boxes, tgt_labels, tgt_boxes):
    # Classification cost (negative log-likelihood)
    prob = pred_logits.softmax(-1)
    cost_class = -prob[:, tgt_labels]

    # L1 cost
    cost_bbox = torch.cdist(pred_boxes, tgt_boxes, p=1)

    # IoU cost (1 - GIoU)
    giou = generalized_box_iou(
        box_cxcywh_to_xyxy(pred_boxes),
        box_cxcywh_to_xyxy(tgt_boxes)
    )
    cost_giou = 1 - giou

    return cost_class + cost_bbox + cost_giou

# Hungarian matching
def match(pred_logits, pred_boxes, targets):
    indices = []
    for b in range(pred_logits.shape[0]):
        cost = compute_cost(
            pred_logits[b], pred_boxes[b],
            targets[b]['labels'], targets[b]['boxes']
        ).detach().cpu().numpy()
        row_ind, col_ind = linear_sum_assignment(cost)
        indices.append((row_ind, col_ind))
    return indices

# Loss function
class DetrLoss(nn.Module):
    def __init__(self, class_weight=1, bbox_weight=5, giou_weight=2):
        super().__init__()
        self.class_weight = class_weight
        self.bbox_weight = bbox_weight
        self.giou_weight = giou_weight

    def forward(self, outputs, targets):
        pred_logits = outputs['pred_logits']  # [B, 100, num_classes+1]
        pred_boxes = outputs['pred_boxes']    # [B, 100, 4]

        indices = match(pred_logits, pred_boxes, targets)

        loss_cls, loss_bbox, loss_giou = 0, 0, 0

        for b, (src_idx, tgt_idx) in enumerate(indices):
            src_logits = pred_logits[b][src_idx]
            tgt_labels = targets[b]['labels'][tgt_idx]
            loss_cls += F.cross_entropy(src_logits, tgt_labels)

            src_boxes = pred_boxes[b][src_idx]
            tgt_boxes = targets[b]['boxes'][tgt_idx]

            loss_bbox += F.l1_loss(src_boxes, tgt_boxes)

            giou = generalized_box_iou(
                box_cxcywh_to_xyxy(src_boxes),
                box_cxcywh_to_xyxy(tgt_boxes)
            )
            loss_giou += 1 - giou.diag().mean()

        num_boxes = sum(len(t['labels']) for t in targets)
        loss_dict = {
            'loss_cls': loss_cls / num_boxes,
            'loss_bbox': loss_bbox / num_boxes,
            'loss_giou': loss_giou / num_boxes
        }

        total_loss = (
            self.class_weight * loss_dict['loss_cls'] +
            self.bbox_weight * loss_dict['loss_bbox'] +
            self.giou_weight * loss_dict['loss_giou']
        )

        loss_dict['total_loss'] = total_loss
        return loss_dict

In [None]:
from torch.utils.data import DataLoader

dataset = YoloDataset("/datasets/tdt4265/ad/open/Poles/rgb/images/train", "/datasets/tdt4265/ad/open/Poles/rgb/labels/train")
for images, targets in dataset:
    print(targets['boxes'])


dataloader = DataLoader(dataset, batch_size=10, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

model = DETR().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = DetrLoss()

num_epochs = 50


for epoch in range(num_epochs):

    model.train()
    running_loss = 0

    dataloader_tqdm = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    
    for images, targets in dataloader_tqdm:

        images = torch.stack(images).to(device)
        
        
        targets = [
            {
                "labels": t["labels"].to(device),
                "boxes": t["boxes"].to(device)
            }
            for t in targets
        ]
        
        outputs = model(images)

        loss_dict = loss_fn(outputs, targets)
        loss = loss_dict['total_loss']

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        
        running_loss += loss.item()
        dataloader_tqdm.set_postfix({"Loss": f"{loss.item():.4f}"})

    avg_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Loss: {avg_loss:.4f}")

In [None]:
model.eval()

im_num = 30

image_path = "/datasets/tdt4265/ad/open/Poles/rgb/images/valid"

images = [str(p) for p in Path(image_path).glob("*.PNG")]

image = Image.open(images[im_num]).convert("RGB")
transform = transforms.Compose([
    transforms.Resize((800, 1066)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0).to(device) 

with torch.no_grad():
    outputs = model(image_tensor)

logits = outputs['pred_logits'][0]     
boxes = outputs['pred_boxes'][0]      
probs = logits.softmax(-1)
scores, labels = probs[..., :-1].max(-1) 

threshold = 0.7 
keep = scores > threshold

boxes = boxes[keep]
labels = labels[keep]
scores = scores[keep]



def box_cxcywh_to_xyxy(box):
    x_c, y_c, w, h = box.unbind(-1)
    return torch.stack([
        x_c - 0.5 * w, y_c - 0.5 * h,
        x_c + 0.5 * w, y_c + 0.5 * h
    ], dim=-1)

w, h = image.size

boxes = box_cxcywh_to_xyxy(boxes)
boxes *= torch.tensor([w, h, w, h], device=boxes.device)

def plot_detections(image, boxes, scores, labels):
    plt.figure(figsize=(12, 8))
    plt.imshow(image)
    ax = plt.gca()

    for box, score in zip(boxes, scores):
        x1, y1, x2, y2 = box.tolist()
        rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                 linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        ax.text(x1, y1, f'{score:.2f}', fontsize=10, color='white',
                bbox=dict(facecolor='red', edgecolor='none', alpha=0.5))

    plt.axis('off')
    plt.show()

plot_detections(image, boxes, scores, labels)
