In [1]:
import transformers
import torch
import torchvision
import random
import matplotlib.pyplot as plt
import os

from transformers import DetrImageProcessor, DetrForObjectDetection, Trainer, TrainingArguments
from torchvision.datasets import CocoDetection
from torch.utils.data import DataLoader, Dataset

from matplotlib import cm
from PIL import Image


  from .autonotebook import tqdm as notebook_tqdm


Setup

In [2]:
# MacOS MPS (Metal Performance Shaders) support
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# CUDA support
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using device:", device)

# Load the pre-trained DETR model and processor
# Note: The model is downloaded from Hugging Face Hub, so make sure you have internet access
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")

id2label = {0: "Button", 1: "CheckBox", 
            2: "ComboBox", 3: "Heading", 
            4: "Image", 5: "Label", 
            6: "Link", 7: "Paragraph", 
            8: "RadioButton" , 9: "TextBox"}
label2id = {v: k for k, v in id2label.items()}

model.config.id2label = id2label

Using device: mps


Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Data

In [4]:
class DetrCocoDataset(Dataset):
    def __init__(self, root, annFile, processor: DetrImageProcessor):
        self.ds = CocoDetection(root, annFile)
        self.processor = processor

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        image, raw_anns = self.ds[idx]

        # 1) fix raw_anns in-place: map any string → int
        for ann in raw_anns:
            if isinstance(ann["category_id"], str):
                ann["category_id"] = label2id[ann["category_id"]]

        # 2) build the full COCO dict
        target = {
            "image_id": raw_anns[0]["image_id"],
            "annotations": raw_anns,
        }

        return {"image": image, "target": target}

def collate_fn(batch):
    images     = [item["image"] for item in batch]
    targets    = [item["target"] for item in batch]  # full COCO dicts
    
    encoding = processor(
      images=images,
      annotations=targets,
      return_tensors="pt"
    )
    
    return encoding

# Instantiate
train_dataset = DetrCocoDataset("../data/sketch2code-data/train", "../data/sketch2code-data/annotations/instances_train.json", processor)
val_dataset   = DetrCocoDataset("../data/sketch2code-data/val",   "../data/sketch2code-data/annotations/instances_val.json",   processor)

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


Training Code

In [7]:
training_args = TrainingArguments(
    output_dir="./checkpoints", # If you use colab, set this to "/content/checkpoints"

    # batches
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,

    # epochs
    num_train_epochs=100, # personal experimentation shows sketch2code data converges in 70-100 epochs
    
    dataloader_num_workers=2,  # Or try 4
    dataloader_pin_memory=True,

   # evaluation & checkpointing (I use it with wandb.ai to track training) then use ./checkpoints for outputdir
    do_eval=True,    # run validation periodically
    eval_strategy="epoch",  # evaluate every eval_steps
    save_strategy="epoch",  # save every save_steps
    save_total_limit=6,
    load_best_model_at_end=True,    # keep checkpoint with best eval loss

    # logging
    logging_steps=max(13, 1),  # twice per epoch

    # optimizer
    learning_rate=5e-5,             # slightly lower LR for small data
    weight_decay=1e-4,

    # mixed precision
    # fp16=device.type == "cuda",  # use fp16 if available
    # bf16=True, # Comment this line if you are not using MPS (MacOS)

    # pass through all fields
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,  # pack up our pixel_values + labels
)  # Trainer accepts any PyTorch dataset that returns dicts

