### FR5 로봇 데이터셋 전처리 과정
1. 카메라와 Joint angle의 Time stamps 맞추기
2. ROI에 맞게 적절한 label 생성 

### FR5 로봇 자세 추정 과정
1. 전체 이미지에서 ROI 탐색 - SegFormer
2. ROI 이미지에서 각 관절 좌표 추정 3차원

In [None]:
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
import numpy as np

import re
import glob
import json
import math
from pathlib import Path

import pandas as pd
from typing import Any, Dict, List, Optional, Tuple
from bisect import bisect_left

import cv2
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms as T
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import SegformerForSemanticSegmentation, SegformerConfig
import timm

# ================= MKL 충돌 해결 및 장치 설정 =================

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ================= 사용자 설정 =================
ROOT = "/home/najo/NAS/DIP/datasets/Fr5_intertek_dataset"
MAX_TIME_DIFF = 0.025  # seconds
# ==============================================

# =========================================================
# 1) 데이터 준비: 이미지-조인트 매칭 및 인덱싱
# =========================================================

IMG_RE = re.compile(r"zed_(?P<serial>\d+)_(?P<view>[a-zA-Z]+)_(?P<ts>\d+\.\d+)\.jpg$")
JNT_RE = re.compile(r"joint_(?P<robotserial>\d+)_(?P<ts>\d+\.\d+)\.json$")

def parse_img_fname(path: str) -> Optional[Dict[str, Any]]:
    m = IMG_RE.search(os.path.basename(path))
    if not m:
        return None
    d = m.groupdict()
    d["timestamp"] = float(d.pop("ts"))
    d["path"] = path
    return d

def parse_joint_fname(path: str) -> Optional[Dict[str, Any]]:
    m = JNT_RE.search(os.path.basename(path))
    if not m:
        return None
    d = m.groupdict()
    d["timestamp"] = float(d.pop("ts"))
    d["path"] = path
    return d

def flatten_json(prefix: str, obj: Any) -> Dict[str, Any]:
    out = {}
    if isinstance(obj, dict):
        for k, v in obj.items():
            out.update(flatten_json(f"{prefix}.{k}" if prefix else str(k), v))
    elif isinstance(obj, list):
        for i, v in enumerate(obj):
            out.update(flatten_json(f"{prefix}.{i}" if prefix else str(i), v))
    else:
        out[prefix] = obj
    return out

def load_joint_angles(path: str) -> Dict[str, Any]:
    try:
        with open(path, "r") as f:
            data = json.load(f)
    except Exception as e:
        return {"joint_load_error": str(e)}
    candidates = ["joint_angles", "joints", "angles", "q", "positions", "joint"]
    for key in candidates:
        if isinstance(data, dict) and key in data:
            return flatten_json("joint", data[key])
    return flatten_json("joint", data)

def build_joint_index(joint_dir: str) -> Tuple[List[float], List[str]]:
    jpaths = sorted(glob.glob(os.path.join(joint_dir, "joint_*.json")))
    ts_list, path_list = [], []
    for p in jpaths:
        info = parse_joint_fname(p)
        if info is None:
            continue
        ts_list.append(info["timestamp"])
        path_list.append(info["path"])
    pairs = sorted(zip(ts_list, path_list), key=lambda x: x[0])
    return [t for t, _ in pairs], [p for _, p in pairs]

def find_nearest_joint_any(ts: float, ts_index: List[float], paths: List[str]):
    if not ts_index:
        return None
    pos = bisect_left(ts_index, ts)
    best = None
    for idx in (pos - 1, pos):
        if 0 <= idx < len(ts_index):
            jt = ts_index[idx]
            dt = abs(jt - ts)
            cand = (jt, paths[idx], dt)
            if (best is None) or (dt < best[2]):
                best = cand
    return best

def scan_images(img_dirs: List[str]) -> List[Dict[str, Any]]:
    imgs = []
    for d in img_dirs:
        for p in sorted(glob.glob(os.path.join(d, "*.jpg"))):
            info = parse_img_fname(p)
            if info:
                imgs.append(info)
    return imgs

