In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torchvision.ops import box_iou

from MaskDataset import MaskDataset
from SEPN import ResNet18Backbone, SEPN
from DY_PS import DySample, PSConv
from hybrid_encoder import HybridEncoderBlock
from Hungrian_match import HungarianMatcher, RTDETRDetection, SetCriterion

# -----------------------------
# 1. collate_fn (?? ???)
# -----------------------------
def collate_fn(batch):
    imgs = torch.stack([b[0] for b in batch], dim=0)
    targets = [b[1] for b in batch]
    return imgs, targets


# -----------------------------
# 2. ??? ?? ??
# -----------------------------
def get_image_paths(img_dir, exts=(".jpg", ".png", ".jpeg", ".bmp")):
    return [
        os.path.join(img_dir, f)
        for f in os.listdir(img_dir)
        if f.lower().endswith(exts)
    ]


img_dir = "/data2/project/2025summer/yjh0913/DB-1/Images"      # ? ??? ??
mask_dir = "/data2/project/2025summer/yjh0913/DB-1/Masks"      # ? ??? ??

all_img_paths = get_image_paths(img_dir)
print(f"Total images: {len(all_img_paths)}")


# -----------------------------
# 3. 8:2 ?? (??? train?? 8? ??)
# -----------------------------
random.seed(42)
random.shuffle(all_img_paths)

split_idx = int(len(all_img_paths) * 0.8)
train_img_paths = all_img_paths[:split_idx]
val_img_paths   = all_img_paths[split_idx:]   # ? ?????? train?? ? ?

print(f"Train images (80%): {len(train_img_paths)}")
print(f"Val images (20%): {len(val_img_paths)}")


# -----------------------------
# 4. Train Dataset (8? ??)
# -----------------------------
train_dataset = MaskDataset(
    img_paths=train_img_paths,
    mask_dir=mask_dir,
    img_size=640
)

train_loader = DataLoader(
    train_dataset,
    batch_size=1,  # 2 ? 1 (??? ??)
    shuffle=True,
    num_workers=0,  # ?????? ????
    collate_fn=collate_fn
)


Total images: 1195
Train images (80%): 956
Val images (20%): 239


NameError: name 'MaskDataset' is not defined

In [None]:
from Hungrian_match import RTDETRDetection

