In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

def visualize_segmentation(model, data_path, game=None, num_samples=5, save_path=None):
    """
    Visualize video object segmentation results similar to Figure 3 in the paper.
    
    Args:
        model: Trained VideoObjectSegmentationModel
        data_path: Path to the dataset
        game: Specific game to visualize (optional)
        num_samples: Number of samples to visualize
        save_path: Path to save the figure (optional)
    """
    batch_size = 32
    num_frames = 2
    device = next(model.parameters()).device
    
    # Load data
    data = VOSDataset(batch_size, num_frames, data_path, game=game)
    inp = data.get_batch("train").to(device)
    
    # Prepare input (two consecutive frames)
    x_input = torch.cat([
        torch.unsqueeze(inp[:, 0, :, :], 1), 
        torch.unsqueeze(inp[:, 1, :, :], 1)
    ], 1)
    
    model.eval()
    with torch.no_grad():
        # Forward pass
        x_reconstructed = model(x_input)
        
        # Get masks and compute weighted mask
        masks = model.object_masks  # [BS, K, H, W]
        translation_masks = model.translation_masks  # [BS, K, 2, H, W]
        
        # Compute flow magnitude for each object
        flow_magnitude = torch.sqrt(
            translation_masks[:, :, 0, :, :] ** 2 + 
            translation_masks[:, :, 1, :, :] ** 2
        )  # [BS, K, H, W]
        
        # Weight masks by their translation magnitude (model confidence)
        object_translations = model.obj_trans(
            model.relu(model.fc_conv(model.cnn(x_input)))
        ).view(-1, model.K, 2)
        translation_norms = torch.norm(object_translations, dim=2)  # [BS, K]
        
        # Weighted sum of masks
        weighted_masks = torch.zeros_like(masks[:, 0, :, :])  # [BS, H, W]
        for k in range(model.K):
            weighted_masks += masks[:, k, :, :] * translation_norms[:, k].unsqueeze(-1).unsqueeze(-1)
        
        # Normalize weighted masks
        weighted_masks = weighted_masks / (weighted_masks.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] + 1e-8)
        
        # Find most salient mask for each sample (highest flow regularization penalty)
        salient_mask_indices = torch.argmax(
            torch.norm(translation_masks, dim=2).sum(dim=2).sum(dim=2), 
            dim=1
        )  # [BS]
    
    # Create visualization
    fig, axes = plt.subplots(num_samples, 5, figsize=(20, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    # Create green colormap for overlay
    colors_green = [(0, 0, 0, 0), (0, 1, 0, 1)]
    n_bins = 256
    cmap_green = LinearSegmentedColormap.from_list('green_alpha', colors_green, N=n_bins)
    
    for i in range(min(num_samples, batch_size)):
        # Frame 0 (x0)
        frame0 = inp[i, 0, :, :].cpu().numpy()
        axes[i, 0].imshow(frame0, cmap='gray', vmin=0, vmax=1)
        axes[i, 0].set_title('Frame 0 (x₀)')
        axes[i, 0].axis('off')
        
        # Frame 1 (x1)
        frame1 = inp[i, 1, :, :].cpu().numpy()
        axes[i, 1].imshow(frame1, cmap='gray', vmin=0, vmax=1)
        axes[i, 1].set_title('Frame 1 (x₁)')
        axes[i, 1].axis('off')
        
        # Most salient mask overlay
        salient_idx = salient_mask_indices[i].item()
        salient_mask = masks[i, salient_idx, :, :].cpu().numpy()
        axes[i, 2].imshow(frame1, cmap='gray', vmin=0, vmax=1)
        axes[i, 2].imshow(salient_mask, cmap=cmap_green, alpha=0.6, vmin=0, vmax=1)
        axes[i, 2].set_title(f'Most Salient Object (Mask {salient_idx})')
        axes[i, 2].axis('off')
        
        # Weighted sum of all masks
        weighted_mask = weighted_masks[i].cpu().numpy()
        axes[i, 3].imshow(frame1, cmap='gray', vmin=0, vmax=1)
        axes[i, 3].imshow(weighted_mask, cmap=cmap_green, alpha=0.6, vmin=0, vmax=1)
        axes[i, 3].set_title('All Objects (Weighted Sum)')
        axes[i, 3].axis('off')
        
        # Optical flow visualization
        flow = translation_masks[i].sum(dim=0).cpu().numpy()  # [2, H, W]
        flow_magnitude_vis = np.sqrt(flow[0]**2 + flow[1]**2)
        flow_angle = np.arctan2(flow[0], flow[1])
        
        # Create HSV image (hue=direction, value=magnitude)
        hsv = np.zeros((flow.shape[1], flow.shape[2], 3))
        hsv[:, :, 0] = (flow_angle + np.pi) / (2 * np.pi)  # Hue
        hsv[:, :, 1] = 1.0  # Saturation
        hsv[:, :, 2] = flow_magnitude_vis / (flow_magnitude_vis.max() + 1e-8)  # Value
        
        from matplotlib.colors import hsv_to_rgb
        rgb_flow = hsv_to_rgb(hsv)
        
        axes[i, 4].imshow(rgb_flow)
        axes[i, 4].set_title('Optical Flow')
        axes[i, 4].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    plt.show()
    
    return fig


def visualize_individual_masks(model, data_path, game=None, sample_idx=0, num_masks=10):
    """
    Visualize individual object masks for a single sample.
    
    Args:
        model: Trained VideoObjectSegmentationModel
        data_path: Path to the dataset
        game: Specific game to visualize (optional)
        sample_idx: Index of sample to visualize
        num_masks: Number of top masks to display
    """
    batch_size = 32
    num_frames = 2
    device = next(model.parameters()).device
    
    # Load data
    data = VOSDataset(batch_size, num_frames, data_path, game=game)
    inp = data.get_batch("train").to(device)
    
    # Prepare input
    x_input = torch.cat([
        torch.unsqueeze(inp[:, 0, :, :], 1), 
        torch.unsqueeze(inp[:, 1, :, :], 1)
    ], 1)
    
    model.eval()
    with torch.no_grad():
        _ = model(x_input)
        masks = model.object_masks  # [BS, K, H, W]
        translation_masks = model.translation_masks  # [BS, K, 2, H, W]
        
        # Compute importance of each mask
        mask_importance = torch.norm(translation_masks[sample_idx], dim=1).sum(dim=1).sum(dim=1)
        top_mask_indices = torch.argsort(mask_importance, descending=True)[:num_masks]
    
    # Visualize
    cols = 5
    rows = (num_masks + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(3*cols, 3*rows))
    axes = axes.flatten() if num_masks > 1 else [axes]
    
    frame = inp[sample_idx, 1, :, :].cpu().numpy()
    
    for i, mask_idx in enumerate(top_mask_indices):
        if i >= len(axes):
            break
        mask = masks[sample_idx, mask_idx, :, :].cpu().numpy()
        importance = mask_importance[mask_idx].item()
        
        axes[i].imshow(frame, cmap='gray', vmin=0, vmax=1)
        axes[i].imshow(mask, cmap='Greens', alpha=0.6, vmin=0, vmax=1)
        axes[i].set_title(f'Mask {mask_idx.item()} (imp: {importance:.2f})')
        axes[i].axis('off')
    
    # Hide unused subplots
    for i in range(len(top_mask_indices), len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return fig