In [1]:
!pip install ultralytics
!pip install albumentations

import os
import json
import cv2
import copy
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
import yaml
import gc
from collections import defaultdict

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
set_seed(42)

PyTorch version: 2.6.0+cu124
CUDA available: True
Number of GPUs: 2
GPU 0: Tesla T4
GPU 1: Tesla T4


In [None]:
def load_and_preprocess_image(img_path, img_size=640):
    """Load and preprocess image"""
    try:
        img = cv2.imread(str(img_path))
        if img is None:
            return None
        
        if len(img.shape) != 3 or img.shape[2] != 3:
            return None
        
        # Convert to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Resize
        img = cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_LINEAR)
        
        # Normalize directly with torch
        img_tensor = torch.from_numpy(img).float().div(255.0).permute(2, 0, 1)
        
        return img_tensor
        
    except Exception as e:
        print(f"Error loading image {img_path}: {e}")
        return None


def extract_frame(video_path, frame_num, img_size=640):
    """Extract specific frame from video"""
    try:
        cap = cv2.VideoCapture(video_path)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
        ret, frame = cap.read()
        cap.release()
        
        if not ret or frame is None:
            return None, None
        
        if len(frame.shape) != 3 or frame.shape[2] != 3:
            return None, None
        
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        h, w = frame.shape[:2]
        
        # Resize
        frame = cv2.resize(frame, (img_size, img_size), interpolation=cv2.INTER_LINEAR)
        
        # Normalize directly with torch
        frame_tensor = torch.from_numpy(frame).float().div(255.0).permute(2, 0, 1)
        
        return frame_tensor, (w, h)
        
    except Exception as e:
        print(f"Error extracting frame: {e}")
        return None, None


def convert_bbox_to_yolo(bbox, orig_w, orig_h, img_size=640):
    """Convert bbox to YOLO format (normalized)"""
    x1, y1, x2, y2 = bbox
    
    scale_x = img_size / orig_w
    scale_y = img_size / orig_h
    
    x1 = x1 * scale_x
    y1 = y1 * scale_y
    x2 = x2 * scale_x
    y2 = y2 * scale_y
    
    x_center = (x1 + x2) / 2.0 / img_size
    y_center = (y1 + y2) / 2.0 / img_size
    width = (x2 - x1) / img_size
    height = (y2 - y1) / img_size
    
    # Clamp to valid range
    x_center = np.clip(x_center, 0, 1)
    y_center = np.clip(y_center, 0, 1)
    width = np.clip(width, 0, 1)
    height = np.clip(height, 0, 1)
    
    return torch.tensor([x_center, y_center, width, height])

