In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torchvision.ops import RoIAlign
import torch.nn.functional as F

class ResNetBackbone(nn.Module):
    def __init__(self, pretrained=False):
        super().__init__()
        resnet = models.resnet18(pretrained=pretrained)
        # remove avgpool & fc
        self.features = nn.Sequential(*list(resnet.children())[:-2])
        self.out_channels = 512

    def forward(self, x):
        return self.features(x)  # (B, 512, G, G)

class AnchorGridDetector(nn.Module):
    def __init__(self, in_channels, grid_size=7, anchors=None):
        """
        anchors: List of (w,h) tuples, normalized to [0,1] relative to image size
        ex) anchors=[(0.3,0.3), (0.6,0.6), (1.0,1.0)]
        """
        super().__init__()
        self.G = grid_size
        self.anchors = torch.tensor(anchors, dtype=torch.float32)  # (A,2)
        self.A = len(anchors)
        # A×5 channels: tx, ty, tw, th, obj_conf
        self.head = nn.Conv2d(in_channels, self.A*5, kernel_size=1)

    def forward(self, fmap):
        B, C, G, _ = fmap.shape
        device = fmap.device

        # 1) raw predictions
        p = self.head(fmap)                          # (B, A*5, G, G)
        p = p.view(B, self.A, 5, G, G)               # (B, A, 5, G, G)
        p = p.permute(0, 3, 4, 1, 2)                 # (B, G, G, A, 5)

        # 2) decode (optional, inference 때만)
        conf = torch.sigmoid(p[..., 4])              # (B,G,G,A)
        xy   = torch.sigmoid(p[..., 0:2])            # offsets
        wh   = torch.exp(p[..., 2:4]) * self.anchors.view(1,1,1,self.A,2)

        # grid cell 좌표
        grid = torch.arange(G, device=device, dtype=torch.float32)
        gy, gx = torch.meshgrid(grid, grid, indexing='ij')
        gx = gx.view(1,G,G,1); gy = gy.view(1,G,G,1)
        cell = 1.0 / G

        x_abs = (gx + xy[...,0:1]) * cell
        y_abs = (gy + xy[...,1:2]) * cell
        w_abs = wh[...,0:1] * cell
        h_abs = wh[...,1:2] * cell

        # inference 시 boxes 리스트로 뽑으려면 위 값을 이용하세요...
        return p, (x_abs, y_abs, w_abs, h_abs, conf)
    

class PoseHead(nn.Module):
    def __init__(self, in_channels, num_keypoints=4, pool_size=7):
        super().__init__()
        self.roi_align = RoIAlign((pool_size, pool_size),
                                  spatial_scale=1.0,  # 이미 normalized coords 사용
                                  sampling_ratio=-1)
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(in_channels * pool_size * pool_size, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_keypoints * 2)
        )

    def forward(self, fmap, boxes):
        """
        fmap: (B, C, H, W)
        boxes: List[Tensor(N_i, 4)] in normalized coords [x1,y1,x2,y2]
        """
        # concat all boxes with batch idx
        rois = []
        for b_idx, b in enumerate(boxes):
            if b.numel() == 0:
                continue
            idx = torch.full((b.size(0),1), b_idx, device=fmap.device)
            rois.append(torch.cat([idx, b], dim=1))
        if not rois:
            return torch.empty((0, self.fc[-1].out_features), device=fmap.device)
        rois = torch.cat(rois, dim=0)  # (sum_N, 5)
        aligned = self.roi_align(fmap, rois)  # (sum_N, C, pool, pool)
        flat = self.flatten(aligned)          # (sum_N, C*pool*pool)
        coords = self.fc(flat)                # (sum_N, K*2)
        return coords

