# main



## Dependencies



In [None]:
from transformers import pipeline
from transformers import OwlViTProcessor, OwlViTForObjectDetection, AdamW
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

import argparse
import os
import json
import glob
from tqdm import tqdm
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import torch
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from fewshot_dataset import FewShotDetectionDataset, collate_fn
from utils import explore_dataset
from iou_utils import compute_iou, generalized_iou_loss, ciou_loss

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    parser.add_argument("--training_folder", default="./raw_dataset/DL-RX/train", type=str, required=False)
    parser.add_argument("--test_folder",     default="./raw_dataset/DL-RX/test",  type=str, required=False)

    parser.add_argument("--pretrained_checkpoint", default="google/owlv2-base-patch16-ensemble", type=str, required=False)
    parser.add_argument("--apply_finetuning", default=True, type=bool, required=False)

    parser.add_argument("--freezing_method", default="tiny", choices=["simple", "tiny"], type=str, required=False)
    parser.add_argument("--num_epochs", default=500, type=int, required=False)
    parser.add_argument("--batch_size", default=4,  type=int, required=False)
    parser.add_argument("--learning_rate", default=1e-4, type=float, required=False)
    parser.add_argument("--early_stopping_patience", default=10, type=int, required=False)
    parser.add_argument("--model_saving_dir", default="./finetuned_record", type=str, required=False)
    parser.add_argument('--result_folder', default='./results', type=str, required=False)
    return parser.parse_args(args=[]) # [] Only when using jupyter notebook

args = parse_args()
for key, value in vars(args).items():
    print(f"{key} => {value}")

if not os.path.exists(args.model_saving_dir):
    os.makedirs(args.model_saving_dir)
    print(f"New folder created. => [{os.path.abspath(args.model_saving_dir)}]")

if not os.path.exists(args.result_folder):
    os.makedirs(args.result_folder)
    print(f"New folder created. => [{os.path.abspath(args.result_folder)}]")

## Explore Dataset



In [None]:
# explore_dataset(dataset_folder=args.training_folder,
#                 verbose=True,
#                 visualize=True)

## FineTuning



### Model, Processor



In [None]:
model = AutoModelForZeroShotObjectDetection.from_pretrained(args.pretrained_checkpoint)
processor = AutoProcessor.from_pretrained(args.pretrained_checkpoint)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate)

### Freezing layers



In [None]:
if args.freezing_method == "simple":
    # Freeze parameters: vision/text encoder is freezed, we fine-tune only detection heads
    for name, param in model.named_parameters():
        if "vision" in name or "text" in name:
            param.requires_grad = False
        else:
            print(f"Target layer: {name}")

elif args.freezing_method == "tiny":
    allowed_keywords = [
        # "class_head.logit_scale",
        "box_head",
        "objectness_head"
    ]
    for name, param in model.named_parameters():
        if any(keyword in name for keyword in allowed_keywords):
            param.requires_grad = True
            print(f"Target layer: {name}")
        else:
            param.requires_grad = False
else:
    raise Exception("Invalid `freezing method`")

### DataLoader



In [None]:
dataset = FewShotDetectionDataset(dataset_folder=args.training_folder, processor=processor)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)

sample = next(iter(dataloader))
print(f"Sample keys: {sample.keys()}")
print(f"Sample pixel_values: {sample['pixel_values'].shape}")
print(f"Sample input_ids: {sample['input_ids'].shape}")
print(f"Sample attention_mask : {sample['attention_mask'].shape}")
print(f"Sample targets: {sample['targets']}") # label and bounding box

### Fine-tuning Loop



