# YOLOv8

## Backbone

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Conv+BN+SiLU 블록
class Conv(nn.Module):
    def __init__(self, c1, c2, k=3, s=1, p=None):
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, k // 2 if p is None else p, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU()  # Swish

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

# C2f (Cross Stage Partial Fusion)
class C2f(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=True, e=0.5):
        super().__init__()
        c_ = int(c2 * e)
        self.cv1 = Conv(c1, 2 * c_, 1, 1)
        self.cv2 = Conv((2 + n) * c_, c2, 1)
        self.m = nn.ModuleList([Bottleneck(c_, c_) for _ in range(n)])
        
    def forward(self, x):
        y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

# Bottleneck 블록 (C2f 내부)
class Bottleneck(nn.Module):
    def __init__(self, c1, c2, shortcut=True):
        super().__init__()
        self.cv1 = Conv(c1, c2, 1, 1)
        self.cv2 = Conv(c2, c2, 3, 1)
        self.add = shortcut

    def forward(self, x):
        out = self.cv2(self.cv1(x))
        return out + x if self.add else out

# SPPF
class SPPF(nn.Module):
    def __init__(self, c1, c2, k=5):
        super().__init__()
        self.cv1 = Conv(c1, c2, 1, 1)
        self.cv2 = Conv(c2 * 4, c2, 1, 1)
        self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

    def forward(self, x):
        x = self.cv1(x)
        y1 = self.m(x)
        y2 = self.m(y1)
        y3 = self.m(y2)
        return self.cv2(torch.cat([x, y1, y2, y3], 1))

# YOLOv8 Backbone 예시
class YOLOv8Backbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = Conv(3, 32, 3, 2)
        self.stage1 = C2f(32, 64, n=1)
        self.stage2 = C2f(64, 128, n=2)
        self.stage3 = C2f(128, 256, n=3)
        self.sppf = SPPF(256, 256)

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.sppf(x)
        return x


## Neck

In [None]:
class Neck(nn.Module):
    def __init__(self, channels=[256, 512, 1024]):  # 입력 feature의 채널 수 예시
        super().__init__()
        # Upsample, concat, C2f 등을 위한 레이어 정의
        self.reduce_conv1 = nn.Conv2d(channels[2], channels[1], 1)  # 1024→512
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.c2f1 = C2f(channels[1]*2, channels[1], n=3)  # Concat(512+512)→512
        
        self.reduce_conv2 = nn.Conv2d(channels[1], channels[0], 1)  # 512→256
        self.c2f2 = C2f(channels[0]*2, channels[0], n=3)  # Concat(256+256)→256
        
        # Downsample (PAN), for bottom-up path
        self.downsample1 = nn.Conv2d(channels[0], channels[0], 3, stride=2, padding=1)
        self.c2f3 = C2f(channels[0]*2, channels[1], n=3)  # (256+256)→512

        self.downsample2 = nn.Conv2d(channels[1], channels[1], 3, stride=2, padding=1)
        self.c2f4 = C2f(channels[1]*2, channels[2], n=3)  # (512+512)→1024

    def forward(self, features):
        # features: [P3, P4, P5] (예: [256, 512, 1024])
        P3, P4, P5 = features
        
        # FPN top-down
        x = self.reduce_conv1(P5)  # 1024→512
        x = self.upsample(x)
        x = torch.cat([x, P4], dim=1)  # (512+512)
        x = self.c2f1(x)
        
        y = self.reduce_conv2(x)
        y = self.upsample(y)
        y = torch.cat([y, P3], dim=1)
        y = self.c2f2(y)  # 최상위 해상도 feature
        
        # PAN bottom-up
        z = self.downsample1(y)
        z = torch.cat([z, x], dim=1)
        z = self.c2f3(z)
        
        w = self.downsample2(z)
        w = torch.cat([w, P5], dim=1)
        w = self.c2f4(w)
        
        # 최종적으로 [y, z, w] 등 여러 해상도 feature 반환 (head에서 사용)
        return [y, z, w]

## Head

In [None]:
class DetectPoseHead(nn.Module):
    def __init__(self, num_classes, num_keypoints, ch):  # ch: neck에서 들어오는 각 feature의 채널 수
        super().__init__()
        # 각 scale(feature map)별로 별도의 head 사용
        self.detect_layers = nn.ModuleList([
            nn.Conv2d(c, (num_classes + 4 + num_keypoints*3), 1) for c in ch
        ])
    
    def forward(self, features):
        # features: neck output, list of [P3, P4, P5]
        outputs = []
        for x, head in zip(features, self.detect_layers):
            y = head(x)  # [B, out_dim, H, W]
            # [B, 4+num_classes+num_keypoints*3, H, W]
            outputs.append(y)
        return outputs  # [P3_out, P4_out, P5_out]