def process_dataset_indexing(dataset_root: str, max_time_diff: float = 0.2):
    img_dirs = [os.path.join(dataset_root, x) for x in ("left", "right", "top")]
    joint_dir = os.path.join(dataset_root, "joint")
    out_csv = os.path.join(dataset_root, "matched_index.csv")
    out_jsonl = os.path.join(dataset_root, "matched_index.jsonl")

    joint_ts_index, joint_paths = build_joint_index(joint_dir)
    images = scan_images(img_dirs)

    if not images:
        print(f"[{os.path.basename(dataset_root)}] 이미지가 없습니다.")
        return
    if not joint_ts_index:
        print(f"[{os.path.basename(dataset_root)}] 조인트(JSON)가 없습니다.")
        return

    records = []
    unmatched_too_far = 0
    unmatched_no_joint = 0

    for img in images:
        ts_img = img["timestamp"]
        nearest = find_nearest_joint_any(ts_img, joint_ts_index, joint_paths)

        if nearest is None:
            print(f"[{os.path.basename(dataset_root)}] UNMATCHED(no_joint): {img['path']}")
            unmatched_no_joint += 1
            continue

        joint_ts, joint_path, dt = nearest

        if dt > max_time_diff:
            print(
                f"[{os.path.basename(dataset_root)}] UNMATCHED(threshold) "
                f"dt={dt:.9f}s > {max_time_diff:.9f}s | img_ts={ts_img:.9f} "
                f"-> nearest_joint={os.path.basename(joint_path)} (joint_ts={joint_ts:.9f})"
            )
            unmatched_too_far += 1
            continue

        joint_cols = load_joint_angles(joint_path)
        rec = {
            "img.path": img["path"],
            "img.serial": img["serial"],
            "img.view": img["view"],
            "img.ts": ts_img,
            "joint.path": joint_path,
            "joint.ts": joint_ts,
            "abs_dt": dt
        }
        rec.update(joint_cols)
        records.append(rec)

    if not records:
        print(
            f"[{os.path.basename(dataset_root)}] 매칭된 쌍이 없음 "
            f"(threshold={max_time_diff}s, images={len(images)}, "
            f"too_far={unmatched_too_far}, no_joint={unmatched_no_joint})"
        )
        return

    df = pd.DataFrame(records).sort_values(by=["img.ts"]).reset_index(drop=True)
    df.to_csv(out_csv, index=False)
    with open(out_jsonl, "w", encoding="utf-8") as f:
        for _, row in df.iterrows():
            f.write(json.dumps(row.to_dict(), ensure_ascii=False) + "\n")

    print(
        f"[{os.path.basename(dataset_root)}] 완료: "
        f"matched={len(records)} / images={len(images)} "
        f"(too_far={unmatched_too_far}, no_joint={unmatched_no_joint}) "
        f"-> {out_csv}, {out_jsonl}"
    )

# =========================================================
# 2) SegFormer 모델 정의
# =========================================================
class SegFormerForRobotArm(nn.Module):
    def __init__(self, num_classes=2, model_name="nvidia/mit-b2"):
        super().__init__()
        self.num_classes = num_classes
        print(f"🏗️ SegFormer 모델 로딩: {model_name}")
        self.segformer = SegformerForSemanticSegmentation.from_pretrained(
            model_name,
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )
        total_params = sum(p.numel() for p in self.segformer.parameters())
        trainable_params = sum(p.numel() for p in self.segformer.parameters() if p.requires_grad)
        print(f"📊 총 파라미터: {total_params:,}")
        print(f"📊 훈련 가능 파라미터: {trainable_params:,}")
        
    def forward(self, pixel_values):
        outputs = self.segformer(pixel_values=pixel_values)
        logits = outputs.logits
        upsampled_logits = F.interpolate(
            logits,
            size=pixel_values.shape[-2:],
            mode='bilinear',
            align_corners=False
        )
        return upsampled_logits

# =========================================================
# 3) Dataset
# =========================================================
def mask_to_bbox(mask: np.ndarray, min_area: int = 100) -> Optional[Tuple[int, int, int, int]]:
    ys, xs = np.where(mask > 0)
    if len(xs) == 0:
        return None
    x1, x2 = int(xs.min()), int(xs.max())
    y1, y2 = int(ys.min()), int(ys.max())
    if (x2 - x1 + 1) * (y2 - y1 + 1) < min_area:
        return None
    return x1, y1, x2, y2

