In [None]:
local_detr_path = "shared/mldatasets/huggingface/hub/models--facebook--detr-resnet-50/snapshots/1d5f47bd3bdd2c4bbfa585418ffe6da5028b4c0b"
from transformers import DetrImageProcessor

processor = DetrImageProcessor.from_pretrained(
    local_detr_path,
    local_files_only=True
)

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.models import resnet50
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        try:
            resnet = resnet50(weights='IMAGENET1K_V1')
        except TypeError:
            resnet = resnet50(pretrained=True)
            
        self.feature_extractor = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            nn.Conv2d(1024, 256, kernel_size=1)
        )
        
    def forward(self, x):
        output = self.feature_extractor(x)
        return output

class MovingObjectDataset(Dataset):
    
    def __init__(self, frame_pairs_dir, annotation_dir, img_size=(800, 1333)):
        self.annotation_dir = annotation_dir
        self.img_size = img_size
        self.device = device
        self.resize = Resize(img_size, interpolation=Image.BILINEAR)
        
        self.feature_extractor = FeatureExtractor().to(device)
        self.feature_extractor.eval()
        self.transform = Compose([
            Resize(self.img_size),
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.all_annotation_files = sorted([
            os.path.join(self.annotation_dir, f)
            for f in os.listdir(self.annotation_dir)
            if f.endswith('.txt')
        ])
        
        self.samples = []
        annotation_file_idx = 0 
        
        for subdir_name in sorted(os.listdir(frame_pairs_dir)): 
            subdir_path = os.path.join(frame_pairs_dir, subdir_name)
            if not os.path.isdir(subdir_path):
                continue
        
            image_files_in_subdir = sorted([
                os.path.join(subdir_path, f)
                for f in os.listdir(subdir_path)
                if f.lower().endswith(('.jpg', '.jpeg', '.png'))
            ])
        
            for i in range(0, len(image_files_in_subdir) - 1, 2):
    
                frame1_path = image_files_in_subdir[i]
                frame2_path = image_files_in_subdir[i+1]
                
                current_ann_file = self.all_annotation_files[annotation_file_idx]
        
                self.samples.append((frame1_path, frame2_path, current_ann_file))
                
                annotation_file_idx += 1 
        
            if annotation_file_idx >= len(self.all_annotation_files):
                break 
            
    def __len__(self):
        return len(self.samples)

    def parse_annotations(self, annotation_path):
        objects = []
        try:
            with open(annotation_path) as f:
                lines = [line.strip() for line in f.readlines() if line.strip()]
            for i in range(0, len(lines), 2):
                if i+1 >= len(lines):
                    break
                old = lines[i].split()
                new = lines[i+1].split()
                objects.append({
                    'id': int(old[0]),
                    'old_bbox': list(map(float, old[1:5])),
                    'new_bbox': list(map(float, new[1:5])),
                    'class': int(old[5])
                })
        except Exception as e:
            print(f"Error parsing {annotation_path}: {e}")
        return objects
    
    def __getitem__(self, idx):
        frame1_path, frame2_path, ann_path = self.samples[idx]

        img1 = Image.open(frame1_path).convert("RGB")
        img2 = Image.open(frame2_path).convert("RGB")
        
        img1 = self.resize(img1)
        img2 = self.resize(img2)
        
        diff = Image.fromarray(
            np.abs(np.array(img1, dtype=np.int16) - np.array(img2, dtype=np.int16))
            .astype(np.uint8)
        )
    
        encoding = processor(images=diff, return_tensors="pt", do_pad = True)
        pixel_values = encoding.pixel_values.squeeze(0)  
        pixel_mask   = encoding.pixel_mask.squeeze(0)   
    
        annotations = self.parse_annotations(ann_path)
    
        img_w, img_h = diff.size 
        boxes = []
        labels = []
        for obj in annotations:
            x0, y0, x1, y1 = obj['new_bbox']                 
            cx = x0 / img_w
            cy = y0 / img_h
            w  = x1 / img_w
            h  = y1 / img_h
            boxes.append([cx, cy, w, h])
            labels.append(obj['class'])
    
        targets = {
            "class_labels":torch.tensor(labels, dtype=torch.long),
            "boxes": torch.tensor(boxes, dtype=torch.float32),
            "orig_size": torch.tensor([img_h, img_w], dtype=torch.int64),
            "size": torch.tensor([img_h, img_w], dtype=torch.int64),
        }
        return {
            "pixel_values": pixel_values,
            "pixel_mask":   pixel_mask,
            "labels":       targets
        }

In [None]:
def collate_fn(batch):
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    pixel_mask   = torch.stack([item["pixel_mask"]   for item in batch])
    labels       = [item["labels"] for item in batch]
    return {
        "pixel_values": pixel_values,
        "pixel_mask":   pixel_mask,
        "labels":       labels
    }

In [None]:
from transformers import DetrForObjectDetection, DetrConfig
from transformers import DetrImageProcessor

local_detr_path = "shared/mldatasets/huggingface/hub/models--facebook--detr-resnet-50/snapshots/1d5f47bd3bdd2c4bbfa585418ffe6da5028b4c0b"
config = DetrConfig.from_pretrained(local_detr_path)
config.use_pretrained_backbone = False 

config.num_labels = 6 
config.id2label = {
    0: 'Unknown', 1: 'Person', 2: 'Car',
    3: 'Other Vehicle', 4: 'Other Object', 5: 'Bike'
} 
config.label2id = {v: k for k, v in config.id2label.items()}


model = DetrForObjectDetection.from_pretrained(
    local_detr_path,
    config=config,
    local_files_only=True,
    ignore_mismatched_sizes=True
)
model.to(device)

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

def train_model(train_loader, val_loader, model, device, epochs=1):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=5, verbose=True)
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch in train_loader:
            pixel_values = batch["pixel_values"].to(device)
            pixel_mask   = batch["pixel_mask"].to(device)
            targets = [
                {k: v.to(device) for k, v in t.items()}
                for t in batch["labels"]
            ]
            
            optimizer.zero_grad()
            outputs = model(
                pixel_values=pixel_values,
                pixel_mask=  pixel_mask,
                labels=      targets
            )
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                pixel_values = batch["pixel_values"].to(device)
                pixel_mask   = batch["pixel_mask"].to(device)
                targets = [
                    {k: v.to(device) for k, v in t.items()}
                    for t in batch["labels"]
                ]
                outputs = model(
                    pixel_values = pixel_values,
                    pixel_mask =   pixel_mask,
                    labels =       targets
                )
                val_loss += outputs.loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        scheduler.step(avg_val_loss)
        print(
            f"Epoch {epoch+1}/{epochs} — "
            f"Train Loss: {avg_train_loss:.4f}, "
            f"Val Loss: {avg_val_loss:.4f}"
        )

In [None]:
import numpy as np  
def evaluate_and_visualize(model, dataset, device, idx=0, class_names=None, threshold=0.5):
    model.eval()
    item     = dataset[idx]
    pix_vals = item["pixel_values"].unsqueeze(0).to(device, non_blocking=True)
    pix_mask = item["pixel_mask"].unsqueeze(0).to(device, non_blocking=True)

    _, _, H_processed, W_processed = pix_vals.shape

    with torch.no_grad():
        outputs = model(pixel_values=pix_vals, pixel_mask=pix_mask)

    logits      = outputs.logits.softmax(-1)[0, :, :-1]
    scores, labels = logits.max(-1)          
    boxes_norm  = outputs.pred_boxes[0]    

    frame2_path = dataset.samples[idx][1]   
    raw_img     = load_raw_image(frame2_path)  
    H_orig, W_orig = raw_img.shape[:2]
    
    sx, sy = W_orig / W_processed, H_orig / H_processed
   
    cx, cy, w, h = boxes_norm.unbind(-1)
    x_p = (cx - 0.5 * w) * W_processed
    y_p = (cy - 0.5 * h) * H_processed
    w_p = w * W_processed
    h_p = h * H_processed

    abs_boxes = torch.stack([
        x_p * sx,
        y_p * sy,
        w_p * sx,
        h_p * sy
    ], dim=-1).cpu().numpy()
   
    scores_np = scores.cpu().numpy()
    labels_np = labels.cpu().numpy()

    keep = scores_np > threshold
    abs_boxes = abs_boxes[keep]
    labels_np = labels_np[keep]
    scores_np = scores_np[keep]

    plot_detections(
        raw_img,
        abs_boxes,
        labels_np,
        scores_np,
        class_names=class_names,
        threshold=threshold
    )

In [None]:
if __name__ == "__main__":
    frame_pairs_dir = "shared/data/cv_data_hw2/data"
    annotation_dir  = "shared/hw3/matched_annotations"
    batch_size      = 2
    device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = MovingObjectDataset(frame_pairs_dir, annotation_dir)
    print(f"Found {len(dataset)} valid image pairs")
    
    n = len(dataset)
    split = int(0.8 * n)
    train_set, val_set = torch.utils.data.random_split(dataset, [split, n - split])
    
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,  collate_fn=collate_fn)
    val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    train_model(train_loader, val_loader, model, device, epochs=25)
    output_dir = "./detr_moving_checkpoint"
    model.save_pretrained(output_dir)
    processor.save_pretrained(output_dir)
    evaluate_and_visualize(model, dataset, device, idx=0)

In [None]:
for idx in range(10):
    print(f"\n--- Visualizing sample {idx} ---")
    evaluate_and_visualize(
        model,
        dataset,
        device,
        idx=idx,
        class_names=[None, "Person", "Car", "Other Vehicle", "Other Object", "Bike"],
        threshold=0.3,
    )