## Loss

In [None]:
import torch
import torch.nn.functional as F

def yolo_pose_loss(
    pred,              # [B, A, H, W, 4+C+K*3] or [B, N, 4+C+K*3] (after flatten)
    target,            # GT dict with keys: 'boxes', 'classes', 'objectness', 'keypoints', 'kp_vis'
    num_classes,
    num_keypoints,
    box_weight=7.5,    # Default YOLOv8: box loss weight
    cls_weight=0.5,    # class loss weight
    obj_weight=1.0,    # objectness loss weight
    kpt_weight=1.5,    # keypoint loss weight
    device="cuda"
):
    """
    pred: [B, N, 4+C+K*3]
    target: {
        'boxes': [B, N, 4],      # (x, y, w, h) normalized
        'classes': [B, N],       # class idx
        'objectness': [B, N],    # 1 if object, else 0
        'keypoints': [B, N, K, 2],  # GT (x, y) for each keypoint (normalized)
        'kp_vis': [B, N, K],     # 1 if visible, 0 if not
    }
    """

    # 1. Objectness Loss (BCE)
    obj_pred = pred[..., 4]
    obj_gt = target['objectness'].float().to(device)
    obj_loss = F.binary_cross_entropy_with_logits(obj_pred, obj_gt, reduction='mean')

    # 2. Box Regression Loss (CIoU or GIoU, here just SmoothL1 for simplicity)
    box_pred = pred[..., :4]
    box_gt = target['boxes'].to(device)
    box_loss = F.smooth_l1_loss(box_pred, box_gt, reduction='none')    # [B, N, 4]
    box_loss = box_loss.mean(-1)                                      # [B, N]
    box_loss = (box_loss * obj_gt).sum() / (obj_gt.sum() + 1e-8)      # Only for positive anchors

    # 3. Classification Loss (BCE for multi-label, CE for single-label)
    class_pred = pred[..., 5:5+num_classes]                           # [B, N, C]
    class_gt = F.one_hot(target['classes'].long(), num_classes).float().to(device)   # [B, N, C]
    class_loss = F.binary_cross_entropy_with_logits(class_pred, class_gt, reduction='none')  # [B, N, C]
    class_loss = (class_loss.mean(-1) * obj_gt).sum() / (obj_gt.sum() + 1e-8)       # Only for objects

    # 4. Keypoint Loss (SmoothL1 for (x, y), BCE for conf)
    start = 5 + num_classes
    kpt_pred = pred[..., start:].reshape(*pred.shape[:-1], num_keypoints, 3)         # [B, N, K, 3]
    kpt_gt = target['keypoints'].to(device)                                          # [B, N, K, 2]
    kpt_vis = target['kp_vis'].float().to(device)                                    # [B, N, K]

    # (x, y) loss
    xy_loss = F.smooth_l1_loss(kpt_pred[..., :2], kpt_gt, reduction='none').sum(-1)  # [B, N, K]
    xy_loss = (xy_loss * kpt_vis * obj_gt.unsqueeze(-1)).sum() / ((kpt_vis * obj_gt.unsqueeze(-1)).sum() + 1e-8)

    # conf loss (confidence: visible or not)
    conf_pred = kpt_pred[..., 2]
    conf_gt = kpt_vis
    conf_loss = F.binary_cross_entropy_with_logits(conf_pred, conf_gt, reduction='none')  # [B, N, K]
    conf_loss = (conf_loss * obj_gt.unsqueeze(-1)).sum() / ((kpt_vis * obj_gt.unsqueeze(-1)).sum() + 1e-8)

    kpt_loss = xy_loss + 0.5 * conf_loss   # 0.5는 공식 코드 기준, 조정 가능

    # 총합
    total_loss = (
        box_weight * box_loss +
        cls_weight * class_loss +
        obj_weight * obj_loss +
        kpt_weight * kpt_loss
    )

    # (optionally) 개별 loss 값도 return
    return total_loss, {
        "box_loss": box_loss.item(),
        "obj_loss": obj_loss.item(),
        "class_loss": class_loss.item(),
        "kpt_loss": kpt_loss.item()
    }
