In [7]:
import torch
import torch.nn as nn
import cv2

class SimpleDetectorPoseEstimator(nn.Module):
    def __init__(self, num_keypoints=4, grid_size=7, img_size=224):
        super(SimpleDetectorPoseEstimator, self).__init__()
        self.grid_size = grid_size
        self.img_size = img_size

        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU()
        )
        # YOLO-like Object Detector
        self.detector = nn.Sequential(
            nn.AdaptiveAvgPool2d(grid_size),
            nn.Conv2d(256, 5, kernel_size=1)  # [obj_conf, x, y, w, h]
        )

        # Pose Estimator
        self.pose_estimator = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, num_keypoints * 2)
        )


    def detect_objects(self, image_tensor, conf_threshold=0.5):
        B, _, _, _ = image_tensor.shape
        predictions = self.detector(image_tensor).permute(0, 2, 3, 1)  # (B, G, G, 5)

        obj_conf = torch.sigmoid(predictions[..., 0])
        x = torch.sigmoid(predictions[..., 1])
        y = torch.sigmoid(predictions[..., 2])
        w = torch.sigmoid(predictions[..., 3])
        h = torch.sigmoid(predictions[..., 4])

        grid_range = torch.arange(self.grid_size, dtype=torch.float32, device=image_tensor.device)
        gy, gx = torch.meshgrid(grid_range, grid_range, indexing='ij')
        gx = gx.unsqueeze(0)  # (1, G, G)
        gy = gy.unsqueeze(0)  # (1, G, G)

        cell_size = 1.0 / self.grid_size
        x_abs = (gx + x) * cell_size  # (B, G, G)
        y_abs = (gy + y) * cell_size

        w_abs = w
        h_abs = h

        mask = obj_conf > conf_threshold  # (B, G, G)
        boxes_batch = []
        for b in range(B):
            mask_b = mask[b]
            x_b = x_abs[b][mask_b]
            y_b = y_abs[b][mask_b]
            w_b = w_abs[b][mask_b]
            h_b = h_abs[b][mask_b]
            boxes = torch.stack([x_b, y_b, w_b, h_b], dim=1)  # (N, 4)
            boxes_batch.append(boxes.tolist())

        return boxes_batch  # List of List[box]

    def crop_objects(self, images, boxes_batch, margin=0.1):
        crops = []
        for img_idx, (image, boxes) in enumerate(zip(images, boxes_batch)):
            H, W, _ = image.shape
            for x_c, y_c, w, h in boxes:
                w_new = w * (1 + margin)
                h_new = h * (1 + margin)

                x1 = int((x_c - w_new / 2) * W)
                y1 = int((y_c - h_new / 2) * H)
                x2 = int((x_c + w_new / 2) * W)
                y2 = int((y_c + h_new / 2) * H)

                x1 = max(0, x1)
                y1 = max(0, y1)
                x2 = min(W, x2)
                y2 = min(H, y2)

                crop = image[y1:y2, x1:x2]
                crop_resized = cv2.resize(crop, (self.img_size, self.img_size))
                crops.append(crop_resized)

        return crops

    def estimate_pose(self, cropped_images):
        if not cropped_images:
            return []

        crop_tensors = torch.stack([
            torch.tensor(crop, dtype=torch.float32).permute(2, 0, 1) / 255.0
            for crop in cropped_images
        ])
        with torch.no_grad():
            features = self.feature_extractor(crop_tensors)
            keypoints = self.pose_estimator(features)
        return keypoints.cpu().numpy()

    def forward(self, x):
        features = self.feature_extractor(x)
        det_output = self.detector(features).permute(0, 2, 3, 1)  # (B, G, G, 5)
        center = self.grid_size // 2
        det_preds = det_output[:, center, center, :]  # (B, 5)

        pose_preds = self.pose_estimator(features)  # (B, K*2)

        return det_preds, pose_preds


In [9]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader

# (기존 클래스 import 후)