class MultiObjectKeypointNet(nn.Module):
    def __init__(self, num_keypoints=4, grid_size=7, pretrained_backbone=False):
        super().__init__()
        self.backbone = ResNetBackbone(pretrained=pretrained_backbone)
        self.detector = AnchorGridDetector(self.backbone.out_channels, grid_size)
        self.pose_head = PoseHead(self.backbone.out_channels, num_keypoints)
        self.grid_size = grid_size

    def forward(self, x):
        """
        x: (B,3,H,W), expected normalized to [0,1]
        returns:
          boxes_batch: List[Tensor(N_i,4)] normalized bbox coords
          keypoints_batch: List[List[Tensor(K*2)]] per-object keypoints
        """
        B, _, H, W = x.shape
        fmap = self.backbone(x)
        boxes_batch, scores = self.detector(fmap)
        coords = self.pose_head(fmap, boxes_batch)  # (sum_N, K*2)

        # split coords back per image
        keypoints_batch = [[] for _ in range(B)]
        ptr = 0
        for b_idx, boxes in enumerate(boxes_batch):
            n = boxes.size(0)
            if n>0:
                keypoints_batch[b_idx] = coords[ptr:ptr+n]
            ptr += n

        return boxes_batch, keypoints_batch
    
    def training_forward(self, x, det_labels, pose_labels):
        B = x.size(0)
        fmap = self.backbone(x)
        # 1) Detection raw
        p, _ = self.detector(fmap)                   # (B,G,G,A,5)
        G, A = self.G, self.detector.A
        device = x.device

        # 2) 타깃 생성
        #   det_target: (B,G,G,A,5) 
        #   ignore_mask: (B,G,G,A) → no-object 학습 시 제외할 것
        det_target    = torch.zeros_like(p, device=device)
        ignore_mask   = torch.zeros((B,G,G,A), device=device, dtype=torch.bool)

        for b in range(B):
            boxes = det_labels[b]    # (N_i,5): [conf,x,y,w,h]
            # gt corner 계산
            gt_xywh = boxes[:,1:5]  # (N_i,4)
            # 각 gt마다 best anchor 찾기
            for gt in gt_xywh:
                x_c,y_c,w,h = gt.tolist()
                gj = int(x_c * G); gi = int(y_c * G)

                # anchor별 IoU(wh vs anchor_wh) 계산
                # 단순히 wh비교 -> min/max ratio로 대체 가능
                anchor_wh = self.detector.anchors.to(device)  # (A,2)
                gt_wh     = gt[2:4].unsqueeze(0)              # (1,2)
                min_wh    = torch.min(anchor_wh, gt_wh)
                max_wh    = torch.max(anchor_wh, gt_wh)
                ious_wh   = (min_wh.prod(dim=1) / max_wh.prod(dim=1))  # (A,)
                best_a    = torch.argmax(ious_wh)

                # objectness target
                det_target[b, gi, gj, best_a, 4] = 1.0

                # tx, ty : cell 내 상대 offset
                det_target[b, gi, gj, best_a, 0] = x_c*G - gj
                det_target[b, gi, gj, best_a, 1] = y_c*G - gi

                # tw, th : log-space ratio
                aw, ah = anchor_wh[best_a]
                det_target[b, gi, gj, best_a, 2] = torch.log(w/aw + 1e-6)
                det_target[b, gi, gj, best_a, 3] = torch.log(h/ah + 1e-6)

                # 나머지 anchor들은 no-object로 학습하되, 
                # GT 박스와 IoU가 작은 anchor만 no-object loss에 포함
                ignore_mask[b, gi, gj, :] = ious_wh > 0.5

        # 3) Pose GT 박스 준비 (corner)
        gt_corner_boxes = []
        for b in range(B):
            corners = []
            for det in det_labels[b]:
                _, x_c, y_c, w, h = det.tolist()
                x1,y1 = x_c - w/2, y_c - h/2
                x2,y2 = x_c + w/2, y_c + h/2
                corners.append([x1, y1, x2, y2])
            if corners:
                gt_corner_boxes.append(torch.tensor(corners, device=device))
            else:
                gt_corner_boxes.append(torch.empty((0,4), device=device))

        # 4) Pose raw
        pose_pred = self.pose_head(fmap, gt_corner_boxes)  # (sum_N, K*2)
        pose_target = torch.cat(pose_labels, dim=0) if pose_labels else torch.empty_like(pose_pred)

        return p, det_target, ignore_mask, pose_pred, pose_target


In [None]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader

# (기존 클래스 import 후)

# Custom Dataset 정의
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, transform=None):
        self.img_dir = os.path.join(data_path, 'images')
        self.label_dir = os.path.join(data_path, 'labels')
        self.img_files = sorted(os.listdir(self.img_dir))
        self.transform = transform

    def __len__(self):
        return len(self.img_files)


    def __getitem__(self, idx):
        img_file = self.img_files[idx]
        label_file = img_file.replace('.jpg', '.txt')

        # 이미지 로드
        img_path = os.path.join(self.img_dir, img_file)
        img = cv2.imread(img_path)
        if self.transform is not None:
            img = self.transform(img)

        # 레이블 로드 (모든 객체)
        label_path = os.path.join(self.label_dir, label_file)
        det_labels = []
        pose_labels = []

        with open(label_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 5:
                    det = [1.0] + [float(x) for x in parts[1:5]]  # objectness = 1.0
                    pose = [float(x) for x in parts[5:]]
                    det_labels.append(det)
                    pose_labels.append(pose)

        det_labels = torch.tensor(det_labels, dtype=torch.float32)  # (num_objs, 5)
        pose_labels = torch.tensor(pose_labels, dtype=torch.float32)  # (num_objs, num_kps*2)

        return img, det_labels, pose_labels


    
base = '/home/otter/dataset/pallet/dataset/train'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((360, 360)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def custom_collate_fn(batch):
    images, dets, poses = zip(*batch)  # 튜플을 분리

    images = torch.stack(images, dim=0)  # 이미지만은 고정 크기이므로 stack 가능

    return images, list(dets), list(poses)
dataset = CustomDataset(base, transform)
loader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=custom_collate_fn)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiObjectKeypointNet(num_keypoints=4, grid_size=7, pretrained_backbone=False)
model.to(device)

# Detection: objectness은 BCEWithLogits, bbox는 MSE
det_loss_fn_conf = nn.BCEWithLogitsLoss()
det_loss_fn_bbox = nn.MSELoss()
pose_loss_fn     = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=1e-4)