class JFSTDETR(nn.Module):
    def __init__(
        self,
        num_classes=1,
        hidden_dim=256,
        num_queries=100
    ):
        super().__init__()

        # Backbone
        self.backbone = ResNet18Backbone()

        # SEPN
        self.sepn = SEPN()

        # Channel alignment
        self.proj3 = nn.Conv2d(128, hidden_dim, 1)
        self.proj4 = nn.Conv2d(256, hidden_dim, 1)
        self.proj5 = nn.Conv2d(512, hidden_dim, 1)

        # DySample + PSConv
        self.refine = DySample(hidden_dim)
        
        # Channel restore after DySample (pixel_shuffle reduces channels)
        self.restore3 = nn.Conv2d(hidden_dim // 4, hidden_dim, 1)
        self.restore4 = nn.Conv2d(hidden_dim // 4, hidden_dim, 1)
        self.restore5 = nn.Conv2d(hidden_dim // 4, hidden_dim, 1)

        # Hybrid Encoder
        self.encoder = HybridEncoderBlock(hidden_dim)

        # Detection head
        self.detector = RTDETRDetection(
            hidden_dim=hidden_dim,
            num_classes=num_classes,
            num_queries=num_queries
        )

    def forward(self, x):
        # Backbone
        p2, p3, p4, p5 = self.backbone(x)

        # SEPN
        p3, p4, p5 = self.sepn(p2, p3, p4, p5)

        # Channel align
        p3 = self.proj3(p3)
        p4 = self.proj4(p4)
        p5 = self.proj5(p5)

        # DySample + restore channels
        p3 = self.refine(p3)
        p3 = self.restore3(p3)
        
        p4 = self.refine(p4)
        p4 = self.restore4(p4)
        
        p5 = self.refine(p5)
        p5 = self.restore5(p5)

        # Hybrid Encoder
        p3, p4, p5 = self.encoder(p3, p4, p5)

        # Flatten (RT-DETR Î∞©Ïãù)
        B, C, H3, W3 = p3.shape
        mem3 = p3.flatten(2).permute(0, 2, 1)
        mem4 = p4.flatten(2).permute(0, 2, 1)
        mem5 = p5.flatten(2).permute(0, 2, 1)
        memory = torch.cat([mem3, mem4, mem5], dim=1)

        # Detection
        return self.detector(memory)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = JFSTDETR(
    num_classes=1,
    hidden_dim=128,  # 256 ‚Üí 128 (Î©îÎ™®Î¶¨ Ï†àÏïΩ)
    num_queries=50   # 100 ‚Üí 50 (Î©îÎ™®Î¶¨ Ï†àÏïΩ)
).to(device)

matcher = HungarianMatcher(
    cost_class=1,
    cost_bbox=5,
    cost_giou=2
)

criterion = SetCriterion(
    num_classes=1,
    matcher=matcher,
    weight_dict={
        "loss_ce": 2,
        "loss_bbox": 5,
        "loss_giou": 2
    }
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-5,  # ÌïôÏäµÎ•† ÎÇÆÏ∂§ (1e-4 ‚Üí 1e-5)
    weight_decay=1e-4
)


In [None]:
import os

save_dir = "./checkpoints"
os.makedirs(save_dir, exist_ok=True)

best_loss = float("inf")
num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for imgs, targets in train_loader:
        imgs = imgs.to(device)
        targets = [
            {
                "boxes": t["boxes"].to(device),
                "labels": t["labels"].to(device)
            }
            for t in targets
        ]

        outputs = model(imgs)
        loss, loss_dict = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)

    print(
        f"[Epoch {epoch:03d}] "
        f"Loss: {avg_loss:.4f} | "
        f"CE: {loss_dict['loss_ce']:.3f}, "
        f"BBox: {loss_dict['loss_bbox']:.3f}, "
        f"GIoU: {loss_dict['loss_giou']:.3f}"
    )

    # üî• best weight Ï†ÄÏû•
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": best_loss
            },
            os.path.join(save_dir, "jfst_detr_best.pth")
        )
        print(f"‚úÖ Saved best model (loss={best_loss:.4f})")


‚úÖ Model / DataLoader / Criterion / Optimizer confirmed
[Epoch 000] Loss: 10.4443 | CE: 0.693 | BBox: 0.887 | GIoU: 1.936 | Skipped batches: 0 | NaN batches: 0
‚úÖ Saved best model (loss=10.4443)
[Epoch 001] Loss: 10.4665 | CE: 0.693 | BBox: 1.005 | GIoU: 1.980 | Skipped batches: 0 | NaN batches: 0
[Epoch 002] Loss: 10.4735 | CE: 0.693 | BBox: 0.971 | GIoU: 1.947 | Skipped batches: 0 | NaN batches: 0
[Epoch 003] Loss: 10.4453 | CE: 0.693 | BBox: 1.143 | GIoU: 1.966 | Skipped batches: 0 | NaN batches: 0
[Epoch 004] Loss: 10.4507 | CE: 0.693 | BBox: 0.767 | GIoU: 1.985 | Skipped batches: 0 | NaN batches: 0
[Epoch 005] Loss: 10.4531 | CE: 0.693 | BBox: 1.179 | GIoU: 1.860 | Skipped batches: 0 | NaN batches: 0


In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from torchvision.ops import box_iou

# ===============================
# box utils
# ===============================
def cxcywh_to_xyxy(boxes):
    cx, cy, w, h = boxes.unbind(-1)
    return torch.stack([
        cx - w / 2,
        cy - h / 2,
        cx + w / 2,
        cy + h / 2
    ], dim=-1)

# ===============================
# postprocess (DETR style)
# ===============================
def postprocess(pred_logits, pred_boxes, score_thresh=0.5):
    probs = F.softmax(pred_logits, dim=-1)
    scores, _ = probs[..., :-1].max(dim=-1)  # remove background
    keep = scores > score_thresh
    return pred_boxes[keep], scores[keep]

# ===============================
# TP / FP / FN
# ===============================
def calc_tp_fp_fn(pred_boxes, gt_boxes, iou_thresh=0.5):
    if len(pred_boxes) == 0:
        return 0, 0, len(gt_boxes)

    ious = box_iou(pred_boxes, gt_boxes)
    tp = 0
    matched = set()

    for i in range(len(pred_boxes)):
        max_iou, idx = ious[i].max(dim=0)
        if max_iou >= iou_thresh and idx.item() not in matched:
            tp += 1
            matched.add(idx.item())

    fp = len(pred_boxes) - tp
    fn = len(gt_boxes) - tp
    return tp, fp, fn