# Custom Dataset 정의
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, transform=None):
        self.img_dir = os.path.join(data_path, 'images')
        self.label_dir = os.path.join(data_path, 'labels')
        self.img_files = sorted(os.listdir(self.img_dir))
        self.transform = transform

    def __len__(self):
        return len(self.img_files)


    def __getitem__(self, idx):
        img_file = self.img_files[idx]
        label_file = img_file.replace('.jpg', '.txt')

        # 이미지 로드
        img_path = os.path.join(self.img_dir, img_file)
        img = cv2.imread(img_path)
        if self.transform is not None:
            img = self.transform(img)

        # 레이블 로드 (모든 객체)
        label_path = os.path.join(self.label_dir, label_file)
        det_labels = []
        pose_labels = []

        with open(label_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 5:
                    det = [1.0] + [float(x) for x in parts[1:5]]  # objectness = 1.0
                    pose = [float(x) for x in parts[5:]]
                    det_labels.append(det)
                    pose_labels.append(pose)

        det_labels = torch.tensor(det_labels, dtype=torch.float32)  # (num_objs, 5)
        pose_labels = torch.tensor(pose_labels, dtype=torch.float32)  # (num_objs, num_kps*2)

        return img, det_labels, pose_labels

    
base = '/home/otter/dataset/pallet/dataset/train'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((360, 360)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def custom_collate_fn(batch):
    images, dets, poses = zip(*batch)  # 튜플을 분리

    images = torch.stack(images, dim=0)  # 이미지만은 고정 크기이므로 stack 가능

    return images, list(dets), list(poses)
dataset = CustomDataset(base, transform)
loader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=custom_collate_fn)

# Loss 정의
det_loss_fn = nn.MSELoss()
pose_loss_fn = nn.MSELoss()
device = 'cuda:0'

model = SimpleDetectorPoseEstimator(num_keypoints=4).to(device)

# Optimizer 정의
optimizer = optim.Adam([
    {'params': model.detector.parameters(), 'lr': 1e-4},
    {'params': model.pose_estimator.parameters(), 'lr': 1e-4}
])

# 학습 루프
epochs = 20
for epoch in range(epochs):
    total_det_loss, total_pose_loss = 0, 0
    pbar = tqdm(loader)
    for imgs, det_labels, pose_labels in pbar:
        imgs = imgs.to(device)
        # det_labels = det_labels.to(device)
        # pose_labels = pose_labels.to(device)
        det_labels = torch.stack([
            det[0] if len(det) > 0 else torch.zeros(5) for det in det_labels
        ]).to(device)
        pose_labels = torch.stack([
            det[0] if len(det) > 0 else torch.zeros(4) for det in pose_labels
        ]).to(device)
        optimizer.zero_grad()

        outputs = model(imgs)
        det_preds, pose_preds = model(imgs)
        det_loss = det_loss_fn(det_preds, det_labels)
        pose_loss = pose_loss_fn(pose_preds, pose_labels)

        loss = det_loss + pose_loss
        loss.backward()
        optimizer.step()

        total_det_loss += det_loss.item()
        total_pose_loss += pose_loss.item()
        # color_frame = imgs[0]
        # pred = model.pose_estimator(color_frame.unsqueeze(0))
        
        # color_frame = color_frame.permute(1, 2, 0).detach().cpu().numpy()
        # xs = pred[:, 0::2] * 360
        # ys = pred[:, 1::2] * 360
        # for k in range(len(xs)):
        #     x = xs[k].type(torch.int32).detach().cpu().numpy()
        #     y = ys[k].type(torch.int32).detach().cpu().numpy()
        #     for i in range(len(x)):
        #         cv2.circle(color_frame, (x, y), 5, (0, 0, 255), -1)
        # cv2.imshow('frame', color_frame)    
        # key = cv2.waitKey(1) & 0xFF
        # if key == ord('q') or key == ord('p'): break

    print(f"Epoch [{epoch+1}/{epochs}] Det Loss: {total_det_loss/len(loader):.4f}, Pose Loss: {total_pose_loss/len(loader):.4f}")

print("학습 완료!")


100%|██████████| 340/340 [00:51<00:00,  6.58it/s]


Epoch [1/20] Det Loss: 0.2252, Pose Loss: 0.2167


100%|██████████| 340/340 [00:51<00:00,  6.62it/s]


Epoch [2/20] Det Loss: 0.1198, Pose Loss: 0.1003


100%|██████████| 340/340 [00:50<00:00,  6.78it/s]


Epoch [3/20] Det Loss: 0.0667, Pose Loss: 0.0509


100%|██████████| 340/340 [00:52<00:00,  6.48it/s]


Epoch [4/20] Det Loss: 0.0397, Pose Loss: 0.0338


100%|██████████| 340/340 [00:52<00:00,  6.49it/s]


Epoch [5/20] Det Loss: 0.0262, Pose Loss: 0.0291


100%|██████████| 340/340 [00:51<00:00,  6.61it/s]


Epoch [6/20] Det Loss: 0.0200, Pose Loss: 0.0281


100%|██████████| 340/340 [00:52<00:00,  6.52it/s]


Epoch [7/20] Det Loss: 0.0174, Pose Loss: 0.0278


100%|██████████| 340/340 [00:52<00:00,  6.46it/s]


Epoch [8/20] Det Loss: 0.0164, Pose Loss: 0.0277


100%|██████████| 340/340 [00:51<00:00,  6.60it/s]


Epoch [9/20] Det Loss: 0.0160, Pose Loss: 0.0275


100%|██████████| 340/340 [00:52<00:00,  6.53it/s]


Epoch [10/20] Det Loss: 0.0156, Pose Loss: 0.0273


100%|██████████| 340/340 [00:52<00:00,  6.50it/s]


Epoch [11/20] Det Loss: 0.0154, Pose Loss: 0.0272


100%|██████████| 340/340 [00:51<00:00,  6.56it/s]


Epoch [12/20] Det Loss: 0.0151, Pose Loss: 0.0270


100%|██████████| 340/340 [00:52<00:00,  6.47it/s]


Epoch [13/20] Det Loss: 0.0148, Pose Loss: 0.0268


100%|██████████| 340/340 [00:57<00:00,  5.87it/s]


Epoch [14/20] Det Loss: 0.0145, Pose Loss: 0.0266


100%|██████████| 340/340 [01:00<00:00,  5.61it/s]


Epoch [15/20] Det Loss: 0.0142, Pose Loss: 0.0264


100%|██████████| 340/340 [01:00<00:00,  5.58it/s]


Epoch [16/20] Det Loss: 0.0140, Pose Loss: 0.0262


100%|██████████| 340/340 [00:58<00:00,  5.84it/s]


Epoch [17/20] Det Loss: 0.0137, Pose Loss: 0.0260


100%|██████████| 340/340 [00:59<00:00,  5.76it/s]


Epoch [18/20] Det Loss: 0.0134, Pose Loss: 0.0258


100%|██████████| 340/340 [00:59<00:00,  5.73it/s]


Epoch [19/20] Det Loss: 0.0132, Pose Loss: 0.0257


100%|██████████| 340/340 [00:54<00:00,  6.23it/s]

Epoch [20/20] Det Loss: 0.0129, Pose Loss: 0.0255
학습 완료!





In [10]:
torch.save(model, 'checkpoints/ype.pt')


In [15]:

from torchvision import transforms
model = torch.load('checkpoints/ype.pt').cuda()
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((360, 360)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 모델 사용 예시
file_base = '/home/otter/workspace/Pallet/train_video'
files_list = os.listdir(file_base)
files = [os.path.join(file_base, file) for file in files_list if file.endswith('.mp4')]
for file in files:
    key = None
    cap = cv2.VideoCapture(file)
    while cap.isOpened():
        ret, color_frame = cap.read()
        if ret: 
            input_frame = transform(color_frame).to('cuda').unsqueeze(0)
            det_preds, outputs = model(input_frame)
            xs = outputs[:, 0::2] * 1920
            ys = outputs[:, 1::2] * 1080
            for k in range(len(xs)):
                x = xs[k].type(torch.int32).detach().cpu().numpy()
                y = ys[k].type(torch.int32).detach().cpu().numpy()
                for i in range(len(x)):
                    cv2.circle(color_frame, (x[i], y[i]), 5, (0, 0, 255), -1)
            cv2.imshow('frame', color_frame)    
            key = cv2.waitKey(1) & 0xFF
            if key == ord('q') or key == ord('p'): break
        else:
            break 
    if key == ord('q'):
        break

    cap.release()
cv2.destroyAllWindows()

  model = torch.load('checkpoints/ype.pt').cuda()