# --- 학습 루프 ---
epochs = 20
for epoch in range(epochs):
    model.train()
    total_det_loss  = 0.0
    total_pose_loss = 0.0

    for imgs, det_labels, pose_labels in tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}"):
        # imgs: Tensor(B,3,H,W)
        # det_labels: List of Tensor(N_i,4)
        # pose_labels: List of Tensor(N_i,K*2)
        imgs = imgs.to(device)
        det_labels  = [d.to(device) for d in det_labels]
        pose_labels = [p.to(device) for p in pose_labels]

        optimizer.zero_grad()
        det_out, det_target, pose_out, pose_target = \
            model.training_forward(imgs, det_labels, pose_labels)

        # Detection loss
        conf_pred   = det_out[..., 0]
        conf_target = det_target[..., 0]
        bbox_pred   = det_out[..., 1:]
        bbox_target = det_target[..., 1:]
        loss_conf = det_loss_fn_conf(conf_pred,   conf_target)
        loss_bbox = det_loss_fn_bbox(bbox_pred, bbox_target)
        det_loss = loss_conf + loss_bbox

        # Pose loss
        pose_loss = pose_loss_fn(pose_out, pose_target)

        # Backprop & step
        loss = det_loss + pose_loss
        loss.backward()
        optimizer.step()

        total_det_loss  += det_loss.item()
        total_pose_loss += pose_loss.item()

    avg_det  = total_det_loss  / len(loader)
    avg_pose = total_pose_loss / len(loader)
    print(f"[Epoch {epoch+1}] det_loss: {avg_det:.4f} | pose_loss: {avg_pose:.4f}")
print("학습 완료!")


In [None]:
torch.save(model, 'checkpoints/resnet_based_pe.pt')


In [None]:
import os
import cv2
import torch
from torchvision import transforms
from PIL import Image

# 디바이스 설정 (CUDA 사용 가능하면 GPU, 아니면 CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 1) 모델 로드 ---
# 모델 정의가 있는 파일에서 MultiObjectKeypointNet 클래스를 import 해두세요.
# from your_model_file import MultiObjectKeypointNet

model = torch.load('checkpoints/resnet_based_pe.pt', map_location=device)
model.to(device).eval()  # 평가 모드로 전환

# --- 2) 입력 전처리 정의 ---
transform = transforms.Compose([
    transforms.Resize((360, 360)),             # 모델 입력 크기에 맞춰 리사이즈
    transforms.ToTensor(),                     # Tensor로 변환 (C×H×W, [0,1])
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std =[0.229,0.224,0.225]),  # ImageNet 표준 정규화
])

# --- 3) 비디오 폴더의 .mp4 파일 순회 ---
video_dir = '/home/otter/workspace/Pallet/train_video'
video_files = sorted([
    os.path.join(video_dir, f)
    for f in os.listdir(video_dir)
    if f.lower().endswith('.mp4')
])

for video_path in video_files:
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"[경고] 비디오 열기 실패: {video_path}")
        continue

    while True:
        ret, frame_bgr = cap.read()
        if not ret:
            break

        H, W = frame_bgr.shape[:2]

        # OpenCV BGR → RGB → PIL → transform → 배치 차원 추가
        frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
        pil_img    = Image.fromarray(frame_rgb)
        inp        = transform(pil_img).unsqueeze(0).to(device)  # (1,3,360,360)

        # 추론
        with torch.no_grad():
            boxes_batch, kps_batch = model(inp)

        # 배치 크기 1 이므로 첫 번째 결과만 사용
        boxes = boxes_batch[0]  # Tensor(num_objs, 4) — [x1,y1,x2,y2] (정규화)
        kps   = kps_batch[0]    # Tensor(num_objs, K*2) — 박스 내 상대좌표

        # 각 객체마다 박스와 키포인트를 원본 해상도로 변환해 그리기
        for obj_idx in range(boxes.shape[0]):
            x1n, y1n, x2n, y2n = boxes[obj_idx].tolist()
            # 정규화된 좌표 → 픽셀 좌표
            x1, y1 = int(x1n * W), int(y1n * H)
            x2, y2 = int(x2n * W), int(y2n * H)

            # 키포인트: 상대좌표 → 절대좌표
            obj_kps = kps[obj_idx]  # (K*2,)
            xs = obj_kps[0::2] * (x2 - x1) + x1
            ys = obj_kps[1::2] * (y2 - y1) + y1
            xs = xs.cpu().numpy().astype(int)
            ys = ys.cpu().numpy().astype(int)

            for (kx, ky) in zip(xs, ys):
                cv2.circle(frame_bgr, (kx, ky), 4, (0,0,255), -1)

        # 결과 프레임 출력  
        cv2.imshow("Inference", frame_bgr)
        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):  # 'q' 누르면 종료
            break

    cap.release()
    if key == ord('q'):
        break

cv2.destroyAllWindows()