In [None]:
if args.apply_finetuning:
    model.to(args.device)
    saving_dir = args.model_saving_dir
    best_loss = float("inf")
    early_stopping_count = 0

    # Apply same learning rate for all updated parameters
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=4, verbose=True)

    for epoch in range(args.num_epochs):
        model.train()
        total_loss = 0.0
        for batch in tqdm(dataloader):
            optimizer.zero_grad()

            # Move batch data to device
            pixel_values   = batch["pixel_values"].to(args.device)
            input_ids      = batch["input_ids"].to(args.device)
            attention_mask = batch["attention_mask"].to(args.device)
            
            outputs = model(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            
            pred_boxes = outputs.pred_boxes[:, 0, :]  # shape: [batch_size, 4]
            pred_logits = outputs.logits[:, 0, :]     # shape: [batch_size, num_classes]

            # Ground truth: bounding box info
            gt_boxes = torch.stack([t["boxes"][0] for t in batch["targets"]]).to(args.device) # shape: [batch_size, 4]

            # Single class => ground truth label is always zero
            gt_labels = torch.zeros(len(batch["targets"]), dtype=torch.long, device=args.device)

            # Loss: We focus on localization (loss_box, loss_giou), as classification is optimized well.
            loss_cls  = F.cross_entropy(pred_logits, gt_labels)
            
            # loss_box  = F.l1_loss(pred_boxes, gt_boxes)
            loss_box = F.smooth_l1_loss(pred_boxes, gt_boxes)

            # loss_giou = generalized_iou_loss(pred_boxes, gt_boxes)            
            loss_ciou = ciou_loss(pred_boxes, gt_boxes)
            
            lambda_cls  = 0.05
            lambda_box  = 0.95
            # lambda_ciou = 0.55
            loss = (lambda_cls * loss_cls) + (lambda_box * loss_box) # + (lambda_ciou * loss_ciou)
            loss.backward()
            
            optimizer.step()
            total_loss += loss.item()
            
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch + 1}/{args.num_epochs}], Average Loss: {avg_loss:.4f}")

        scheduler.step(avg_loss)
        
        # Save best model
        if avg_loss < best_loss:
            early_stopping_count = 0
            best_loss = avg_loss
            print(f"Best model saved at epoch={epoch + 1}")
            model.save_pretrained(saving_dir)
            processor.save_pretrained(saving_dir)
        else:
            early_stopping_count += 1
        
        if early_stopping_count == args.early_stopping_patience:
            print(f"\nFine-tuning early-stopped at epoch={epoch + 1}")
            break
        print()
    print("OWL-ViT few-shot finetuning finished.")
else:
    print("No finetuning")

## Testset Inference



In [None]:
test_folder = args.test_folder
test_image_paths = glob.glob(os.path.join(test_folder, '*.bmp'))
test_annotation_paths = glob.glob(os.path.join(test_folder, '*.json'))

# Fine-tuned checkpoint or raw model
finetuned_checkpoint = args.model_saving_dir if args.apply_finetuning else args.pretrained_checkpoint
model = AutoModelForZeroShotObjectDetection.from_pretrained(finetuned_checkpoint)
processor = AutoProcessor.from_pretrained(finetuned_checkpoint)

device = args.device
model.to(device)
model.eval()

IOU_THRESHOLD = 0.5
query_label = 'industrial stabbed defect on metal bearing surface'

predictions = []   # {"image_id": ..., "score": ..., "box": ...} (Normalized coordinates, ROI)
ground_truths = {} # key: image_id, value: ground truth bbox (Normalized coordinates, ROI)