In [None]:
"""
Prototypical Networks: leaner metric-learning approach for few-shot detection
that trains embeddings and classifies by distance to a prototype; simpler than
MAML for this use case.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from tqdm.auto import tqdm
import gc

class PrototypicalYOLO(nn.Module):
    """
    Few-shot detector that reuses a frozen YOLO backbone and learns a
    prototypical embedding/detection head.
    """
    
    def __init__(self, base_model_path='yolo11n.pt', device='cuda'):
        super().__init__()
        self.device = device
        
        # Feature extractor
        from ultralytics import YOLO
        yolo = YOLO(base_model_path)
        self.backbone = yolo.model.to(device)
        
        # Freeze backbone
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        # Embedding head (project features to metric space)
        self.embedding_head = nn.Sequential(
            nn.Conv2d(144, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 128, 1),  # Embedding dimension = 128
        ).to(device)
        
        # Detection head (gi·ªëng c≈©)
        self.bbox_head = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 4, 1)  # 4 coords
        ).to(device)
        
        self.obj_head = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, 1)  # Objectness
        ).to(device)

    def extract_features(self, x):
        """Extract backbone features."""
        self.backbone.train()
        
        with torch.no_grad():
            features = self.backbone(x)
            if isinstance(features, (list, tuple)):
                features = features[0]
        return features
        
    def forward(self, x, return_embeddings=False):
        """
        Forward pass
        
        Args:
            x: [B, 3, H, W]
            return_embeddings: N·∫øu True, tr·∫£ v·ªÅ embeddings thay v√¨ detections
        
        Returns:
            embeddings: [B, 128, H/32, W/32] ho·∫∑c
            (bbox, obj): Predictions
        """
        # Extract features
        features = self.extract_features(x)  # [B, 144, H/32, W/32]
        
        # Get embeddings
        embeddings = self.embedding_head(features)  # [B, 128, H/32, W/32]
        
        if return_embeddings:
            return embeddings
        
        # Detection
        bbox = self.bbox_head(embeddings)
        obj = self.obj_head(embeddings)
        
        return bbox, obj
    
    def compute_prototype(self, support_imgs, support_bboxes):
        """
        Build a class prototype from the support set.
        
        Args:
            support_imgs: [N, 3, H, W]
            support_bboxes: [N, 4] normalized xywh
        
        Returns:
            prototype: [128] embedding vector for the target object
        """
        embeddings = self.forward(support_imgs, return_embeddings=True)
        # [N, 128, H', W']
        
        # RoI pooling: L·∫•y embedding ·ªü v√πng bbox
        N, C, H, W = embeddings.shape
        
        roi_embeddings = []
        for i in range(N):
            x_c, y_c, w, h = support_bboxes[i]
            
            # Convert to grid coords
            x_c = int(x_c * W)
            y_c = int(y_c * H)
            w = max(1, int(w * W))
            h = max(1, int(h * H))
            
            x1 = max(0, x_c - w//2)
            y1 = max(0, y_c - h//2)
            x2 = min(W, x_c + w//2)
            y2 = min(H, y_c + h//2)
            
            # Average pooling trong RoI
            roi = embeddings[i, :, y1:y2, x1:x2]  # [128, h, w]
            roi_feat = roi.mean(dim=[1, 2])  # [128]
            roi_embeddings.append(roi_feat)
        
        # Prototype is the mean of support embeddings
        prototype = torch.stack(roi_embeddings).mean(dim=0)  # [128]
        
        return prototype
    
    def compute_loss_with_prototype(self, query_imgs, query_targets, prototype):
        """
        Loss that ties detections to a class prototype.
        
        Args:
            query_imgs: [B, 3, H, W]
            query_targets: [B, 4]
            prototype: [128] target embedding
        
        Returns:
            loss: scalar
        """
        # Get predictions
        bbox_pred, obj_pred = self.forward(query_imgs)
        
        # Get embeddings
        embeddings = self.forward(query_imgs, return_embeddings=True)
        # [B, 128, H', W']
        
        B, C, H, W = embeddings.shape
        
        # Reshape
        bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(B, -1, 4)
        obj_pred = obj_pred.permute(0, 2, 3, 1).reshape(B, -1)
        embeddings = embeddings.permute(0, 2, 3, 1).reshape(B, -1, C)
        
        # Loss 1: embedding similarity to the prototype
        prototype_expanded = prototype.unsqueeze(0).unsqueeze(0)  # [1, 1, 128]
        
        distances = torch.norm(
            embeddings - prototype_expanded, 
            dim=2
        )  # [B, N]
        
        # Pick the anchor with highest IoU as the positive
        def compute_anchor_iou(pred_boxes, target_boxes):
            """Compute IoU between predictions and targets."""
            def xywh2xyxy(boxes):
                x, y, w, h = boxes[..., 0], boxes[..., 1], boxes[..., 2], boxes[..., 3]
                x1 = x - w / 2
                y1 = y - h / 2
                x2 = x + w / 2
                y2 = y + h / 2
                return torch.stack([x1, y1, x2, y2], dim=-1)
            
            pred_xyxy = xywh2xyxy(pred_boxes)
            target_xyxy = xywh2xyxy(target_boxes.unsqueeze(1))
            
            x1 = torch.max(pred_xyxy[..., 0], target_xyxy[..., 0])
            y1 = torch.max(pred_xyxy[..., 1], target_xyxy[..., 1])
            x2 = torch.min(pred_xyxy[..., 2], target_xyxy[..., 2])
            y2 = torch.min(pred_xyxy[..., 3], target_xyxy[..., 3])
            
            inter = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
            pred_area = (pred_xyxy[..., 2] - pred_xyxy[..., 0]) * (pred_xyxy[..., 3] - pred_xyxy[..., 1])
            target_area = (target_xyxy[..., 2] - target_xyxy[..., 0]) * (target_xyxy[..., 3] - target_xyxy[..., 1])
            union = pred_area + target_area - inter + 1e-7
            
            return inter / union
        
        iou = compute_anchor_iou(bbox_pred, query_targets)  # [B, N]
        best_anchor = iou.argmax(dim=1)  # [B]
        
        # Contrastive loss: K√©o best anchor g·∫ßn prototype, ƒë·∫©y c√°c anchor kh√°c xa
        positive_distances = distances[torch.arange(B), best_anchor]
        
        # Negatives: all other anchors
        negative_mask = torch.ones_like(distances, dtype=torch.bool)
        negative_mask[torch.arange(B), best_anchor] = False
        negative_distances = distances[negative_mask].view(B, -1)
        
        # Triplet-style margin loss
        margin = 0.5
        contrastive_loss = F.relu(
            positive_distances.unsqueeze(1) - negative_distances + margin
        ).mean()
        
        # Loss 2: bbox regression
        best_bbox = bbox_pred[torch.arange(B), best_anchor]
        bbox_loss = F.smooth_l1_loss(best_bbox, query_targets)
        
        # Loss 3: objectness
        obj_targets = torch.zeros_like(obj_pred)
        obj_targets[torch.arange(B), best_anchor] = 1.0
        obj_loss = F.binary_cross_entropy_with_logits(obj_pred, obj_targets)
        
        # Total loss
        total_loss = (
            contrastive_loss * 2.0 +
            bbox_loss * 5.0 +
            obj_loss * 1.0
        )
        
        return total_loss


# ==========================================
# TRAINING LOOP - PHI√äN B·∫¢N ƒê·∫¶Y ƒê·ª¶
# ==========================================

class PrototypicalTrainer:
    """
    End-to-end trainer for the prototypical detector, with multi-GPU,
    AMP, checkpointing, and tracking utilities.
    """
    
    def __init__(self, model, cfg):
        self.model = model
        self.cfg = cfg
        self.device = cfg.DEVICE
        
        # Multi-GPU support
        if cfg.USE_MULTI_GPU and torch.cuda.device_count() > 1:
            print(f"‚úì Using {torch.cuda.device_count()} GPUs")
            self.model = nn.DataParallel(model)
        
        self.model = self.model.to(self.device)
        
        # Get model for parameter access
        self.base_model = self.model.module if hasattr(self.model, 'module') else self.model
        
        # Optimizer: only train embedding and detection heads
        self.optimizer = torch.optim.AdamW([
            {'params': self.base_model.embedding_head.parameters(), 'lr': cfg.META_LR},
            {'params': self.base_model.bbox_head.parameters(), 'lr': cfg.META_LR},
            {'params': self.base_model.obj_head.parameters(), 'lr': cfg.META_LR},
        ], weight_decay=0.01)
        
        # Learning rate scheduler
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=cfg.META_EPOCHS
        )
        
        # Mixed precision training
        self.scaler = torch.amp.GradScaler('cuda') if cfg.MIXED_PRECISION else None
        
        # Tracking metrics
        self.train_losses = []
        self.val_losses = []
        self.best_val_loss = float('inf')
        self.current_epoch = 0
    
    def train_step(self, task_batch):
        """
        One training step over a batch of few-shot tasks.
        
        task_batch: List of {
            'support': [N_support, 3, H, W],
            'query': [N_query, 3, H, W],
            'support_targets': [N_support, 4],
            'query_targets': [N_query, 4]
        }
        """
        self.model.train()
        total_loss = 0.0
        task_count = 0
        
        for task in task_batch:
            support = task['support'].to(self.device)
            query = task['query'].to(self.device)
            support_targets = task['support_targets'].to(self.device)
            query_targets = task['query_targets'].to(self.device)
            
            # Mixed precision context
            with torch.amp.autocast('cuda', enabled=self.cfg.MIXED_PRECISION):
                # 1. Compute prototype
                prototype = self.base_model.compute_prototype(support, support_targets)
                
                # 2. Compute loss on query set
                loss = self.base_model.compute_loss_with_prototype(
                    query, query_targets, prototype
                )
            
            total_loss += loss
            task_count += 1
        
        # Backward
        if task_count > 0:
            total_loss = total_loss / task_count
            
            self.optimizer.zero_grad()
            
            if self.scaler is not None:
                self.scaler.scale(total_loss).backward()
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.optimizer.step()
        
        return total_loss.item() if task_count > 0 else 0.0
    
    def train_epoch(self, tasks, sample_task_batch_fn):
        """
        Run a full training epoch.
        
        Args:
            tasks: List of task dictionaries
            sample_task_batch_fn: Function to sample a task batch
        """
        self.model.train()
        epoch_losses = []
        
        n_batches = max(1, len(tasks) // self.cfg.META_BATCH_SIZE)
        
        pbar = tqdm(range(n_batches), desc=f"Epoch {self.current_epoch+1}")
        
        for batch_idx in pbar:
            # Sample tasks
            import random
            sampled_tasks = random.sample(
                tasks,
                min(self.cfg.META_BATCH_SIZE, len(tasks))
            )
            
            # Get task batch
            task_batch = sample_task_batch_fn(
                sampled_tasks,
                self.cfg.N_SUPPORT,
                self.cfg.N_QUERY
            )
            
            if len(task_batch) == 0:
                continue
            
            # Train step
            loss = self.train_step(task_batch)
            epoch_losses.append(loss)
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss:.4f}',
                'lr': f'{self.scheduler.get_last_lr()[0]:.6f}'
            })
            
            # Memory cleanup
            if batch_idx % 10 == 0:
                torch.cuda.empty_cache()
                gc.collect()
        
        avg_loss = np.mean(epoch_losses) if epoch_losses else 0.0
        return avg_loss
    
    def evaluate(self, tasks, sample_task_batch_fn, n_eval=5):
        """
        Evaluate on a handful of validation tasks.
        
        Args:
            tasks: Validation tasks
            sample_task_batch_fn: Sampler function
            n_eval: Number of tasks to evaluate
        """
        self.model.eval()
        eval_losses = []
        
        import random
        eval_tasks = random.sample(tasks, min(n_eval, len(tasks)))
        
        with torch.no_grad():
            for task in tqdm(eval_tasks, desc="Evaluating"):
                task_batch = sample_task_batch_fn(
                    [task],
                    self.cfg.N_SUPPORT,
                    self.cfg.N_QUERY
                )
                
                if len(task_batch) == 0:
                    continue
                
                task_data = task_batch[0]
                support = task_data['support'].to(self.device)
                query = task_data['query'].to(self.device)
                support_targets = task_data['support_targets'].to(self.device)
                query_targets = task_data['query_targets'].to(self.device)
                
                with torch.amp.autocast('cuda', enabled=self.cfg.MIXED_PRECISION):
                    prototype = self.base_model.compute_prototype(support, support_targets)
                    loss = self.base_model.compute_loss_with_prototype(
                        query, query_targets, prototype
                    )
                
                eval_losses.append(loss.item())
        
        avg_loss = np.mean(eval_losses) if eval_losses else 0.0
        return avg_loss
    
    def save_checkpoint(self, epoch, path, is_best=False):
        """
        Save a full checkpoint (model + optimizer + scheduler + history).
        
        Args:
            epoch: Current epoch
            path: Save path
            is_best: Whether this is the best model so far
        """
        # Get base model (kh√¥ng bao g·ªìm DataParallel wrapper)
        model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
        
        checkpoint = {
            # Model weights
            'epoch': epoch,
            'model_state_dict': model_to_save.state_dict(),
            
            # Optimizer & Scheduler
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            
            # Training history
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'best_val_loss': self.best_val_loss,
            
            # Config
            'config': {
                'META_LR': self.cfg.META_LR,
                'IMG_SIZE': self.cfg.IMG_SIZE,
                'N_SUPPORT': self.cfg.N_SUPPORT,
                'N_QUERY': self.cfg.N_QUERY,
            },
            
            # Scaler state (n·∫øu d√πng mixed precision)
            'scaler_state_dict': self.scaler.state_dict() if self.scaler else None,
        }
        
        torch.save(checkpoint, path)
        
        suffix = " (BEST)" if is_best else ""
        print(f"‚úì Checkpoint saved: {path}{suffix}")
    
    def load_checkpoint(self, path):
        """
        Load a full checkpoint.
        
        Args:
            path: Checkpoint path
        
        Returns:
            epoch: Epoch number
        """
        checkpoint = torch.load(path, map_location=self.device)
        
        # Load model weights
        model_to_load = self.model.module if hasattr(self.model, 'module') else self.model
        model_to_load.load_state_dict(checkpoint['model_state_dict'])
        
        # Load optimizer
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # Load scheduler
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        # Load training history
        self.train_losses = checkpoint.get('train_losses', [])
        self.val_losses = checkpoint.get('val_losses', [])
        self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
        
        # Load scaler
        if self.scaler and checkpoint.get('scaler_state_dict'):
            self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
        
        print(f"‚úì Checkpoint loaded: {path}")
        print(f"  - Epoch: {checkpoint['epoch']}")
        print(f"  - Best Val Loss: {self.best_val_loss:.4f}")
        
        return checkpoint['epoch']
    
    def step_scheduler(self):
        """Update learning rate"""
        self.scheduler.step()
    
    def get_lr(self):
        """Get current learning rate"""
        return self.scheduler.get_last_lr()[0]


print("""
‚úì PrototypicalTrainer ƒê·∫¶Y ƒê·ª¶ CH·ª®C NƒÇNG:
  1. ‚úì Multi-GPU support
  2. ‚úì Mixed precision training
  3. ‚úì Gradient clipping
  4. ‚úì Learning rate scheduling
  5. ‚úì Save/Load checkpoint ƒê·∫¶Y ƒê·ª¶
  6. ‚úì Training history tracking
  7. ‚úì Memory optimization
  
