In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from pycocotools.coco import COCO
import numpy as np
from PIL import Image
import os
import math


COCO_CLASSES = 80
ANCHORS = [
    [(116, 90), (156, 198), (373, 326)],  # Layer 82
    [(30, 61), (62, 45), (59, 119)],      # Layer 94
    [(10, 13), (16, 30), (33, 23)]        # Layer 106
]
IMAGE_SIZE = 416

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1):
        super().__init__()
        padding = (kernel_size - 1) // 2
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.leaky = nn.LeakyReLU(0.1)
        
    def forward(self, x):
        return self.leaky(self.bn(self.conv(x)))

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        
        self.block = nn.Sequential(
            ConvBlock(channels, channels // 2, 1),
            ConvBlock(channels // 2, channels, 3)
        )
        
    def forward(self, x):
        return x + self.block(x)

class Darknet53(nn.Module):
    def __init__(self):
        super().__init__()
    
        self.conv1 = ConvBlock(3, 32, 3)
 
        self.conv2 = ConvBlock(32, 64, 3, stride=2)
        self.res1 = nn.Sequential(ResidualBlock(64))
        
        self.conv3 = ConvBlock(64, 128, 3, stride=2)
        self.res2 = nn.Sequential(*[ResidualBlock(128) for _ in range(2)])

        self.conv4 = ConvBlock(128, 256, 3, stride=2)
        self.res3 = nn.Sequential(*[ResidualBlock(256) for _ in range(8)])

        self.conv5 = ConvBlock(256, 512, 3, stride=2)
        self.res4 = nn.Sequential(*[ResidualBlock(512) for _ in range(8)])

        self.conv6 = ConvBlock(512, 1024, 3, stride=2)
        self.res5 = nn.Sequential(*[ResidualBlock(1024) for _ in range(4)])
        
    def forward(self, x):
        features = []
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.res1(x)
        
        x = self.conv3(x)
        x = self.res2(x)
        
        x = self.conv4(x)
        x = self.res3(x)
        features.append(x)  
        
        x = self.conv5(x)
        x = self.res4(x)
        features.append(x) 
        
        x = self.conv6(x)
        x = self.res5(x)
        features.append(x) 
        
        return features

class DetectionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv1 = ConvBlock(in_channels, out_channels, 1)
        self.conv2 = ConvBlock(out_channels, out_channels * 2, 3)
        self.conv3 = ConvBlock(out_channels * 2, out_channels, 1)
        self.conv4 = ConvBlock(out_channels, out_channels * 2, 3)
        self.conv5 = ConvBlock(out_channels * 2, out_channels, 1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x

class YOLOv3(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        self.num_classes = num_classes
        self.backbone = Darknet53()
        
        # Detection layers
        self.detect1 = DetectionBlock(1024, 512)
        self.detect2 = DetectionBlock(512, 256)
        self.detect3 = DetectionBlock(256, 128)
        
        # Prediction conv layers (output: [batch, anchors * (5 + num_classes), grid, grid])
        self.pred1 = nn.Conv2d(512, 3 * (5 + num_classes), 1)
        self.pred2 = nn.Conv2d(256, 3 * (5 + num_classes), 1)
        self.pred3 = nn.Conv2d(128, 3 * (5 + num_classes), 1)
        
        # Upsampling layers
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        
        # Additional conv layers for feature map processing
        self.conv1 = ConvBlock(512, 256, 1)
        self.conv2 = ConvBlock(256, 128, 1)
        
    def forward(self, x):
        f1, f2, f3 = self.backbone(x)  # Large to small feature maps
        
        # First detection branch
        x1 = self.detect1(f3)
        p1 = self.pred1(x1)
        
        # Second detection branch
        x1 = self.conv1(x1)
        x1 = self.upsample(x1)
        x2 = torch.cat([x1, f2], dim=1)
        x2 = self.detect2(x2)
        p2 = self.pred2(x2)
        
        # Third detection branch
        x2 = self.conv2(x2)
        x2 = self.upsample(x2)
        x3 = torch.cat([x2, f1], dim=1)
        x3 = self.detect3(x3)
        p3 = self.pred3(x3)
        
        return p1, p2, p3

class YOLOLoss(nn.Module):
    def __init__(self, anchors, num_classes, img_size):
        super().__init__()
        self.anchors = anchors
        self.num_classes = num_classes
        self.img_size = img_size
        self.mse = nn.MSELoss()
        self.bce = nn.BCELoss()
        self.ignore_thres = 0.5
        self.obj_scale = 1
        self.noobj_scale = 100
        self.metrics = {}

    def forward(self, predictions, targets):
        device = predictions[0].device
        total_loss = 0
        
        for i, pred in enumerate(predictions):
            batch_size = pred.size(0)
            grid_size = pred.size(2)
            
            # Transform predictions
            prediction = pred.view(batch_size, 3, 5 + self.num_classes, grid_size, grid_size)
            prediction = prediction.permute(0, 1, 3, 4, 2).contiguous()
            
            # Get outputs
            x = torch.sigmoid(prediction[..., 0])
            y = torch.sigmoid(prediction[..., 1])
            w = prediction[..., 2]  
            h = prediction[..., 3]  
            pred_conf = torch.sigmoid(prediction[..., 4])  # Object confidence
            pred_cls = torch.sigmoid(prediction[..., 5:])  # Class predictions
            
            # Calculate offsets for each grid
            grid_x = torch.arange(grid_size, device=device).repeat(grid_size, 1).view([1, 1, grid_size, grid_size])
            grid_y = torch.arange(grid_size, device=device).repeat(grid_size, 1).t().view([1, 1, grid_size, grid_size])
            scaled_anchors = torch.tensor([(a_w / self.img_size, a_h / self.img_size) 
                                         for a_w, a_h in self.anchors[i]], device=device)
            anchor_w = scaled_anchors[:, 0:1].view((1, 3, 1, 1))
            anchor_h = scaled_anchors[:, 1:2].view((1, 3, 1, 1))
            
            # Add offset and scale with anchors
            pred_boxes = torch.zeros_like(prediction[..., :4])
            pred_boxes[..., 0] = x + grid_x
            pred_boxes[..., 1] = y + grid_y
            pred_boxes[..., 2] = torch.exp(w) * anchor_w
            pred_boxes[..., 3] = torch.exp(h) * anchor_h
            
            # Process targets
            target_mask, obj_mask, noobj_mask, tx, ty, tw, th, tconf, tcls = self.build_targets(
                pred_boxes, targets[i], scaled_anchors, grid_size
            )
            
            # Calculate losses
            loss_x = self.mse(x[obj_mask], tx[obj_mask])
            loss_y = self.mse(y[obj_mask], ty[obj_mask])
            loss_w = self.mse(w[obj_mask], tw[obj_mask])
            loss_h = self.mse(h[obj_mask], th[obj_mask])
            loss_conf_obj = self.bce(pred_conf[obj_mask], tconf[obj_mask])
            loss_conf_noobj = self.bce(pred_conf[noobj_mask], tconf[noobj_mask])
            loss_conf = self.obj_scale * loss_conf_obj + self.noobj_scale * loss_conf_noobj
            loss_cls = self.bce(pred_cls[obj_mask], tcls[obj_mask])
            
            total_loss += loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls
            
        return total_loss

def build_targets(self, pred_boxes, target, anchors, grid_size):
    
    batch_size = pred_boxes.size(0)
    num_anchors = len(anchors)
    
    # Initialize output tensors
    obj_mask = torch.zeros(batch_size, num_anchors, grid_size, grid_size, dtype=torch.bool, device=pred_boxes.device)
    noobj_mask = torch.ones(batch_size, num_anchors, grid_size, grid_size, dtype=torch.bool, device=pred_boxes.device)
    target_mask = torch.zeros(batch_size, num_anchors, grid_size, grid_size, dtype=torch.bool, device=pred_boxes.device)
    
    tx = torch.zeros(batch_size, num_anchors, grid_size, grid_size, dtype=torch.float, device=pred_boxes.device)
    ty = torch.zeros(batch_size, num_anchors, grid_size, grid_size, dtype=torch.float, device=pred_boxes.device)
    tw = torch.zeros(batch_size, num_anchors, grid_size, grid_size, dtype=torch.float, device=pred_boxes.device)
    th = torch.zeros(batch_size, num_anchors, grid_size, grid_size, dtype=torch.float, device=pred_boxes.device)
    tconf = torch.zeros(batch_size, num_anchors, grid_size, grid_size, dtype=torch.float, device=pred_boxes.device)
    tcls = torch.zeros(batch_size, num_anchors, grid_size, grid_size, self.num_classes, dtype=torch.float, device=pred_boxes.device)
    
    if len(target) == 0:
        return target_mask, obj_mask, noobj_mask, tx, ty, tw, th, tconf, tcls
    
    # Convert anchors to tensor
    anchors = torch.tensor(anchors, device=pred_boxes.device)
    
    # For each ground truth box
    for target_idx in range(len(target)):
        # Get batch index, class, and box coordinates
        batch_idx = int(target[target_idx, 0])
        class_idx = int(target[target_idx, 1])
        
        # Get ground truth box coordinates (normalized)
        gx = target[target_idx, 2] * grid_size  # center x
        gy = target[target_idx, 3] * grid_size  # center y
        gw = target[target_idx, 4] * grid_size  # width
        gh = target[target_idx, 5] * grid_size  # height
        
        # Get grid cell coordinates
        gi = int(gx)
        gj = int(gy)
        
        # Get ground truth box in anchor format
        gt_box = torch.tensor([0, 0, gw, gh], device=pred_boxes.device)
        
        # Calculate IoU between ground truth and anchor boxes
        anchor_boxes = torch.cat((torch.zeros(num_anchors, 2, device=pred_boxes.device), anchors), 1)
        anchor_ious = bbox_iou(gt_box.unsqueeze(0), anchor_boxes, x1y1x2y2=False)
        
        # Find the best matching anchor box
        best_anchor_idx = anchor_ious.argmax()
        
        # Only process if the best IoU is above a threshold (usually 0.3-0.5)
        if anchor_ious[best_anchor_idx] > 0.3:
            # Set masks
            obj_mask[batch_idx, best_anchor_idx, gj, gi] = True
            noobj_mask[batch_idx, best_anchor_idx, gj, gi] = False
            target_mask[batch_idx, best_anchor_idx, gj, gi] = True
            
            # Set target values
            tx[batch_idx, best_anchor_idx, gj, gi] = gx - gi  # x offset
            ty[batch_idx, best_anchor_idx, gj, gi] = gy - gj  # y offset
            
            # Width and height targets (log space)
            tw[batch_idx, best_anchor_idx, gj, gi] = torch.log(gw / anchors[best_anchor_idx, 0] + 1e-16)
            th[batch_idx, best_anchor_idx, gj, gi] = torch.log(gh / anchors[best_anchor_idx, 1] + 1e-16)
            
            # Confidence and class targets
            tconf[batch_idx, best_anchor_idx, gj, gi] = 1
            tcls[batch_idx, best_anchor_idx, gj, gi, class_idx] = 1
    
    return target_mask, obj_mask, noobj_mask, tx, ty, tw, th, tconf, tcls
    
    def bbox_iou(box1, box2, x1y1x2y2=True):
        if not x1y1x2y2:
            # Transform from center and width to exact coordinates
            b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
            b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
            b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
            b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
        else:
            b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
            b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
        
        # Get the coordinates of the intersection rectangle
        inter_rect_x1 = torch.max(b1_x1.unsqueeze(1), b2_x1.unsqueeze(0))
        inter_rect_y1 = torch.max(b1_y1.unsqueeze(1), b2_y1.unsqueeze(0))
        inter_rect_x2 = torch.min(b1_x2.unsqueeze(1), b2_x2.unsqueeze(0))
        inter_rect_y2 = torch.min(b1_y2.unsqueeze(1), b2_y2.unsqueeze(0))
        
        # Intersection area
        inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * \
                     torch.clamp(inter_rect_y2 - inter_rect_y1 + 1, min=0)
        
        # Union Area
        b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
        b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
        union_area = b1_area.unsqueeze(1) + b2_area.unsqueeze(0) - inter_area
        
        return inter_area / union_area

class COCODataset(Dataset):
    def __init__(self, root_dir, ann_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.coco = COCO(ann_file)
        self.ids = list(self.coco.imgs.keys())
        
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        img_id = self.ids[idx]
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        annotations = self.coco.loadAnns(ann_ids)
        
        # Load image
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root_dir, img_info['file_name'])
        img = Image.open(img_path).convert('RGB')
        
        # Get bounding boxes and labels
        boxes = []
        labels = []
        for ann in annotations:
            bbox = ann['bbox']  # [x, y, width, height]
            # Convert to YOLO format [center_x, center_y, width, height]
            x_center = (bbox[0] + bbox[2]/2) / img_info['width']
            y_center = (bbox[1] + bbox[3]/2) / img_info['height']
            width = bbox[2] / img_info['width']
            height = bbox[3] / img_info['height']
            boxes.append([x_center, y_center, width, height])
            labels.append(ann['category_id'])
        
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.long)
        
        if self.transform:
            img = self.transform(img)
        
        return img, boxes, labels

def train_one_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    
    for batch_idx, (images, boxes, labels) in enumerate(dataloader):
        images = images.to(device)
        boxes = boxes.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        predictions = model(images)
        loss = loss_fn(predictions, [boxes, labels])
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 100 == 0:
            print(f'Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}')
    
    return total_loss / len(dataloader)

def evaluate(model, dataloader, loss_fn, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for images, boxes, labels in dataloader:
            images = images.to(device)
            boxes = boxes.to(device)
            labels = labels.to(device)
            
            predictions = model(images)
            loss = loss_fn(predictions, [boxes, labels])
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

def main():
    # Hyperparameters
    BATCH_SIZE = 8
    LEARNING_RATE = 1e-4
    NUM_EPOCHS = 100
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Transform
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Dataset and DataLoader
    train_dataset = COCODataset(
        root_dir='path/to/coco/train2017',
        ann_file='path/to/coco/annotations/instances_train2017.json',
        transform=transform
    )
    
    val_dataset = COCODataset(
        root_dir='path/to/coco/val2017',
        ann_file='path/to/coco/annotations/instances_val2017.json',
        transform=transform
    )
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        collate_fn=collate_fn  # Need to implement this
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        collate_fn=collate_fn
    )
    
    # Model, optimizer, and loss
    model = YOLOv3(num_classes=COCO_CLASSES).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_fn = YOLOLoss(ANCHORS, COCO_CLASSES, IMAGE_SIZE)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=3, verbose=True
    )
    
    # Training loop
    best_val_loss = float('inf')
    for epoch in range(NUM_EPOCHS):
        print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')
        
        # Train
        train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, DEVICE)
        print(f'Training Loss: {train_loss:.4f}')
        
        # Validate
        val_loss = evaluate(model, val_loader, loss_fn, DEVICE)
        print(f'Validation Loss: {val_loss:.4f}')
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
            }, 'best_model.pth')

if __name__ == '__main__':
    main()