def crop_with_padding(img: np.ndarray, box: Tuple[int, int, int, int], pad: int = 10) -> np.ndarray:
    H, W = img.shape[:2]
    x1, y1, x2, y2 = box
    x1 = max(0, x1 - pad)
    y1 = max(0, y1 - pad)
    x2 = min(W - 1, x2 + pad)
    y2 = min(H - 1, y2 + pad)
    return img[y1:y2+1, x1:x2+1]

def dh_transform(a: float, alpha: float, d: float, theta: float) -> np.ndarray:
    ca, sa = math.cos(alpha), math.sin(alpha)
    ct, st = math.cos(theta), math.sin(theta)
    return np.array([
        [ct, -st*ca,  st*sa, a*ct],
        [st,  ct*ca, -ct*sa, a*st],
        [0.0,    sa,    ca,     d],
        [0.0,  0.0,   0.0,   1.0]
    ], dtype=np.float64)

def angle_to_joint_coordinate(joint_angles_deg: List[float]) -> np.ndarray:
    fr5_dh_parameters = [
        {'alpha':   90, 'a':  0.0, 'd': 0.152, 'theta_offset': 0},
        {'alpha':    0, 'a': -0.425, 'd': 0.0, 'theta_offset': 0},
        {'alpha':    0, 'a': -0.395, 'd': 0.0, 'theta_offset': 0},
        {'alpha':   90, 'a':  0.0, 'd': 0.102, 'theta_offset': 0},
        {'alpha':  -90, 'a':  0.0, 'd': 0.102, 'theta_offset': 0},
        {'alpha':    0, 'a':  0.0, 'd': 0.100, 'theta_offset': 0},
    ]
    assert len(joint_angles_deg) == 6, "joint_angles_deg must have length 6."
    T_cum = np.eye(4, dtype=np.float64)
    joints_xyz = [T_cum[:3, 3].copy()]
    for i in range(len(fr5_dh_parameters)):
        p = fr5_dh_parameters[i]
        alpha = math.radians(p['alpha'])
        theta = math.radians(joint_angles_deg[i] + p['theta_offset'])
        a, d = p['a'], p['d']
        A_i = dh_transform(a, alpha, d, theta)
        T_cum = T_cum @ A_i
        joints_xyz.append(T_cum[:3, 3].copy())
    joints_xyz = np.stack(joints_xyz, axis=0)
    return joints_xyz

def maybe_to_degrees(joint_angles: List[float]) -> List[float]:
    if all(abs(a) <= math.pi * 1.25 for a in joint_angles):
        return [math.degrees(a) for a in joint_angles]
    return joint_angles

def extract_joint_angles_from_row(row: pd.Series) -> Optional[List[float]]:
    keys_idx = [f"joint.{i}" for i in range(6)]
    if all(k in row for k in keys_idx):
        vals = [float(row[k]) for k in keys_idx]
        return maybe_to_degrees(vals)
    keys_q = [f"joint.q{i+1}" for i in range(6)]
    if all(k in row for k in keys_q):
        vals = [float(row[k]) for k in keys_q]
        return maybe_to_degrees(vals)
    joint_cols = [k for k in row.index if isinstance(k, str) and k.startswith("joint.")]
    nums = []
    for k in sorted(joint_cols):
        v = row[k]
        try:
            v = float(v)
            nums.append(v)
        except Exception:
            continue
    if len(nums) >= 6:
        return maybe_to_degrees(nums[:6])
    return None