BACKBONE ƒê∆Ø·ª¢C L∆ØU TRONG:
- checkpoint['model_state_dict'] ch·ª©a T·∫§T C·∫¢ weights (backbone + heads)
- Backbone v·∫´n frozen (kh√¥ng update), nh∆∞ng ƒê∆Ø·ª¢C L∆ØU HO√ÄN CH·ªàNH
""")


‚úì PrototypicalTrainer ƒê·∫¶Y ƒê·ª¶ CH·ª®C NƒÇNG:
  1. ‚úì Multi-GPU support
  2. ‚úì Mixed precision training
  3. ‚úì Gradient clipping
  4. ‚úì Learning rate scheduling
  5. ‚úì Save/Load checkpoint ƒê·∫¶Y ƒê·ª¶
  6. ‚úì Training history tracking
  7. ‚úì Memory optimization
  
BACKBONE ƒê∆Ø·ª¢C L∆ØU TRONG:
- checkpoint['model_state_dict'] ch·ª©a T·∫§T C·∫¢ weights (backbone + heads)
- Backbone v·∫´n frozen (kh√¥ng update), nh∆∞ng ƒê∆Ø·ª¢C L∆ØU HO√ÄN CH·ªàNH



In [4]:
class Config:
    # Paths
    DATASET_DIR = '/kaggle/input/raw-data/train'
    OUTPUT_DIR = '/kaggle/working/yolomaml_output'
    
    # Model
    BASE_MODEL = 'yolo11n.pt'
    IMG_SIZE = 640
    
    # Meta-Learning Parameters
    META_BATCH_SIZE = 2
    N_SUPPORT = 3
    N_QUERY = 5
    
    # Training
    META_EPOCHS = 50
    INNER_STEPS = 3
    META_LR = 0.0005
    INNER_LR = 0.01
    
    # Data split
    META_TRAIN_RATIO = 0.8
    
    # Device
    USE_MULTI_GPU = True
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Checkpointing
    SAVE_FREQ = 5
    
    # Memory optimization
    GRADIENT_ACCUMULATION_STEPS = 4
    MIXED_PRECISION = True
    
cfg = Config()
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

In [5]:
import random 

def few_shot_inference(model_path, test_video_dir, output_json='predictions.json', confidence=0.05):
    print("="*80)
    print("YOLOMAML Few-Shot Inference")
    print("="*80)
    
    # Load model
    print("\n[1/4] Loading model...")
    model = PrototypicalYOLO(cfg.BASE_MODEL, cfg.DEVICE)
    checkpoint = torch.load(model_path, map_location=cfg.DEVICE, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Find videos
    print("\n[2/4] Finding test videos...")
    video_folders = sorted([
        f for f in os.listdir(test_video_dir)
        if os.path.isdir(os.path.join(test_video_dir, f))
    ])
    print(f"Found {len(video_folders)} test videos")
    
    # Process videos
    print("\n[3/4] Processing videos...")
    all_predictions = []
    
    for video_id in tqdm(video_folders, desc="Inference"):
        video_dir = os.path.join(test_video_dir, video_id)
        video_path = os.path.join(video_dir, 'drone_video.mp4')
        ref_dir = os.path.join(video_dir, 'object_images')
        
        # Check video path
        if not os.path.exists(video_path):
            print(f"\n‚ö†Ô∏è  Video not found: {video_path}")
            all_predictions.append({"video_id": video_id, "annotations": []})
            continue
        
        # Load support images
        support_images = []
        for img_name in ['img_1.jpg', 'img_2.jpg', 'img_3.jpg']:
            img_path = os.path.join(ref_dir, img_name)
            if os.path.exists(img_path):
                img = load_and_preprocess_image(img_path, cfg.IMG_SIZE)
                if img is not None:
                    support_images.append(img)
        
        # Check support images
        if len(support_images) < 3:
            print(f"\n‚ö†Ô∏è  Not enough reference images for {video_id}: {len(support_images)}/3")
            all_predictions.append({"video_id": video_id, "annotations": []})
            continue
        
        support_tensor = torch.stack(support_images).to(cfg.DEVICE)
        
        # Create support targets
        support_targets = torch.tensor([
            [0.5, 0.5, 0.4, 0.4]
        ]).repeat(len(support_images), 1).to(cfg.DEVICE)
        
        # Compute prototype
        with torch.no_grad():
            prototype = model.compute_prototype(support_tensor, support_targets)
        
        # Process video
        cap = cv2.VideoCapture(video_path)
        frame_idx = 0
        video_bboxes = []
        
        orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        detection_count = 0
        
        with torch.no_grad():
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                
                # Preprocess
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame_resized = cv2.resize(frame_rgb, (cfg.IMG_SIZE, cfg.IMG_SIZE))
                frame_tensor = torch.from_numpy(frame_resized).permute(2, 0, 1).float() / 255.0
                frame_tensor = frame_tensor.unsqueeze(0).to(cfg.DEVICE)
                
                # Get predictions
                bbox_pred, obj_pred = model(frame_tensor)
                embeddings = model.forward(frame_tensor, return_embeddings=True)
                
                B, C, H, W = embeddings.shape
                bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(B, -1, 4)
                obj_pred = obj_pred.permute(0, 2, 3, 1).reshape(B, -1)
                embeddings = embeddings.permute(0, 2, 3, 1).reshape(B, -1, C)
                
                # Compute distances
                prototype_expanded = prototype.unsqueeze(0).unsqueeze(0)
                distances = torch.norm(embeddings - prototype_expanded, dim=2)
                
                # Get TOP-K anchors
                TOP_K = 10
                topk_vals, topk_indices = torch.topk(distances[0], k=TOP_K, largest=False)
                
                for k_idx in range(TOP_K):
                    anchor_idx = topk_indices[k_idx].item()
                    obj_score = torch.sigmoid(obj_pred[0, anchor_idx]).item()
                    
                    if obj_score > confidence:
                        best_bbox = bbox_pred[0, anchor_idx].cpu().numpy()
                        
                        # Convert to pixel coords
                        x_c, y_c, w, h = best_bbox
                        x_c_pixel = x_c * orig_width
                        y_c_pixel = y_c * orig_height
                        w_pixel = w * orig_width
                        h_pixel = h * orig_height
                        
                        x1 = int(x_c_pixel - w_pixel/2)
                        y1 = int(y_c_pixel - h_pixel/2)
                        x2 = int(x_c_pixel + w_pixel/2)
                        y2 = int(y_c_pixel + h_pixel/2)
                        
                        # Clamp
                        x1 = max(0, min(x1, orig_width))
                        y1 = max(0, min(y1, orig_height))
                        x2 = max(0, min(x2, orig_width))
                        y2 = max(0, min(y2, orig_height))
                        
                        # Valid box check
                        if x2 > x1 and y2 > y1:
                            video_bboxes.append({
                                "frame": frame_idx,
                                "x1": x1, "y1": y1,
                                "x2": x2, "y2": y2
                            })
                            detection_count += 1
                            break
                
                frame_idx += 1
        
        cap.release()
        
        # Log summary for this video
        print(f"\nüìπ {video_id}: {detection_count}/{total_frames} detections")
        
        # Log random sample of 5 bboxes (if available)
        if len(video_bboxes) > 0:
            sample_size = min(5, len(video_bboxes))
            sample_bboxes = random.sample(video_bboxes, sample_size)
            print(f"   Sample detections:")
            for bbox in sample_bboxes:
                print(f"   - Frame {bbox['frame']:3d}: [{bbox['x1']:4d},{bbox['y1']:4d},{bbox['x2']:4d},{bbox['y2']:4d}]")
        
        # Format annotations
        annotations = []
        if len(video_bboxes) > 0:
            annotations.append({"bboxes": video_bboxes})
        
        all_predictions.append({
            "video_id": video_id,
            "annotations": annotations
        })
    
    # Save
    print(f"\n[4/4] Saving predictions to {output_json}...")
    with open(output_json, 'w') as f:
        json.dump(all_predictions, f, indent=4)
    
    print("\n‚úì Inference complete!")
    return all_predictions

In [6]:
def run_complete_inference():
    """Run complete inference pipeline"""
    
    # Define paths
    MODEL_PATH = '/kaggle/input/prototypical-model/prototypical_best.pt'
    TEST_DIR = '/kaggle/input/public-test/public_test/samples'
    OUTPUT_JSON = 'yolomaml_predictions.json'
    
    # Check if model exists
    if not os.path.exists(MODEL_PATH):
        print(f"Error: Model not found at {MODEL_PATH}")
        print("Please train the model first!")
        return
    
    # Run inference
    print("\nRunning Few-Shot Inference...")
    predictions = few_shot_inference(
        model_path=MODEL_PATH,
        test_video_dir=TEST_DIR,
        output_json=OUTPUT_JSON,
        confidence=0.05
    )
    
    # Final summary
    print("\n" + "="*80)
    print("INFERENCE SUMMARY")
    print("="*80)
    print(f"‚úì Predictions saved to: {OUTPUT_JSON}")
    print(f"‚úì Total videos processed: {len(predictions)}")
    
    # Count detections
    total_detections = sum(
        len(ann['bboxes']) 
        for v in predictions 
        for ann in v['annotations']
    )
    print(f"‚úì Total detections: {total_detections}")
    
    # Per-video breakdown
    print(f"\nPer-video detections:")
    for v in predictions:
        det_count = sum(len(ann['bboxes']) for ann in v['annotations'])
        print(f"  ‚Ä¢ {v['video_id']}: {det_count} detections")
    
    print("="*80)
    
    return predictions

In [7]:
predictions = run_complete_inference()


Running Few-Shot Inference...
YOLOMAML Few-Shot Inference

[1/4] Loading model...

[2/4] Finding test videos...
Found 6 test videos

[3/4] Processing videos...


Inference:   0%|          | 0/6 [00:00<?, ?it/s]


üìπ BlackBox_0: 708/5443 detections
   Sample detections:
   - Frame 5371: [ 427, 278, 475, 293]
   - Frame 860: [ 431, 282, 475, 292]
   - Frame 369: [ 453, 259, 488, 283]
   - Frame 2488: [ 494, 316, 522, 335]
   - Frame 2167: [ 498, 317, 524, 332]

üìπ BlackBox_1: 666/5776 detections
   Sample detections:
   - Frame 1989: [ 489, 304, 519, 315]
   - Frame 1071: [ 434, 258, 473, 274]
   - Frame 905: [ 458, 272, 500, 287]
   - Frame 5008: [ 492, 307, 535, 321]
   - Frame 840: [ 524, 336, 556, 354]

üìπ CardboardBox_0: 554/5285 detections
   Sample detections:
   - Frame 3905: [ 410, 237, 434, 258]
   - Frame 480: [ 483, 315, 506, 333]
   - Frame 475: [ 463, 277, 478, 292]
   - Frame 1131: [ 479, 275, 511, 295]
   - Frame 2327: [ 444, 302, 492, 318]

üìπ CardboardBox_1: 700/5942 detections
   Sample detections:
   - Frame 1616: [ 449, 278, 474, 295]
   - Frame 2884: [ 475, 293, 483, 314]
   - Frame 4957: [ 439, 250, 468, 268]
   - Frame 363: [ 470, 309, 507, 322]
   - Frame 4581: [