# ===============================
# AP (mAP@0.5)
# ===============================
def compute_ap(recalls, precisions):
    recalls = np.concatenate(([0.], recalls, [1.]))
    precisions = np.concatenate(([0.], precisions, [0.]))

    for i in range(len(precisions) - 1, 0, -1):
        precisions[i - 1] = max(precisions[i - 1], precisions[i])

    idx = np.where(recalls[1:] != recalls[:-1])[0]
    return np.sum((recalls[idx + 1] - recalls[idx]) * precisions[idx + 1])

# ===============================
# evaluation
# ===============================
@torch.no_grad()
@torch.no_grad()
def evaluate(model, dataloader, device, score_thresh=0.5, iou_thresh=0.5):
    model.eval()

    total_tp = total_fp = total_fn = 0
    precisions, recalls = [], []

    for imgs, targets in dataloader:
        imgs = imgs.to(device)

        # üî• FP16 + no grad
        with torch.cuda.amp.autocast():
            outputs = model(imgs)

        # üî• ÌïÑÏöîÌïú Í≤ÉÎßå CPUÎ°ú Ï¶âÏãú Ïù¥Îèô
        pred_logits = outputs["pred_logits"].cpu()
        pred_boxes  = outputs["pred_boxes"].cpu()

        # üî• GPU Î©îÎ™®Î¶¨ Ï¶âÏãú Ìï¥Ï†ú
        del outputs, imgs
        torch.cuda.empty_cache()

        for i in range(len(pred_logits)):
            pb, _ = postprocess(
                pred_logits[i],
                pred_boxes[i],
                score_thresh
            )

            gt = targets[i]["boxes"]

            pb = cxcywh_to_xyxy(pb)
            gt = cxcywh_to_xyxy(gt)

            tp, fp, fn = calc_tp_fp_fn(pb, gt, iou_thresh)

            total_tp += tp
            total_fp += fp
            total_fn += fn

            p = tp / (tp + fp + 1e-6)
            r = tp / (tp + fn + 1e-6)
            precisions.append(p)
            recalls.append(r)

        # üî• CPU ÌÖêÏÑúÎèÑ Ï†ïÎ¶¨
        del pred_logits, pred_boxes

    precision = total_tp / (total_tp + total_fp + 1e-6)
    recall    = total_tp / (total_tp + total_fn + 1e-6)
    f1        = 2 * precision * recall / (precision + recall + 1e-6)
    map50     = compute_ap(np.array(recalls), np.array(precisions))

    return precision, recall, f1, map50


# ===============================
# RUN (load weight ‚Üí test 20%)
# ===============================
device = "cuda" if torch.cuda.is_available() else "cpu"

model = JFSTDETR(
    num_classes=1,
    hidden_dim=128,  # 256 ‚Üí 128
    num_queries=50   # 100 ‚Üí 50
).to(device)

ckpt = torch.load("./checkpoints/jfst_detr_best.pth", map_location=device)
model.load_state_dict(ckpt["model_state_dict"])

# val_img_pathsÏóêÏÑú DataLoader ÏÉùÏÑ±
val_dataset = MaskDataset(
    img_paths=val_img_paths,
    mask_dir=mask_dir,
    img_size=640
)
val_loader = DataLoader(
    val_dataset,
    batch_size=1,  # 2 ‚Üí 1 (Î©îÎ™®Î¶¨ Ï†àÏïΩ)
    shuffle=False,
    num_workers=0,  # Î©ÄÌã∞ÌîÑÎ°úÏÑ∏Ïã± ÎπÑÌôúÏÑ±Ìôî
    collate_fn=collate_fn
)

precision, recall, f1, map50 = evaluate(
    model,
    val_loader,  # DataLoader Ï†ÑÎã¨
    device,
    score_thresh=0.5,
    iou_thresh=0.5
)

print(f"Precision : {precision:.4f}")
print(f"Recall    : {recall:.4f}")
print(f"F1-score  : {f1:.4f}")
print(f"mAP@0.5   : {map50:.4f}")
