In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.nn import functional as F
from PIL import Image
import clip
import numpy as np

class ZeroShotTracker:
    def __init__(self, clip_model="ViT-B/32"):
        """
        Initialize the zero-shot tracker with CLIP for semantic understanding
        and a simple motion predictor.
        """
        # Load CLIP model for zero-shot detection
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.clip_model, self.preprocess = clip.load(clip_model, device=self.device)
        
        # Initialize tracking state
        self.tracking_state = None
        self.previous_bbox = None
        self.target_embedding = None
        
        # Configuration
        self.search_window = 1.5  # Search window multiplier
        self.confidence_threshold = 0.7
        
    def encode_target(self, target_name):
        """
        Encode the target object name into CLIP embedding space
        """
        text = clip.tokenize([target_name]).to(self.device)
        with torch.no_grad():
            text_features = self.clip_model.encode_text(text)
        self.target_embedding = text_features / text_features.norm(dim=-1, keepdim=True)
        
    def get_region_proposals(self, frame, prev_bbox):
        """
        Generate region proposals based on previous bbox and motion prediction
        """
        if prev_bbox is None:
            # If no previous bbox, search the entire frame
            return self._get_sliding_windows(frame.size)
            
        # Generate proposals around previous bbox
        x, y, w, h = prev_bbox
        center_x = x + w/2
        center_y = y + h/2
        
        # Create search window
        search_w = w * self.search_window
        search_h = h * self.search_window
        search_x = max(0, center_x - search_w/2)
        search_y = max(0, center_y - search_h/2)
        
        return self._generate_dense_proposals(
            (search_x, search_y, search_w, search_h),
            frame.size,
            num_scales=3
        )
        
    def _generate_dense_proposals(self, search_window, frame_size, num_scales=3):
        """
        Generate dense proposals within search window at multiple scales
        """
        proposals = []
        x, y, w, h = search_window
        scales = np.linspace(0.8, 1.2, num_scales)
        
        for scale in scales:
            scaled_w = w * scale
            scaled_h = h * scale
            
            # Generate proposals with 50% overlap
            step_x = scaled_w * 0.5
            step_y = scaled_h * 0.5
            
            for px in np.arange(x, x + w - scaled_w, step_x):
                for py in np.arange(y, y + h - scaled_h, step_y):
                    proposals.append((px, py, scaled_w, scaled_h))
                    
        return proposals
        
    def track(self, frame, target_name=None):
        """
        Track the target object in the given frame
        """
        if target_name is not None:
            self.encode_target(target_name)
            
        if self.target_embedding is None:
            raise ValueError("No target object specified")
            
        # Convert frame to PIL if needed
        if not isinstance(frame, Image.Image):
            frame = Image.fromarray(frame)
            
        # Get region proposals
        proposals = self.get_region_proposals(frame, self.previous_bbox)
        
        # Score each proposal using CLIP
        best_score = -float('inf')
        best_bbox = None
        
        for bbox in proposals:
            x, y, w, h = [int(v) for v in bbox]
            region = frame.crop((x, y, x+w, y+h))
            score = self._compute_clip_score(region)
            
            if score > best_score:
                best_score = score
                best_bbox = bbox
                
        # Update tracking state
        if best_score > self.confidence_threshold:
            self.previous_bbox = best_bbox
            return best_bbox, best_score
        else:
            self.previous_bbox = None
            return None, best_score
            
    def _compute_clip_score(self, region):
        """
        Compute CLIP similarity score between image region and target text
        """
        # Preprocess image
        image = self.preprocess(region).unsqueeze(0).to(self.device)
        
        # Get image features
        with torch.no_grad():
            image_features = self.clip_model.encode_image(image)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            
        # Compute similarity
        similarity = F.cosine_similarity(image_features, self.target_embedding)
        return similarity.item()
        
    def _get_sliding_windows(self, frame_size, min_size=64):
        """
        Generate sliding window proposals for initial detection
        """
        width, height = frame_size
        proposals = []
        
        # Generate windows at multiple scales
        for scale in np.linspace(0.1, 0.5, 5):
            w = int(width * scale)
            h = int(height * scale)
            
            if w < min_size or h < min_size:
                continue
                
            step_x = w // 2
            step_y = h // 2
            
            for x in range(0, width - w, step_x):
                for y in range(0, height - h, step_y):
                    proposals.append((x, y, w, h))
                    
        return proposals