# trainer.train()

ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.26.0`: Please run `pip install transformers[torch]` or `pip install 'accelerate>=0.26.0'`

Testing (Loss) ~ 0.9679

In [16]:
from pycocotools.coco import COCO

# Load the model
best_model = DetrForObjectDetection.from_pretrained("./checkpoints/checkpoint-3172").to(device)

# Load COCO annotations
coco_test = COCO("../data/sketch2code-data/annotations/instances_test.json")

# Create test dataset with correct argument order
test_dataset = DetrCocoDataset("../data/sketch2code-data/test", "../data/sketch2code-data/annotations/instances_test.json", processor)

# Create test loader
test_loader = DataLoader(
    test_dataset, 
    batch_size=8, 
    shuffle=False, 
    collate_fn=collate_fn, 
    num_workers=0
)

# Create evaluation-specific training arguments
eval_args = TrainingArguments(
    output_dir="./test_results",
    per_device_eval_batch_size=8,
    remove_unused_columns=False,
    fp16=False,  # Disable for MPS
    report_to="none",
    do_train=False,
    do_eval=True,
)

# Create trainer for evaluation
tester = Trainer(
    model=best_model,
    args=eval_args,
    eval_dataset=test_dataset,
    data_collator=collate_fn,
)

# Run evaluation
test_results = tester.evaluate(
    eval_dataset=test_dataset, 
    metric_key_prefix="test"
)

print("Test results:", test_results)

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


Test results: {'test_loss': 0.9679349660873413, 'test_model_preparation_time': 0.0015, 'test_runtime': 13.6349, 'test_samples_per_second': 1.687, 'test_steps_per_second': 0.22}


Helpers (for visualization)

In [17]:
# Utility: convert [x,y,w,h] → [x_min,y_min,x_max,y_max]
# Your function is good. torchvision.ops.box_convert can also do this.
def box_xywh_to_xyxy(boxes: torch.Tensor) -> torch.Tensor:
    if boxes.ndim == 1: # Handle single box case for safety
        boxes = boxes.unsqueeze(0)
    if boxes.shape[0] == 0: # Handle empty tensor
        return torch.empty((0,4), dtype=boxes.dtype, device=boxes.device)
    x, y, w, h = boxes.unbind(-1)
    return torch.stack((x, y, x + w, y + h), dim=-1)

def iou_match(pred_boxes_cls, gt_boxes_cls, iou_threshold):
    num_preds = pred_boxes_cls.shape[0]
    num_gts = gt_boxes_cls.shape[0]
    
    if num_preds == 0 and num_gts == 0: return [], [], []
    if num_preds == 0: return [], [], list(range(num_gts))
    if num_gts == 0: return [], list(range(num_preds)), []
    
    iou_matrix = torchvision.ops.box_iou(pred_boxes_cls, gt_boxes_cls)
    matches = []
    matched_gts = set()
    
    # For each prediction (ideally sorted by confidence), find best available GT
    for pred_idx in range(num_preds):
        # Find best IoU with unmatched GTs
        best_iou = -1
        best_gt_idx = -1
        
        for gt_idx in range(num_gts):
            if gt_idx not in matched_gts and iou_matrix[pred_idx, gt_idx] > best_iou:
                best_iou = iou_matrix[pred_idx, gt_idx]
                best_gt_idx = gt_idx
        
        # Match if above threshold
        if best_iou >= iou_threshold:
            matches.append((pred_idx, best_gt_idx))
            matched_gts.add(best_gt_idx)
    
    # Calculate FP and FN
    matched_preds = {m[0] for m in matches}
    fp_indices = [i for i in range(num_preds) if i not in matched_preds]
    fn_indices = [i for i in range(num_gts) if i not in matched_gts]
    
    return matches, fp_indices, fn_indices

In [18]:
## Configuration
best_model  = best_model.to(device).eval()
threshold   = 0.5 # Detection confidence threshold
iou_thresh  = 0.5 # IoU threshold for matching
max_viz     = 8
num_classes = len(id2label)
cmap        = cm.get_cmap("tab20", num_classes if num_classes <= 20 else 20) # tab20 has 20 colors (incase we choose to have more classes)
label_colors= {cid: cmap(cid % cmap.N) for cid in range(num_classes)} # Use modulo for >20 classes

  cmap        = cm.get_cmap("tab20", num_classes if num_classes <= 20 else 20) # tab20 has 20 colors (incase we choose to have more classes)


In [19]:
# Metrics accumulators
total_tp_overall = 0
total_fp_overall = 0
total_fn_overall = 0
viz_candidates = []

with torch.no_grad():
    for batch_idx, batch_from_loader in enumerate(test_loader):
        # Move to device (only tensors)
        # Assuming batch_from_loader is a dict of tensors or contains tensors
        # If batch_from_loader contains non-tensor data like list of dicts for labels, handle carefully
        # For this example, assume pixel_values and pixel_mask are the primary tensors from loader.
        
        pixel_values_batch = batch_from_loader["pixel_values"].to(device)
        # Handle optional pixel_mask
        pixel_mask_batch = batch_from_loader.get("pixel_mask")
        if pixel_mask_batch is not None:
            pixel_mask_batch = pixel_mask_batch.to(device)

        inputs = {"pixel_values": pixel_values_batch}
        if pixel_mask_batch is not None:
            inputs["pixel_mask"] = pixel_mask_batch
            
        outputs = best_model(**inputs)

        # --- CRITICAL CHANGE: Determine original image sizes for post-processing ---
        original_image_sizes_list = []
        current_batch_size = pixel_values_batch.shape[0]
        for i_in_batch in range(current_batch_size):
            dataset_idx = batch_idx * test_loader.batch_size + i_in_batch
            if dataset_idx < len(test_dataset):
                # Assuming test_dataset[dataset_idx]["image"] is a PIL Image
                pil_image = test_dataset[dataset_idx]["image"]
                original_w, original_h = pil_image.size # PIL size is (width, height)
                original_image_sizes_list.append(torch.tensor([original_h, original_w], device=device)) # Target size is (height, width)
            else:
                # Should not happen if dataloader iterates correctly over the dataset
                # Fallback, but this is likely problematic as it uses resized dimensions
                # (Lucas note) My current implmentation does not hit this case, but it's good to have a fallback
                pv_h, pv_w = pixel_values_batch.shape[2:]
                original_image_sizes_list.append(torch.tensor([pv_h, pv_w], device=device))
                print(f"Warning: Dataset index {dataset_idx} out of bounds. Using fallback size for postprocessing.")

        target_sizes_for_postprocessing = torch.stack(original_image_sizes_list)

        # Post‑process with original target_sizes
        results_batch = processor.post_process_object_detection(
            outputs, target_sizes=target_sizes_for_postprocessing, threshold=threshold
        )

        # Per‑image matching
        for i_in_batch, det_per_image in enumerate(results_batch):
            dataset_idx = batch_idx * test_loader.batch_size + i_in_batch
            if dataset_idx >= len(test_dataset): continue # Safety break

            sample_from_dataset = test_dataset[dataset_idx]

            # Ground truth (these are in original image coordinates)
            ann       = sample_from_dataset["target"]["annotations"]
            if not ann: # Handle images with no GT annotations
                gt_boxes_xyxy  = torch.empty((0, 4), device=device, dtype=torch.float)
                gt_labels = torch.empty((0,), device=device, dtype=torch.long)
            else:
                gt_boxes_xywh  = torch.tensor([a["bbox"] for a in ann], device=device, dtype=torch.float)
                gt_boxes_xyxy  = box_xywh_to_xyxy(gt_boxes_xywh) # Your converter
                gt_labels = torch.tensor([a["category_id"] for a in ann], device=device, dtype=torch.long)


            # Predictions (now scaled to original image dimensions by post_process_object_detection)
            pred_boxes_xyxy  = det_per_image["boxes"].to(device)
            pred_labels = det_per_image["labels"].to(device)
            # pred_scores = det_per_image["scores"].to(device) # If you want to use scores

            # --- Match by IoU per class (Revised TP/FP/FN logic for clarity) ---
            # These will store indices relative to pred_boxes_xyxy and gt_boxes_xyxy for the current image
            current_img_tp_pred_indices = []
            
            num_preds_img = pred_boxes_xyxy.shape[0]
            num_gts_img = gt_boxes_xyxy.shape[0]

            # Keep track of which predictions and GTs have been matched as TPs
            # This prevents a single prediction/GT from being part of multiple TPs
            # or a TP being later counted as FP/FN.
            pred_matched_as_tp = [False] * num_preds_img
            gt_matched_as_tp = [False] * num_gts_img
            
            # Iterate over all unique classes present in either predictions or ground truths for this image
            unique_classes_on_img = torch.unique(torch.cat((gt_labels, pred_labels)))

            for cls_id in unique_classes_on_img:
                # Get predictions and GTs for the current class
                pred_indices_cls_mask = (pred_labels == cls_id)
                gt_indices_cls_mask = (gt_labels == cls_id)

                current_pred_boxes_cls = pred_boxes_xyxy[pred_indices_cls_mask]
                current_gt_boxes_cls = gt_boxes_xyxy[gt_indices_cls_mask]

                # Get the original indices (relative to full pred_boxes_xyxy/gt_boxes_xyxy)
                # for the items selected for this class.
                pred_original_indices_cls = pred_indices_cls_mask.nonzero(as_tuple=True)[0]
                gt_original_indices_cls = gt_indices_cls_mask.nonzero(as_tuple=True)[0]

                if current_pred_boxes_cls.numel() == 0 and current_gt_boxes_cls.numel() == 0:
                    continue
                # Inside the loop `for cls_id in unique_classes_on_img:`
                # `matches_cls` contains pairs of (idx_in_current_pred_boxes_cls, idx_in_current_gt_boxes_cls)
                matches_cls, _, _ = iou_match(current_pred_boxes_cls, current_gt_boxes_cls, iou_thresh)
                
                for pred_cls_idx, gt_cls_idx in matches_cls:
                    original_pred_idx = pred_original_indices_cls[pred_cls_idx].item()
                    original_gt_idx = gt_original_indices_cls[gt_cls_idx].item()

                    # Ensure this pred-GT pair hasn't been matched already (e.g., if classes overlap somehow, though unlikely here)
                    # And that this specific pred or gt hasn't been used in another TP for a *different* class if that were possible
                    # For standard per-class matching, this check is more about ensuring one-to-one matching within a class context if iou_match allows multiple.
                    if not pred_matched_as_tp[original_pred_idx] and not gt_matched_as_tp[original_gt_idx]:
                        current_img_tp_pred_indices.append(original_pred_idx)
                        pred_matched_as_tp[original_pred_idx] = True
                        gt_matched_as_tp[original_gt_idx] = True
            
            # False Positives: Predictions not marked as TPs
            current_img_fp_pred_indices = [i for i, matched in enumerate(pred_matched_as_tp) if not matched]
            # False Negatives: Ground truths not marked as TPs
            current_img_fn_gt_indices = [i for i, matched in enumerate(gt_matched_as_tp) if not matched]

            total_tp_overall += len(current_img_tp_pred_indices)
            total_fp_overall += len(current_img_fp_pred_indices)
            total_fn_overall += len(current_img_fn_gt_indices)

            # Save for visualization if any errors occurred
            if current_img_fp_pred_indices or current_img_fn_gt_indices:
                viz_candidates.append({
                    "img":         sample_from_dataset["image"], # Original PIL image
                    "pred_boxes":  pred_boxes_xyxy.cpu(),  # Now in original image coords
                    "pred_labels": pred_labels.cpu(),
                    # "pred_scores": det_per_image["scores"].cpu(), # Optional
                    "tp_idx":      current_img_tp_pred_indices, # Indices of pred_boxes that are TPs
                    "fp_idx":      current_img_fp_pred_indices, # Indices of pred_boxes that are FPs
                    "gt_boxes":    gt_boxes_xyxy.cpu(),    # Original GT boxes
                    "gt_labels":   gt_labels.cpu(),
                    "fn_idx":      current_img_fn_gt_indices,   # Indices of gt_boxes that are FNs
                })

# Print overall precision & recall
precision = total_tp_overall / (total_tp_overall + total_fp_overall + 1e-8)
recall    = total_tp_overall / (total_tp_overall + total_fn_overall + 1e-8)
f1_score  = 2 * (precision * recall) / (precision + recall + 1e-8) if (precision + recall > 0) else 0.0

print(f"Precision={precision:.4f}, Recall={recall:.4f}, F1-Score={f1_score:.4f}")
print(f"TP: {total_tp_overall}, FP: {total_fp_overall}, FN: {total_fn_overall}")

Precision=0.7638, Recall=0.9174, F1-Score=0.8335
TP: 333, FP: 103, FN: 30


Visualization

In [20]:
# Create output directory if it doesn't exist
viz_output_dir = "visualizations_output"
os.makedirs(viz_output_dir, exist_ok=True)
print(f"Saving visualizations to: {os.path.abspath(viz_output_dir)}")

if viz_candidates:
    samples_to_visualize = random.sample(viz_candidates, min(max_viz, len(viz_candidates)))
else:
    samples_to_visualize = []
    print("No candidates with errors to visualize.")

for idx, sample_data in enumerate(samples_to_visualize):
    img_pil = sample_data["img"].convert("RGB")
    
    # Predictions are in sample_data["pred_boxes"], sample_data["pred_labels"]
    # Ground truths are in sample_data["gt_boxes"], sample_data["gt_labels"]
    
    # Indices for TP boxes are in sample_data["tp_idx"] (these are indices into pred_boxes)
    # Indices for FP boxes are in sample_data["fp_idx"] (these are indices into pred_boxes)
    # Indices for FN boxes are in sample_data["fn_idx"] (these are indices into gt_boxes)

    fig, ax = plt.subplots(figsize=(10,10)) # Larger figure for clarity
    ax.imshow(img_pil)
    ax.axis("off")
    ax.set_title(f"Visualization {idx+1}")

    # True Positives (solid green outline, from predicted boxes)
    for i in sample_data["tp_idx"]:
        cls_id = sample_data["pred_labels"][i].item()
        label  = id2label.get(cls_id, f"CLS {cls_id}")
        # score = sample_data["pred_scores"][i].item() # If you stored scores
        color  = label_colors.get(cls_id, (0.5,0.5,0.5)) # Default gray
        x1,y1,x2,y2 = sample_data["pred_boxes"][i]
        
        ax.add_patch(plt.Rectangle((x1,y1), x2-x1, y2-y1,
                                   edgecolor='green', facecolor="none", # Specific color for TP
                                   linewidth=2, linestyle="-"))
        ax.text(x1, y1-3, f"{label} (TP)", # {score:.2f}
                fontsize=7, color="white",
                bbox=dict(facecolor='green', alpha=0.6, pad=1, edgecolor='none'))

    # False Positives (dashed red outline, from predicted boxes)
    for i in sample_data["fp_idx"]:
        cls_id = sample_data["pred_labels"][i].item()
        label  = id2label.get(cls_id, f"CLS {cls_id}")
        # score = sample_data["pred_scores"][i].item() # If you stored scores
        color  = label_colors.get(cls_id, (0.5,0.5,0.5))
        x1,y1,x2,y2 = sample_data["pred_boxes"][i]
        
        ax.add_patch(plt.Rectangle((x1,y1), x2-x1, y2-y1,
                                   edgecolor='red', facecolor="none", # Specific color for FP
                                   linewidth=2, linestyle="--"))
        ax.text(x1, y1-3, f"{label} (FP)", # {score:.2f}
                fontsize=7, color="white",
                bbox=dict(facecolor='red', alpha=0.6, pad=1, edgecolor='none'))

    # False Negatives (dotted orange outline, from ground truth boxes)
    for i in sample_data["fn_idx"]:
        cls_id = sample_data["gt_labels"][i].item()
        label  = id2label.get(cls_id, f"CLS {cls_id}")
        color  = label_colors.get(cls_id, (0.5,0.5,0.5))
        x1,y1,x2,y2 = sample_data["gt_boxes"][i]
        
        ax.add_patch(plt.Rectangle((x1,y1), x2-x1, y2-y1,
                                   edgecolor='orange', facecolor="none",# Specific color for FN
                                   linewidth=2, linestyle=":"))
        ax.text(x1, y1-3, f"{label} (FN)",
                fontsize=7, color="black", # Contrast for orange
                bbox=dict(facecolor='orange', alpha=0.6, pad=1, edgecolor='none'))

    # plt.show() # This will pop up a window for each image
    plt.savefig(os.path.join(viz_output_dir, f"visualization_{idx}.png"), bbox_inches="tight")
    plt.close(fig) # Close the figure to free memory

Saving visualizations to: /Users/yunseolee/Documents/main/GitHub/Personal/dinov2-personalized-federated-learning/code/visualizations_output


Singleton Test on new sketches

In [None]:
from PIL import ImageDraw
# Load a separate dataset image
image = Image.open("./data/UIED-data/1243_1.jpg").convert('RGB')
inputs = processor(images=image, return_tensors="pt").to(device)
outputs = model(**inputs)

# Get results
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.40)[0]

# Print detected objects
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    print(f"Detected {model.config.id2label[label.item()]} with confidence {round(score.item(), 3)} at location {box.tolist()}")

# Draw bounding boxes on a copy of the image
image_with_boxes = image.copy()
draw = ImageDraw.Draw(image_with_boxes)
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    box = box.tolist()
    draw.rectangle(box, outline="red", width=3)
    
    # Get label text
    label_text = f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}"
    
    # Calculate text size and position, scaled based on image dimensions
    img_width, img_height = image.size
    scale_factor = min(img_width, img_height) / 500  # More aggressive scaling for larger text
    
    # Import and create font with scaled size
    from PIL import ImageFont
    try:
        font = ImageFont.truetype("Arial.ttf", size=int(20 * scale_factor))
    except:
        font = ImageFont.load_default()
    
    # Get text dimensions using font
    text_bbox = draw.textbbox((0, 0), label_text, font=font)
    text_width = text_bbox[2] - text_bbox[0]
    text_height = text_bbox[3] - text_bbox[1]
    
    # Add padding around text
    padding = int(4 * scale_factor)
    
    # Draw semi-transparent black background for better visibility
    text_box = [
        box[0],  # x0 
        max(0, box[1] - text_height - padding * 2),  # y0
        box[0] + text_width + padding * 2,  # x1 
        box[1]  # y1
    ]
    draw.rectangle(text_box, fill=(0, 0, 0, 180))  # Semi-transparent black
    
    # Draw text in white for better contrast
    draw.text(
        (box[0] + padding, max(0, box[1] - text_height - padding)), 
        label_text,
        fill="white",
        font=font
    )

# Show the image with detections
plt.figure(figsize=(10, 8))
plt.imshow(image_with_boxes)
plt.axis("off")
plt.show()