for test_image_path, test_annotation_path in zip(test_image_paths, test_annotation_paths):
    image_id = os.path.basename(test_image_path)
    full_image = Image.open(test_image_path).convert('RGB')
    full_width, full_height = full_image.size
    
    with open(test_annotation_path, 'r', encoding='utf-8') as f:
        annotation = json.load(f)
    
    roi = annotation['rois'][0]
    roi_x = roi[0]
    roi_y = roi[1]
    roi_width = roi[2]
    roi_height = roi[3]
    
    # Crop ROI
    cropped_image = full_image.crop((roi_x, roi_y, roi_x + roi_width, roi_y + roi_height))
    crop_width, crop_height = cropped_image.size
    
    # Extract ground truth bbox
    shape = annotation['shapes'][0]
    gt_bbox = shape['bbox']  # bbox dict {"x":..., "y":..., "width":..., "height":...}
    gt_full_box = [gt_bbox["x"],
                   gt_bbox["y"], 
                   gt_bbox["x"] + gt_bbox["width"], 
                   gt_bbox["y"] + gt_bbox["height"]]
    
    # Normalize ground truth bbox to ROI coordinates
    gt_roi_box = [ (gt_full_box[0] - roi_x) / roi_width,
                   (gt_full_box[1] - roi_y) / roi_height,
                   (gt_full_box[2] - roi_x) / roi_width,
                   (gt_full_box[3] - roi_y) / roi_height ]
    ground_truths[image_id] = gt_roi_box  # Single object
    
    # Inference on cropped image
    inputs = processor(text=[query_label], images=cropped_image, return_tensors='pt')
    pixel_values = inputs['pixel_values'].to(device)
    input_ids = inputs.get('input_ids')
    if input_ids is not None:
        input_ids = input_ids.to(device)
    attention_mask = inputs.get('attention_mask')
    if attention_mask is not None:
        attention_mask = attention_mask.to(device)
    
    with torch.no_grad():
        outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
    
    # Single object detection => use first query
    pred_box_norm = outputs.pred_boxes[0, 0, :].cpu().numpy()  # Normalized bbox [x_min, y_min, x_max, y_max] (ROI)
    pred_logits = outputs.logits[0, 0, :].cpu()
    score = torch.softmax(pred_logits, dim=0).numpy()[0]
    
    predictions.append({
        "image_id": image_id,
        "score": score,
        "box": pred_box_norm
    })
    
    # Visualize: Draw ground truth and predicted bbox on ROI image
    vis_image = cropped_image.copy()
    draw = ImageDraw.Draw(vis_image)
    
    # Ground Truth: normalized bbox -> absolute coordinates
    gt_abs_box = [ int(gt_roi_box[0] * crop_width),
                   int(gt_roi_box[1] * crop_height),
                   int(gt_roi_box[2] * crop_width),
                   int(gt_roi_box[3] * crop_height) ]
    # Ensure proper ordering
    gt_xmin, gt_ymin, gt_xmax, gt_ymax = (min(gt_abs_box[0], gt_abs_box[2]),
                                          min(gt_abs_box[1], gt_abs_box[3]),
                                          max(gt_abs_box[0], gt_abs_box[2]),
                                          max(gt_abs_box[1], gt_abs_box[3]))
    gt_abs_box = [gt_xmin, gt_ymin, gt_xmax, gt_ymax]
    
    draw.rectangle(gt_abs_box, outline=(0, 255, 0), width=4)
    draw.text((gt_xmin, max(0, gt_ymin - 16)), "Ground Truth", fill=(0, 255, 0))
    
    # Prediction: normalized bbox -> absolute coordinates
    pred_abs_box = [ int(pred_box_norm[0] * crop_width),
                     int(pred_box_norm[1] * crop_height),
                     int(pred_box_norm[2] * crop_width),
                     int(pred_box_norm[3] * crop_height) ]
    # Ensure proper ordering
    pred_xmin, pred_ymin, pred_xmax, pred_ymax = (min(pred_abs_box[0], pred_abs_box[2]),
                                                  min(pred_abs_box[1], pred_abs_box[3]),
                                                  max(pred_abs_box[0], pred_abs_box[2]),
                                                  max(pred_abs_box[1], pred_abs_box[3]))
    pred_abs_box = [pred_xmin, pred_ymin, pred_xmax, pred_ymax]
    
    draw.rectangle(pred_abs_box, outline=(255, 0, 0), width=4)
    draw.text((pred_xmin, pred_ymin - 16), f"Prediction ({score:.4f})", fill=(255, 0, 0))
     
    plt.figure(figsize=(14, 8))
    plt.imshow(vis_image)
    plt.title(f"Prediction vs Ground Truth (Image: {image_id})")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(f'./results/{os.path.splitext(image_id)[0]}_result.png', dpi=300)
    plt.close()

# Evaluation: Calculate AP (IoU threshold=0.5)
predictions = sorted(predictions, key=lambda x: x["score"], reverse=True) 

TPs = np.zeros(len(predictions))
FPs = np.zeros(len(predictions))
total_gt = len(ground_truths) 

detected = {image_id: False for image_id in ground_truths.keys()}

for idx, pred in enumerate(predictions):
    image_id = pred["image_id"]
    pred_box = pred["box"]
    gt_box = ground_truths.get(image_id)
    
    iou = compute_iou(pred_box, gt_box)
    if iou >= IOU_THRESHOLD:
        if not detected[image_id]:
            TPs[idx] = 1
            detected[image_id] = True
        else:
            FPs[idx] = 1
    else:
        FPs[idx] = 1

cum_TP = np.cumsum(TPs)
cum_FP = np.cumsum(FPs)
precisions = cum_TP / (cum_TP + cum_FP + 1e-6)
recalls = cum_TP / (total_gt + 1e-6)

AP = 0.0
for i in range(1, len(precisions)):
    AP += (recalls[i] - recalls[i-1]) * precisions[i]

print("===== Evaluation Results =====")
print(f"Total Ground Truth Boxes: {total_gt}")
print(f"Average Precision (AP) at IoU >= {IOU_THRESHOLD}: {AP:.4f}")
print(f"Final Recall: {recalls[-1]:.4f}")
print("Test inference finished.")

# End