class RobotArmSegFKDataset(Dataset):
    def __init__(
        self,
        index_paths: List[str],
        segformer_model_path: str,  # Change this to accept the path, not the model object
        device: torch.device,
        view_filter: Optional[str] = None,
        image_size: int = 384,
        run_segmentation: bool = True,
        fg_class_id: int = 1,
        threshold: float = 0.5,
        pad: int = 10,
        img_key: str = "img.path",
    ):
        self.device = device
        self.run_seg = run_segmentation
        self.fg_class_id = fg_class_id
        self.threshold = threshold
        self.image_size = image_size
        self.pad = pad
        self.img_key = img_key

        rows = []
        for p in index_paths:
            if p.endswith(".csv"):
                df = pd.read_csv(p)
            elif p.endswith(".jsonl"):
                df = pd.read_json(p, lines=True)
            else:
                raise ValueError(f"Unsupported index file: {p}")
            rows.append(df)
        self.df = pd.concat(rows, ignore_index=True)
        if view_filter is not None:
            self.df = self.df[self.df.get("img.view", "") == view_filter].reset_index(drop=True)
        
        # --- NEW CODE: Load SegFormer model inside the __init__ method ---
        if self.run_seg:
            self.model = SegFormerForRobotArm(num_classes=2, model_name="nvidia/mit-b2").to(device)
            checkpoint = torch.load(segformer_model_path, map_location=self.device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.model.eval()
        else:
            self.model = None

        # Albumentations를 사용한 변환 파이프라인
        self.transform = A.Compose([
            A.Resize(image_size, image_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
        
        self.seg_input_transform = T.Compose([
            T.ToTensor(),
        ])

    def __len__(self):
        return len(self.df)

    @torch.no_grad()
    def _infer_mask(self, pil_img: Image.Image) -> np.ndarray:
        inp = self.seg_input_transform(pil_img).unsqueeze(0).to(self.device)
        self.model.eval()
        logits = self.model(inp)
        if isinstance(logits, (list, tuple)):
            logits = logits[0]
        probs = torch.softmax(logits, dim=1)[0]
        fg_prob = probs[self.fg_class_id]
        mask = (fg_prob > self.threshold).float().cpu().numpy()
        return (mask > 0.5).astype(np.uint8)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        row = self.df.iloc[idx]
        img_path = row[self.img_key]
        bgr = cv2.imread(img_path, cv2.IMREAD_COLOR)
        if bgr is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        
        if self.run_seg:
            pil = Image.fromarray(rgb)
            mask = self._infer_mask(pil)
            box = mask_to_bbox(mask, min_area=100)
        else:
            mask, box = None, None

        if box is None:
            H, W = rgb.shape[:2]
            side = min(H, W)
            cy, cx = H // 2, W // 2
            x1 = max(0, cx - side // 2)
            y1 = max(0, cy - side // 2)
            x2 = min(W - 1, x1 + side - 1)
            y2 = min(H - 1, y1 + side - 1)
            crop = rgb[y1:y2+1, x1:x2+1]
        else:
            crop = crop_with_padding(rgb, box, pad=self.pad)

        transformed = self.transform(image=crop)
        img_tensor = transformed['image']

        joint_angles = extract_joint_angles_from_row(row)
        if joint_angles is None:
            raise ValueError(f"No valid joint angles in row index={idx} (columns like 'joint.*' expected).")

        joints_xyz = angle_to_joint_coordinate(joint_angles)
        joints_tensor = torch.from_numpy(joints_xyz).float()

        sample = {
            "image": img_tensor,
            "joints_xyz": joints_tensor,
            "img_path": img_path,
            "view": row.get("img.view", None),
            "img_ts": float(row.get("img.ts", -1)),
            "joint_ts": float(row.get("joint.ts", -1)),
        }
        if box is not None:
            sample["roi_box_xyxy"] = np.array(box, dtype=np.int32)
        return sample
    
# =========================================================
# 4) Model (HRNet-like with ViT backbone, 7 keypoints)
# =========================================================
def reshape_vit_output(x, patch_size=16, img_size=512):
    if x.ndim == 3 and x.shape[1] > 1:
        if x.shape[1] == (img_size // patch_size)**2 + 1:
            x = x[:, 1:]
        B, N, D = x.shape
        H_feat = W_feat = int(N**0.5)
        if H_feat * W_feat != N:
            raise ValueError(f"ViT output sequence length {N} is not a perfect square.")
        x = x.permute(0, 2, 1).reshape(B, D, H_feat, W_feat)
    elif x.ndim == 4:
        pass
    else:
        raise ValueError(f"Unsupported ViT output shape for reshaping: {x.shape}")
    return x

class BasicBlock(nn.Module):
    def __init__(self, in_c, out_c, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU(inplace=True)
        self.down = (nn.Conv2d(in_c, out_c, 1, stride, bias=False)
                     if (stride != 1 or in_c != out_c) else None)
    def forward(self, x):
        res = self.down(x) if self.down else x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return self.relu(out + res)

class PoseEstimationHRViT(nn.Module):
    def __init__(self, num_kp=7, vit_model_name='vit_base_patch16_224', pretrained=True, img_size=512):
        super().__init__()
        self.num_kp = num_kp
        self.img_size = img_size
        self.vit_backbone = timm.create_model(vit_model_name, pretrained=pretrained, num_classes=0)
        print(f"Loaded pretrained ViT model: {vit_model_name}")
        vit_embed_dim = self.vit_backbone.embed_dim
        vit_patch_size = self.vit_backbone.patch_embed.patch_size[0]
        
        self.high_res_branch_conv = nn.Sequential(
            nn.Conv2d(vit_embed_dim, 256, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            self._make_layer(256, 256, 2, 1)
        )
        self.low_res_branch_conv = nn.Sequential(
            nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            self._make_layer(512, 512, 2, 1)
        )
        self.fuse_l_to_h = nn.Sequential(
            nn.Conv2d(512, 256, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest')
        )
        self.regression_head = nn.Sequential(
            nn.Conv2d(256, 128, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, num_kp * 3)
        )
    def _make_layer(self, in_c, out_c, blocks, stride):
        layers = [BasicBlock(in_c, out_c, stride)]
        for _ in range(1, blocks):
            layers.append(BasicBlock(out_c, out_c))
        return nn.Sequential(*layers)
    def forward(self, x):
        vit_output = self.vit_backbone.forward_features(x)
        vit_spatial_features = reshape_vit_output(vit_output, self.vit_backbone.patch_embed.patch_size[0], self.img_size)
        high_res_feat = self.high_res_branch_conv(vit_spatial_features)
        low_res_feat = self.low_res_branch_conv(high_res_feat)
        high_res_fused = high_res_feat + self.fuse_l_to_h(low_res_feat)
        output = self.regression_head(high_res_fused)
        return output.view(-1, self.num_kp, 3)

# =========================================================
# 5) 학습 및 검증 루프
# =========================================================
def train_and_validate(
    model: nn.Module,
    train_dataset: Dataset,
    val_dataset: Dataset,
    epochs: int,
    batch_size: int,
    learning_rate: float,
    save_path: str,
    device: torch.device,
):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.SmoothL1Loss(reduction='mean')

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=os.cpu_count() // 2,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=os.cpu_count() // 2,
    )
    best_val_loss = float('inf')
    os.makedirs(save_path, exist_ok=True)
    print(f"Starting training for {epochs} epochs...")
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for batch in train_progress_bar:
            images = batch['image'].to(device)
            joints_xyz = batch['joints_xyz'].to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, joints_xyz)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)
            train_progress_bar.set_postfix(loss=loss.item())
        epoch_train_loss = train_loss / len(train_loader.dataset)
        model.eval()
        val_loss = 0.0
        val_progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]")
        with torch.no_grad():
            for batch in val_progress_bar:
                images = batch['image'].to(device)
                joints_xyz = batch['joints_xyz'].to(device)
                outputs = model(images)
                loss = criterion(outputs, joints_xyz)
                val_loss += loss.item() * images.size(0)
                val_progress_bar.set_postfix(loss=loss.item())
        epoch_val_loss = val_loss / len(val_loader.dataset)
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {epoch_train_loss:.4f} | Val Loss: {epoch_val_loss:.4f}")
        if epoch_val_loss < best_val_loss:
            print(f"Validation loss improved from {best_val_loss:.4f} to {epoch_val_loss:.4f}. Saving model...")
            best_val_loss = epoch_val_loss
            checkpoint_path = os.path.join(save_path, "best_model_checkpoint.pt")
            torch.save(model.state_dict(), checkpoint_path)
    print("Training finished.")

# =========================================================
# 6) 실행 메인 함수
# =========================================================
def main():
    # 데이터셋 인덱싱 (한 번만 실행하면 됩니다)
    subdirs = [
        os.path.join(ROOT, d) for d in os.listdir(ROOT)
        if os.path.isdir(os.path.join(ROOT, d)) and d.startswith("Fr5_intertek_")
    ]
    subdirs.sort()
    print(f"총 {len(subdirs)}개 데이터셋 세트 탐색: {subdirs}")
    for sd in subdirs:
        process_dataset_indexing(sd, MAX_TIME_DIFF)

    # SegFormer 모델 경로만 저장
    SegFormer_model_path = "/home/najo/NAS/DIP/Fr5_robot_SegFormer/best_segformer_robot_arm.pth"
    
    train_index_files = [os.path.join(ROOT, 'Fr5_intertek_1st_250526/matched_index.csv'),
                         os.path.join(ROOT, 'Fr5_intertek_2nd_250526/matched_index.csv'),
                         os.path.join(ROOT, 'Fr5_intertek_3rd_250526/matched_index.csv'),
                         os.path.join(ROOT, 'Fr5_intertek_4th_250526/matched_index.csv'),
                         os.path.join(ROOT, 'Fr5_intertek_5th_250526/matched_index.csv'),
                         os.path.join(ROOT, 'Fr5_intertek_6th_250526/matched_index.csv'),
                         ]
    val_index_files = [os.path.join(ROOT, 'Fr5_intertek_7th_250526/matched_index.csv')]

    IMG_SIZE = 512
    
    train_dataset = RobotArmSegFKDataset(
        index_paths=train_index_files,
        segformer_model_path=SegFormer_model_path, # Pass the path
        device=device,
        image_size=IMG_SIZE,
        run_segmentation=True
    )
    
    val_dataset = RobotArmSegFKDataset(
        index_paths=val_index_files,
        segformer_model_path=SegFormer_model_path, # Pass the path
        device=device,
        image_size=IMG_SIZE,
        run_segmentation=True
    )
    
    model = PoseEstimationHRViT(num_kp=7, pretrained=True, img_size=IMG_SIZE).to(device)

    train_and_validate(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        epochs=50,
        batch_size=4,
        learning_rate=1e-4,
        save_path="checkpoints",
        device=device,
    )

if __name__ == "__main__":
    torch.multiprocessing.set_ start_method('spawn', force=True)
    main()

Using device: cuda
총 7개 데이터셋 세트 탐색: ['/home/najo/NAS/DIP/datasets/Fr5_intertek_dataset/Fr5_intertek_1st_250526', '/home/najo/NAS/DIP/datasets/Fr5_intertek_dataset/Fr5_intertek_2nd_250526', '/home/najo/NAS/DIP/datasets/Fr5_intertek_dataset/Fr5_intertek_3rd_250526', '/home/najo/NAS/DIP/datasets/Fr5_intertek_dataset/Fr5_intertek_4th_250526', '/home/najo/NAS/DIP/datasets/Fr5_intertek_dataset/Fr5_intertek_5th_250526', '/home/najo/NAS/DIP/datasets/Fr5_intertek_dataset/Fr5_intertek_6th_250526', '/home/najo/NAS/DIP/datasets/Fr5_intertek_dataset/Fr5_intertek_7th_250526']
[Fr5_intertek_1st_250526] 완료: matched=1296 / images=1296 (too_far=0, no_joint=0) -> /home/najo/NAS/DIP/datasets/Fr5_intertek_dataset/Fr5_intertek_1st_250526/matched_index.csv, /home/najo/NAS/DIP/datasets/Fr5_intertek_dataset/Fr5_intertek_1st_250526/matched_index.jsonl
[Fr5_intertek_2nd_250526] 완료: matched=1310 / images=1310 (too_far=0, no_joint=0) -> /home/najo/NAS/DIP/datasets/Fr5_intertek_dataset/Fr5_intertek_2nd_250526/match

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b2 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  checkpoint = torch.load(segformer_model_path, map_location=self.device)


📊 총 파라미터: 27,348,162
📊 훈련 가능 파라미터: 27,348,162
🏗️ SegFormer 모델 로딩: nvidia/mit-b2


Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b2 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  checkpoint = torch.load(segformer_model_path, map_location=self.device)


📊 총 파라미터: 27,348,162
📊 훈련 가능 파라미터: 27,348,162
Loaded pretrained ViT model: vit_base_patch16_224
Starting training for 50 epochs...


Epoch 1/50 [Train]:   0%|          | 0/1959 [00:00<?, ?it/s]Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/najo/.conda/envs/dip/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/najo/.conda/envs/dip/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'RobotArmSegFKDataset' on <module '__main__' (built-in)>
