In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!git clone https://github.com/princeton-vl/RAFT.git
%cd RAFT
import os
os.chdir('core')
import torch
from raft import *  # Assuming raft.py is saved as a module

In [None]:
import os
import cv2
import numpy as np
import torch
from raft import RAFT
import argparse
from scipy.ndimage import laplace
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm

class VideoProcessor:
    def __init__(self, bin_sizes=[32, 64, 128], embed_dim=64):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self._init_raft()
        self.resize_dim = (256, 256)
        
        self.bin_sizes = bin_sizes
        self.max_bins = max(bin_sizes)
        self.embed_dim = embed_dim
        
        # Initialize histogram network
        self._init_histogram_network()
                
        
    def _init_raft(self):
        args = argparse.Namespace(
            name='raft', stage='chairs', validation='kitti',
            mixed_precision=False, small=False, dropout=0,
            alternate_corr=False, model='',
            restore_ckpt='/kaggle/input/vcdataset/models/raft-kitti.pth',
            path='', gpus=[0], iters=12
        )
        model = RAFT(args)
        checkpoint = torch.load(args.restore_ckpt, map_location=self.device, weights_only=True)  # Added weights_only=True for security
        model.load_state_dict(checkpoint, strict=False)
        return model.to(self.device).eval()

    def _init_histogram_network(self):
        """Initialize learnable histogram processing layers"""
        self.fc_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(bins, self.embed_dim),
                nn.ReLU(),
                nn.LayerNorm(self.embed_dim)
            ) for bins in self.bin_sizes
        ]).to(self.device)
        
        self.attention = nn.MultiheadAttention(
            self.embed_dim, num_heads=4
        ).to(self.device)
        self.output_proj = nn.Linear(self.embed_dim, self.embed_dim).to(self.device)

    def _process_histograms(self, frame):
        """Process frame histograms with attention"""
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        
        # Calculate all histograms
        hists = [
            cv2.normalize(
                cv2.calcHist([gray], [0], None, [bins], [0, 256]),
                None
            ).flatten() 
            for bins in self.bin_sizes
        ]
        
        # Convert to tensor and process
        hist_tensors = [torch.from_numpy(h).float().to(self.device) for h in hists]
        
        # Embed each histogram
        embeddings = []
        for hist, fc in zip(hist_tensors, self.fc_layers):
            embeddings.append(fc(hist))
        
        # Attention processing
        embeddings = torch.stack(embeddings)  # [num_scales, embed_dim]
        attn_out, _ = self.attention(embeddings, embeddings, embeddings)
        return self.output_proj(attn_out.mean(dim=0)).detach()  # Added detach() here

    def calculate_temporal_sharpness(self, l_channel_sequence):
        sharpness_scores = []
        for i in range(1, len(l_channel_sequence)):
            # Current and previous frame
            prev = l_channel_sequence[i-1].astype(np.float32)
            curr = l_channel_sequence[i].astype(np.float32)
            
            # Frame difference (temporal gradient)
            temporal_diff = cv2.absdiff(curr, prev)
            
            # Spatial sharpness (Laplacian variance)
            spatial_sharpness = laplace(curr).var()
            
            # Combined metric (adjust weights as needed)
            combined = 0.7 * temporal_diff.mean() + 0.3 * spatial_sharpness
            sharpness_scores.append(combined)
            
        return np.array(sharpness_scores)
    
    def process_folder(self, input_folder, output_base):
        """Modified to include sharpness calculation"""
        frame_paths = sorted([f for f in os.listdir(input_folder) 
                           if f.endswith(('.png','.jpg','.jpeg'))])
        
        # Create output directories
        os.makedirs(os.path.join(output_base, 'L'), exist_ok=True)
        os.makedirs(os.path.join(output_base, 'AB'), exist_ok=True)
        os.makedirs(os.path.join(output_base, 'flow'), exist_ok=True)
        os.makedirs(os.path.join(output_base, 'hist_features'), exist_ok=True)
        
        prev_frame = None
        flows = []
        l_channels = []  # Store L channels for sharpness calculation
        hist_features = []
        
        for i, frame_name in enumerate(frame_paths):
            frame_path = os.path.join(input_folder, frame_name)
            frame = cv2.imread(frame_path)
            lab = cv2.cvtColor(cv2.imread(frame_path), cv2.COLOR_BGR2LAB)
            lab = cv2.resize(lab, self.resize_dim)
            
            # Save components
            base_name = os.path.splitext(frame_name)[0]
            np.save(os.path.join(output_base, 'L', f'{base_name}.npy'), lab[:,:,0])
            np.save(os.path.join(output_base, 'AB', f'{base_name}.npy'), lab[:,:,1:])

            # Process histograms
            hist_feat = self._process_histograms(frame)
            hist_features.append(hist_feat.cpu().numpy())
            np.save(
                os.path.join(output_base, 'hist_features', f'{base_name}.npy'),
                hist_feat.cpu().numpy()
            )
            
            l_channels.append(lab[:,:,0])  # Store for sharpness
            
            # Flow calculation (existing code)
            l_channel = lab[:,:,0]
            current_frame = torch.from_numpy(np.stack([l_channel]*3, axis=2)) \
                          .permute(2,0,1).float().unsqueeze(0).to(self.device)
            
            if prev_frame is not None:
                with torch.no_grad():
                    flow = self.model(prev_frame, current_frame, iters=12)[1]
                    flows.append(flow.cpu().numpy())
            prev_frame = current_frame
        
        # Calculate temporal sharpness
        sharpness = self.calculate_temporal_sharpness(l_channels)
        np.save(os.path.join(output_base, 'sharpness.npy'), sharpness)
        
        # Save flows (existing code)
        np.save(os.path.join(output_base, 'flow.npy'), np.concatenate(flows))
        np.save(os.path.join(output_base, 'l_frames.npy'), np.array(l_channels))
        np.save(os.path.join(output_base, 'hist_features.npy'), np.array(hist_features))
        
        return {
            'L': os.path.join(output_base, 'L'),
            'AB': os.path.join(output_base, 'AB'),
            'flow': os.path.join(output_base, 'flow.npy'),
            'sharpness': os.path.join(output_base, 'sharpness.npy'),
            'hist_features': os.path.join(output_base, 'hist_features')
        }

def process_all_videos(root_folder, output_base):
    """Process all video folders in root directory"""
    video_folders = [f for f in os.listdir(root_folder) 
                    if os.path.isdir(os.path.join(root_folder, f))]
    
    results = {}
    for folder in video_folders:
        print(f"Processing {folder}...")
        input_path = os.path.join(root_folder, folder)
        output_path = os.path.join(output_base, folder)
        results[folder] = processor.process_folder(input_path, output_path)
    
    return results

def prepare_fusion_input(flows, sharpness, l_frames, flow_scaling=0.1, l_scaling=1.0/255.0):
    """
    Concatenates optical flow, temporal sharpness, and L channels into a unified input tensor.
    
    Args:
        flows: numpy array of shape (n_frames, 1, 2, H, W) 
        sharpness: numpy array of shape (n_frames,)
        l_frames: numpy array of shape (n_frames, H, W) - grayscale frames
        flow_scaling: factor to normalize flow magnitudes
        l_scaling: factor to normalize L channel values (default 1/255)
        
    Returns:
        torch.Tensor of shape (n_frames, 5, H, W) ready for UNet
        Channel order: [L, Flow_X, Flow_Y, Flow_Magnitude, Sharpness]
    """
    # Convert to tensors
    flows = torch.from_numpy(flows).float()
    sharpness = torch.from_numpy(sharpness).float()
    l_frames = torch.from_numpy(l_frames).float()
    
    # Remove batch dimension if present in flows
    if flows.dim() == 5:
        flows = flows.squeeze(1)  # (n_frames, 2, H, W)

    # Pad flows and sharpness with zeros for first frame
    flows = F.pad(flows, (0, 0, 0, 0, 0, 0, 1, 0))  # (n_frames, 2, H, W)
    sharpness = F.pad(sharpness, (1, 0))  # (n_frames,)
    
    # Normalize inputs
    flows = flows * flow_scaling
    l_frames = l_frames * l_scaling  # Typically scale to [0,1]
    
    # Calculate flow magnitude
    flow_magnitude = torch.norm(flows, dim=1, keepdim=True)  # (n_frames, 1, H, W)
    
    # Prepare sharpness features
    sharpness = sharpness.view(-1, 1, 1, 1)  # (n_frames, 1, 1, 1)
    sharpness_map = sharpness.expand(-1, 1, flows.shape[-2], flows.shape[-1])  # (n_frames, 1, H, W)
    
    # Add channel dimension to L frames
    l_frames = l_frames.unsqueeze(1)  # (n_frames, 1, H, W)
    
    # Concatenate all features
    fusion_input = torch.cat([
        l_frames,                # L channel: 1 channel
        flows,                   # Flow XY: 2 channels
        flow_magnitude,          # Flow magnitude: 1 channel
        sharpness_map            # Sharpness map: 1 channel
    ], dim=1)                   # Total: 5 channels
    
    return fusion_input  # (n_frames, 5, H, W)

In [None]:
if __name__ == "__main__":
    # Load your precomputed data
    processor = VideoProcessor()
    test_result = processor.process_folder(
        input_folder="/kaggle/input/vcdataset/DS/blackswan",
        output_base="/kaggle/working/test_output"
    )
    flows = np.load("/kaggle/working/test_output/flow.npy")        # Shape: (n_frames-1, 1, 2, H, W)
    sharpness = np.load("/kaggle/working/test_output/sharpness.npy") # Shape: (n_frames-1,)
    l_frames = np.load("/kaggle/working/test_output/l_frames.npy")
    
    # Prepare fusion input
    fusion_input = prepare_fusion_input(flows, sharpness, l_frames)
    
    print("Fusion input shape:", fusion_input.shape)
    print("Channel order: [L, Flow_X, Flow_Y, Flow_Magnitude, Sharpness]")
    print("Sample values:")
    print("L channel:", fusion_input[0, 0].min().item(), "to", fusion_input[0, 0].max().item())
    print("Flow X:", fusion_input[0, 1].min().item(), "to", fusion_input[0, 1].max().item())
    print("Flow Y:", fusion_input[0, 2].min().item(), "to", fusion_input[0, 2].max().item())
    print("Magnitude:", fusion_input[0, 3].min().item(), "to", fusion_input[0, 3].max().item())
    print("Sharpness:", fusion_input[0, 4].min().item(), "to", fusion_input[0, 4].max().item())

# *Core Architecture*

# TMM

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvGRUCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size=3):
        super().__init__()
        self.hidden_channels = hidden_channels
        
        # Gates convolutions
        self.conv_gates = nn.Conv2d(
            input_channels + hidden_channels, 
            2 * hidden_channels, 
            kernel_size=kernel_size,
            padding=kernel_size//2,
            padding_mode='reflect'
        )
        
        self.conv_candidate = nn.Conv2d(
            input_channels + hidden_channels,
            hidden_channels,
            kernel_size=kernel_size,
            padding=kernel_size//2,
            padding_mode='reflect'
        )
        
        # Normalization
        self.norm_gates = nn.GroupNorm(8, 2 * hidden_channels)
        self.norm_candidate = nn.GroupNorm(8, hidden_channels)
        
        # Learnable hidden state initialization
        self.h0 = nn.Parameter(torch.randn(1, hidden_channels, 1, 1) * 0.02)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        nn.init.xavier_uniform_(self.conv_gates.weight)
        nn.init.zeros_(self.conv_gates.bias)
        with torch.no_grad():
            self.conv_gates.bias[self.hidden_channels:].fill_(-1)
        nn.init.xavier_uniform_(self.conv_candidate.weight)
        nn.init.zeros_(self.conv_candidate.bias)
    
    def forward(self, x, h_prev=None):
        if h_prev is None:
            batch_size, _, height, width = x.shape
            h_prev = self.h0.expand(batch_size, -1, height, width)
        
        combined = torch.cat([x, h_prev], dim=1)
        
        # Gates computation
        gates = self.norm_gates(self.conv_gates(combined))
        reset_gate, update_gate = torch.sigmoid(gates).chunk(2, 1)
        
        # Candidate computation
        combined_reset = torch.cat([x, reset_gate * h_prev], dim=1)
        candidate = torch.tanh(self.norm_candidate(self.conv_candidate(combined_reset)))
        
        # New hidden state
        h_new = update_gate * h_prev + (1 - update_gate) * candidate
        return h_new

In [None]:
class ConvGRUWrapper(nn.Module):
    def __init__(self, input_channels, hidden_channels, num_layers=1, 
                 dropout=0.0, use_attention=True, debug=False):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.debug = debug
        
        self.gru_layers = nn.ModuleList([
            ConvGRUCell(
                input_channels if i == 0 else hidden_channels,
                hidden_channels
            ) for i in range(num_layers)
        ])
        
        self.attention = (
            nn.Sequential(
                nn.Conv2d(hidden_channels, 1, kernel_size=1),
                nn.Sigmoid()
            ) if use_attention else None
        )
        
        self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
        
    def init_hidden(self, batch_size, height, width):
        return [
            self.gru_layers[i].h0.expand(batch_size, -1, height, width)
            for i in range(self.num_layers)
        ]
        
    def forward(self, x, hidden_states=None):
        if self.debug:
            print("\nConvGRUWrapper forward:")
            print(f"Initial input shape: {x.shape}")
        
        # Input shape handling
        if x.dim() == 4:
            x = x.unsqueeze(1)
        if x.dim() != 5:
            raise ValueError(f"Input must be [T,C,H,W] or [T,B,C,H,W], got {x.shape}")
        
        T, B, C, H, W = x.shape
        
        if self.debug:
            print(f"Processed input shape: {x.shape} (T,B,C,H,W)")
        
        # Initialize hidden states
        if hidden_states is None:
            hidden_states = [None] * self.num_layers
            if self.debug:
                print("Initialized all hidden states as None")
        
        outputs = []
        for t in range(T):
            if self.debug:
                print(f"\nProcessing time step {t}")
            
            x_t = x[t]
            new_hidden = []
            
            for layer_idx in range(self.num_layers):
                h = hidden_states[layer_idx]
                h = self.gru_layers[layer_idx](x_t, h)
                new_hidden.append(h)
                x_t = h
            
            # Optional attention + dropout
            out = x_t
            if self.attention is not None:
                attn = self.attention(out)
                if self.debug:
                    print(f"Attention weights mean: {attn.mean().item():.4f}")
                out = out * attn
            
            out = self.dropout(out)
            outputs.append(out)
            hidden_states = new_hidden
            
            if self.debug:
                print(f"Time step {t} output shape: {out.shape}")
                if hidden_states[0] is not None:
                    print(f"Hidden state mean: {torch.mean(hidden_states[0]).item():.4f}")
        
        return torch.stack(outputs), hidden_states

In [None]:
gru = ConvGRUWrapper(input_channels=3, 
                    hidden_channels=64, 
                    num_layers=2, 
                    dropout=0.1, 
                    debug=True)

x = torch.randn(10, 3, 32, 32)  # [T,C,H,W]
output, hidden = gru(x)  # output: [T,1,64,32,32], hidden: list of [1,64,32,32]

# SEB block for Histogram Fusion

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SEBlock(nn.Module):
    """Squeeze-and-Excitation block for channel-wise attention with enhanced checks"""
    def __init__(self, channels, reduction=8, debug=False):
        super().__init__()
        self.debug = debug
        self.channels = channels
        self.reduction = reduction
        
        # Validate reduction ratio
        if channels < reduction:
            raise ValueError(f"Channels ({channels}) must be >= reduction ({reduction})")
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )
        
        # Initialize weights properly
        self._init_weights()

    def _init_weights(self):
        for m in self.fc:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        if x.dim() != 4:
            raise ValueError(f"Input must be 4D [B,C,H,W], got {x.dim()}D tensor")
            
        b, c, h, w = x.shape
        if c != self.channels:
            raise ValueError(f"Expected {self.channels} channels, got {c}")
        
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        
        if self.debug:
            print(f"[SEBlock] Input shape: {x.shape}")
            print(f"Attention weights - min: {y.min().item():.4f}, max: {y.max().item():.4f}, mean: {y.mean().item():.4f}")
        
        return x * y.expand_as(x)


class CrossAttentionFusion(nn.Module):
    """Enhanced cross-attention with dimension checks and debugging"""
    def __init__(self, visual_dim, hist_dim, debug=False):
        super().__init__()
        self.debug = debug
        self.visual_dim = visual_dim
        self.hist_dim = hist_dim
        
        # Projection layers with proper initialization
        self.query = nn.Linear(visual_dim, visual_dim)
        self.key = nn.Linear(hist_dim, visual_dim)
        self.value = nn.Linear(hist_dim, visual_dim)
        self.softmax = nn.Softmax(dim=-1)
        
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.query.weight)
        nn.init.xavier_uniform_(self.key.weight)
        nn.init.xavier_uniform_(self.value.weight)
        nn.init.zeros_(self.query.bias)
        nn.init.zeros_(self.key.bias)
        nn.init.zeros_(self.value.bias)

    def forward(self, visual_feat, hist_feat):
        # Input validation
        if visual_feat.dim() != 4:
            raise ValueError(f"Visual features should be 4D [B,C,H,W], got {visual_feat.shape}")
        if hist_feat.dim() != 2:
            raise ValueError(f"Hist features should be 2D [B,D], got {hist_feat.shape}")
            
        B, C, H, W = visual_feat.shape
        if C != self.visual_dim:
            raise ValueError(f"Visual dim mismatch: expected {self.visual_dim}, got {C}")
        if hist_feat.size(1) != self.hist_dim:
            raise ValueError(f"Hist dim mismatch: expected {self.hist_dim}, got {hist_feat.size(1)}")

        # Flatten visual features
        visual_flat = visual_feat.view(B, C, -1).permute(0, 2, 1)  # [B, HW, C]
        
        # Project features
        q = self.query(visual_flat)  # [B, HW, C]
        k = self.key(hist_feat).unsqueeze(1)  # [B, 1, C]
        v = self.value(hist_feat).unsqueeze(1)  # [B, 1, C]
        
        # Attention computation
        scale = (C) ** -0.5  # More stable scaling
        attn_logits = (q @ k.transpose(-2, -1)) * scale  # [B, HW, 1]
        attn = self.softmax(attn_logits)
        output = (attn * v).permute(0, 2, 1).view(B, C, H, W)
        
        if self.debug and torch.is_grad_enabled():
            print(f"\n[CrossAttention] Debug:")
            print(f"Visual in: {visual_feat.shape}, Hist in: {hist_feat.shape}")
            print(f"Q mean: {q.mean().item():.4f}, K mean: {k.mean().item():.4f}")
            print(f"Attention weights - min: {attn.min().item():.4f}, max: {attn.max().item():.4f}")
            print(f"Output delta mean: {(output - visual_feat).mean().item():.4f}")
        
        return output + visual_feat  # Residual connection

In [None]:
# With debugging enabled
se = SEBlock(channels=64, reduction=8, debug=True)
cross_attn = CrossAttentionFusion(visual_dim=64, hist_dim=64, debug=True)

# Normal usage
visual_feat = torch.randn(2, 64, 32, 32)
hist_feat = torch.randn(2, 64)
out = cross_attn(visual_feat, hist_feat)

# UNet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SpatialAttention(nn.Module):
    """Spatial attention module with configurable kernel size"""
    def __init__(self, kernel_size=7, debug=False):
        super().__init__()
        assert kernel_size % 2 == 1, "Kernel size should be odd for symmetric padding"
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.debug = debug
        
    def forward(self, x):
        if self.debug:
            print(f"\n[SpatialAttention] Input shape: {x.shape}")
        
        # Channel pooling
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        concat = torch.cat([avg_out, max_out], dim=1)
        
        attention = torch.sigmoid(self.conv(concat))
        
        if self.debug:
            print(f"Attention map - min: {attention.min().item():.4f}, max: {attention.max().item():.4f}")
        
        return x * attention

class ConvBlock(nn.Module):
    """Basic convolutional block with optional downsampling"""
    def __init__(self, in_channels, out_channels, downsample=False, debug=False):
        super().__init__()
        self.debug = debug
        self.downsample = downsample
        
        stride = 2 if downsample else 1
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                              stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, 
                               padding=1, bias=False)
        self.norm = nn.GroupNorm(8, out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        # Shortcut connection for downsampling
        if downsample:
            self.down = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
                nn.GroupNorm(8, out_channels)
            )
        else:
            self.down = None
            
    def forward(self, x):
        if self.debug:
            print(f"\n[ConvBlock] Input shape: {x.shape}, downsample: {self.downsample}")
        
        identity = x
        if self.downsample:
            identity = self.down(x)
            if self.debug:
                print(f"Shortcut output shape: {identity.shape}")
        
        x = self.conv1(x)
        x = self.norm(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.norm(x)
        
        if self.downsample:
            x = x + identity
            
        out = self.relu(x)
        
        if self.debug:
            print(f"Output shape: {out.shape}")
        
        return out

class UpConvBlock(nn.Module):
    """Upsampling block with skip connections and attention"""
    def __init__(self, in_channels, out_channels, skip_channels=None, debug=False):
        super().__init__()
        self.debug = debug
        self.skip_channels = skip_channels if skip_channels is not None else out_channels
        
        # Upsampling layer
        self.up = nn.ConvTranspose2d(
            in_channels, out_channels, 
            kernel_size=2, 
            stride=2,
            bias=False
        )
        
        # Main processing block
        self.conv = ConvBlock(
            out_channels + self.skip_channels, 
            out_channels,
            debug=debug
        )
        
        self.attention = SpatialAttention(debug=debug)
        
        if self.debug:
            print(f"\n[UpConvBlock] Initialized with:")
            print(f"in_channels: {in_channels}, out_channels: {out_channels}")
            print(f"skip_channels: {self.skip_channels}")

    def forward(self, x, skip=None):
        if self.debug:
            print(f"\n[UpConvBlock] Input shape: {x.shape}")
            if skip is not None:
                print(f"Skip connection shape: {skip.shape}")
            else:
                print("No skip connection provided")
        
        # Upsample main input
        x = self.up(x)
        if self.debug:
            print(f"After upsampling: {x.shape}")
        
        # Process skip connection if available
        if skip is not None:
            # Validate skip connection channels
            if skip.shape[1] != self.skip_channels:
                raise ValueError(
                    f"Skip channels mismatch: expected {self.skip_channels}, got {skip.shape[1]}"
                )
            
            # Resize if needed (should be rare in properly constructed networks)
            if skip.shape[2:] != x.shape[2:]:
                if self.debug:
                    print(f"Resizing skip from {skip.shape[2:]} to {x.shape[2:]}")
                skip = F.interpolate(
                    skip, 
                    size=x.shape[2:], 
                    mode='bilinear', 
                    align_corners=False
                )
            
            x = torch.cat([x, skip], dim=1)
            if self.debug:
                print(f"After skip concatenation: {x.shape}")
        
        # Process through conv block
        x = self.conv(x)
        if self.debug:
            print(f"After conv block: {x.shape}")
        
        # Apply spatial attention
        out = self.attention(x)
        
        if self.debug:
            print(f"Final output shape: {out.shape}")
        
        return out

In [None]:
# With debugging
up_block = UpConvBlock(64, 32, skip_channels=64, debug=True)
x = torch.randn(2, 64, 16, 16)
skip = torch.randn(2, 64, 32, 32)
out = up_block(x, skip)

# Normal usage
up_block = UpConvBlock(64, 32)  # Defaults skip_channels=32

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BidirectionalVideoColorizationNet(nn.Module):
    def __init__(self, in_channels=5, out_channels=2, base_channels=64, hist_embed_dim=64, debug=False):
        super().__init__()
        self.debug = debug
        self.base_channels = base_channels
        self.hist_embed_dim = hist_embed_dim

        # ================ ENCODER ================
        self.encoder = nn.ModuleDict({
            'enc1': ConvBlock(in_channels, base_channels, downsample=True),
            'enc2': ConvBlock(base_channels, base_channels*2, downsample=True),
            'enc3': ConvBlock(base_channels*2, base_channels*4, downsample=True),
            'enc4': ConvBlock(base_channels*4, base_channels*8, downsample=True),
            'enc5': ConvBlock(base_channels*8, base_channels*16, downsample=True)
        })

        # ================ TEMPORAL PROCESSING ================
        self.temporal = nn.ModuleDict({
            'temporal_forward': ConvGRUWrapper(
                input_channels=base_channels*16,
                hidden_channels=base_channels*8,
                num_layers=1,
                debug=debug
            ),
            'temporal_backward': ConvGRUWrapper(
                input_channels=base_channels*16,
                hidden_channels=base_channels*8,
                num_layers=1,
                debug=debug
            )
        })

        # ================ HISTOGRAM PROCESSING ================
        self.hist_proj = nn.ModuleDict({
            'high': nn.Sequential(
                nn.Linear(hist_embed_dim, base_channels*8),
                nn.ReLU(),
                nn.LayerNorm(base_channels*8)
            ),
            'mid': nn.Sequential(
                nn.Linear(hist_embed_dim, base_channels*4),
                nn.ReLU(),
                nn.LayerNorm(base_channels*4)
            ),
            'low': nn.Sequential(
                nn.Linear(hist_embed_dim, base_channels*2),
                nn.ReLU(),
                nn.LayerNorm(base_channels*2)
            )
        })

        # ================ FUSION MODULES ================
        self.fusion = nn.ModuleDict({
            'proj': nn.Conv2d(base_channels*8*2, base_channels*8, 1),
            'high': CrossAttentionFusion(
                visual_dim=base_channels*8,
                hist_dim=base_channels*8,
                debug=debug
            ),
            'mid': SEBlock(base_channels*4, debug=debug),
            'low': SEBlock(base_channels*2, debug=debug)
        })

        # ================ FIRST FRAME PROCESSING ================
        self.first_frame_proc = nn.Sequential(
            nn.Conv2d(out_channels, base_channels*4, 3, stride=2, padding=1),
            nn.GroupNorm(8, base_channels*4),
            nn.ReLU(),
            nn.Conv2d(base_channels*4, base_channels*8, 3, stride=2, padding=1),
            nn.GroupNorm(8, base_channels*8),
            nn.ReLU()
        )

        # ================ DECODER ================
        self.decoder = nn.ModuleDict({
            'dec1': UpConvBlock(base_channels*8, base_channels*4, 
                              skip_channels=base_channels*8, debug=debug),
            'dec2': UpConvBlock(base_channels*4, base_channels*2, 
                              skip_channels=base_channels*4, debug=debug),
            'dec3': UpConvBlock(base_channels*2, base_channels, 
                              skip_channels=base_channels*2, debug=debug)
        })

        # ================ OUTPUT ================
        self.output = nn.Sequential(
            nn.Conv2d(base_channels, base_channels, 3, padding=1),
            nn.GroupNorm(8, base_channels),
            nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1),
            nn.Conv2d(base_channels, out_channels, 3, padding=1),
            nn.Tanh()
        )

        if self.debug:
            self.print_architecture()
            
    def forward(self, x, hist_features, first_frame_ab):
        """Complete forward pass with batch processing and debugging"""
        # ===================== INPUT VALIDATION =====================
        if self.debug:
            print("\n" + "="*30 + " FORWARD PASS " + "="*30)
            print(f"Input frames shape: {x.shape}")
            print(f"Hist features shape: {hist_features.shape}")
            print(f"First frame AB shape: {first_frame_ab.shape}")
    
        # Validate input dimensions
        if x.dim() != 5:
            raise ValueError(f"Input must be 5D [B,T,1,H,W], got {x.shape}")
        if hist_features.dim() != 3:
            raise ValueError(f"Hist features must be 3D [B,T,D], got {hist_features.shape}")
        if first_frame_ab.dim() != 5 or first_frame_ab.shape[1] != 1:
            raise ValueError(f"First frame AB must be [B,1,2,H,W], got {first_frame_ab.shape}")
    
        B, T, C, H, W = x.shape
        if C != 1:
            raise ValueError(f"Input frames should have 1 channel (L), got {C}")
    
        # ===================== PROCESS EACH VIDEO IN BATCH =====================
        batch_output = []
        for b in range(B):  # Process each video independently
            if self.debug:
                print(f"\nProcessing video {b+1}/{B}")
    
            # Get current video data
            video_L = x[b]  # [T,1,H,W]
            video_hist = hist_features[b]  # [T,D]
            video_first_ab = first_frame_ab[b]  # [1,2,H,W]
    
            # ===================== ENCODER PASS =====================
            encoder_features = []
            skip_connections = {'e4': [], 'e3': [], 'e2': [], 'e1': []}
    
            for t in range(T):
                if self.debug:
                    print(f"  Frame {t+1}/{T}")
    
                x_t = video_L[t].unsqueeze(0)  # [1,1,H,W]
    
                # Encoder forward pass
                e1 = self.encoder['enc1'](x_t)
                e2 = self.encoder['enc2'](e1)
                e3 = self.encoder['enc3'](e2)
                e4 = self.encoder['enc4'](e3)
                e5 = self.encoder['enc5'](e4)
    
                if self.debug:
                    print(f"    Encoder features:")
                    print(f"    - e1: {e1.shape}")
                    print(f"    - e5: {e5.shape}")
    
                encoder_features.append(e5.squeeze(0))  # Remove batch dim
                skip_connections['e4'].append(e4.squeeze(0))
                skip_connections['e3'].append(e3.squeeze(0))
                skip_connections['e2'].append(e2.squeeze(0))
                skip_connections['e1'].append(e1.squeeze(0))
    
            encoder_features = torch.stack(encoder_features)  # [T,1024,H/32,W/32]
    
            # ===================== HISTOGRAM PROCESSING =====================
            hist_high = self.hist_proj['high'](video_hist)  # [T,512]
            hist_mid = self.hist_proj['mid'](video_hist)    # [T,256]
            hist_low = self.hist_proj['low'](video_hist)    # [T,128]
    
            if self.debug:
                print("\n  Histogram projections:")
                print(f"    - high: {hist_high.shape}")
                print(f"    - mid: {hist_mid.shape}")
                print(f"    - low: {hist_low.shape}")
    
            # ===================== TEMPORAL PROCESSING =====================
            # First frame initialization
            init_state = self.first_frame_proc(video_first_ab)  # [1,512,H/4,W/4]
            init_state = F.adaptive_avg_pool2d(init_state, (H//32, W//32))  # [1,512,H/32,W/32]
    
            # Forward pass
            forward_out, _ = self.temporal['temporal_forward'](
                encoder_features.unsqueeze(1))  # [T,1,512,H/32,W/32]
            forward_out = forward_out.squeeze(1)  # [T,512,H/32,W/32]
            forward_out[0] = init_state.squeeze(0)  # Initialize first frame
    
            # Backward pass
            backward_out, _ = self.temporal['temporal_backward'](
            torch.flip(encoder_features, [0]).unsqueeze(1))  # [T,1,512,H/32,W/32]
            backward_out = torch.flip(backward_out.squeeze(1), [0])  # [T,512,H/32,W/32]
            backward_out[-1] = init_state.squeeze(0)
    
            if self.debug:
                print("\n  Temporal processing:")
                print(f"    - Forward out: {forward_out.shape}")
                print(f"    - Backward out: {backward_out.shape}")
    
            # ===================== FUSION AND DECODING =====================
            video_output = []
            for t in range(T):
                if self.debug:
                    print(f"  Decoding frame {t+1}/{T}")
    
                # Bidirectional fusion
                fused = torch.cat([
                    forward_out[t].unsqueeze(0),
                    backward_out[t].unsqueeze(0)
                ], dim=1)  # [1,1024,H/32,W/32]
                fused = self.fusion['proj'](fused)  # [1,512,H/32,W/32]
    
                # High-level fusion with attention
                fused_high = self.fusion['high'](
                    fused, 
                    hist_high[t].unsqueeze(0)  # [1,512]
                )
    
                # Get skip connections
                e4_t = skip_connections['e4'][t].unsqueeze(0)  # [1,512,H/16,W/16]
                e3_t = skip_connections['e3'][t].unsqueeze(0)  # [1,256,H/8,W/8]
                e2_t = skip_connections['e2'][t].unsqueeze(0)  # [1,128,H/4,W/4]
                e1_t = skip_connections['e1'][t].unsqueeze(0)  # [1,64,H/2,W/2]
    
                # Decoder with skip connections
                d1 = self.decoder['dec1'](fused_high, e4_t)  # [1,256,H/16,W/16]
                
                # Mid-level fusion
                hist_mid_t = hist_mid[t].view(1, 256, 1, 1).expand(-1, -1, *d1.shape[2:])
                d1 = d1 * self.fusion['mid'](hist_mid_t)
    
                d2 = self.decoder['dec2'](d1, e3_t)  # [1,128,H/8,W/8]
                
                # Low-level fusion
                hist_low_t = hist_low[t].view(1, 128, 1, 1).expand(-1, -1, *d2.shape[2:])
                d2 = d2 * self.fusion['low'](hist_low_t)
    
                d3 = self.decoder['dec3'](d2, e2_t)  # [1,64,H/4,W/4]
    
                # Shallow feature refinement
                e1_t_resized = F.interpolate(e1_t, size=d3.shape[2:], mode='bilinear')
                d3_refined = d3 + e1_t_resized
    
                # Final output
                ab = F.interpolate(d3_refined, size=(H,W), mode='bilinear')
                ab = self.output(ab)  # [1,2,H,W]
    
                if t == 0:  # Ensure first frame matches input exactly
                    ab = video_first_ab
    
                if self.debug:
                    print(f"    Output AB range: {ab.min().item():.2f} to {ab.max().item():.2f}")
    
                video_output.append(ab.squeeze(0))  # [2,H,W]
    
            batch_output.append(torch.stack(video_output))  # [T,2,H,W]
    
        return torch.stack(batch_output)  # [B,T,2,H,W]

    def print_architecture(self):
        """Print model architecture for debugging"""
        print("\n" + "="*50)
        print("Bidirectional Video Colorization Network")
        print("="*50)
        print(f"Base channels: {self.base_channels}")
        print(f"Histogram embedding dim: {self.hist_embed_dim}")
        print("\nEncoder:")
        for name, module in self.encoder.items():
            print(f"- {name}: {module}")
        print("\nTemporal Processing:")
        for name, module in self.temporal.items():
            print(f"- {name}: {module}")
        print("\nDecoder:")
        for name, module in self.decoder.items():
            print(f"- {name}: {module}")
        print("="*50 + "\n")

    def debug_shapes(self, T=5, H=256, W=256):
        """Print expected tensor shapes for debugging"""
        print("\n" + "="*50)
        print("Expected Tensor Shapes (Debug Mode)")
        print("="*50)
        print(f"Input frames: [{T},5,{H},{W}]")
        print(f"Hist features: [{T},{self.hist_embed_dim}]")
        print(f"First frame AB: [1,2,{H},{W}]")
            
        print("\nEncoder Features:")
        print(f"enc1 out: [{T},{self.base_channels},{H//2},{W//2}]")
        print(f"enc5 out: [{T},{self.base_channels*16},{H//32},{W//32}]")
            
        print("\nTemporal Processing:")
        print(f"Forward out: [{T},{self.base_channels*8},{H//32},{W//32}]")
            
        print("\nDecoder Features:")
        print(f"dec1 out: [{T},{self.base_channels*4},{H//16},{W//16}]")
        print(f"Final AB: [{T},2,{H},{W}]")
        print("="*50 + "\n")

In [None]:
# Initialize model
model = BidirectionalVideoColorizationNet(debug=True)

In [None]:
!pip install git+https://github.com/moskomule/sam.pytorch.git

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {count_parameters(model):,}")

In [None]:
class VideoColorizationPipeline:
    def __init__(self, model_config, train_config):
        # Initialize your original model
        self.model = BidirectionalVideoColorizationNet(**model_config)
        
        # Add performance wrappers (NO architecture changes)
        self.scaler = GradScaler(enabled=train_config['use_amp'])
        self.ema = ExponentialMovingAverage(self.model.parameters(), 
                                          decay=train_config['ema_decay'])
        
        # SAM + Lookahead (wraps existing optimizer)
        base_optimizer = optim.AdamW(self.model.parameters(), 
                                    lr=train_config['lr'])
        self.optimizer = SAM(base_optimizer, rho=train_config['sam_rho'])
        self.optimizer = Lookahead(self.optimizer)
        
        # Enhanced loss function (wraps original outputs)
        self.loss_fn = WarpConsistencyLoss(
            l2_weight=train_config['l2_weight'],
            warp_weight=train_config['warp_weight']
        )

In [None]:
# Clone the repo manually
!git clone https://github.com/alphadl/lookahead.pytorch.git

# Copy the optimizer to your working directory (optional, but clean)
!cp lookahead.pytorch/lookahead.py lookahead_optimizer.py
from lookahead_optimizer import Lookahead

In [None]:
!pip install torch-ema

In [None]:
!pip install sam-pytorch

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch_ema import ExponentialMovingAverage
from sam import SAM  # pip install sam-pytorch

class EnhancedTrainer:
    def __init__(self, model, train_dir, val_dir, config):
        self.model = model
        self.config = config
        self.device = torch.device(config['device'])
        
        # Data loading
        self.train_dataset = VideoFrameDataset(train_dir)
        self.val_dataset = VideoFrameDataset(val_dir)
        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        self.val_loader = torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=1,  # Process videos one at a time for validation
            num_workers=2,
            pin_memory=True
        )

        # Optimization setup
        base_optimizer = optim.AdamW(
            model.parameters(),
            lr=config['lr'],
            weight_decay=config['weight_decay']
        )
        
        # SAM + Lookahead
        self.optimizer = SAM(
            base_optimizer, 
            model.parameters(),
            rho=config['sam_rho'],
            adaptive=config['sam_adaptive']
        )
        self.optimizer = Lookahead(
            self.optimizer,
            k=config['lookahead_steps'],
            alpha=config['lookahead_alpha']
        )
        
        # Learning rate scheduling
        self.scheduler = CosineAnnealingWarmRestarts(
            self.optimizer,
            T_0=config['restart_period'],
            T_mult=config['restart_multiplier'],
            eta_min=config['min_lr']
        )
        
        # Mixed precision and EMA
        self.scaler = GradScaler(enabled=config['use_amp'])
        self.ema = ExponentialMovingAverage(model.parameters(), decay=config['ema_decay'])
        
        # Loss functions
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        
        # Metrics tracking
        self.metrics = {
            'train': {'loss': [], 'psnr': [], 'ssim': []},
            'val': {'psnr': [], 'ssim': []}
        }

    def compute_metrics(self, pred, target):
        """Calculate PSNR and SSIM for AB channels"""
        pred_np = pred.detach().cpu().numpy()
        target_np = target.detach().cpu().numpy()
        
        batch_psnr = []
        batch_ssim = []
        
        for i in range(pred_np.shape[0]):
            p = pred_np[i].transpose(1,2,0)
            t = target_np[i].transpose(1,2,0)
            
            batch_psnr.append(psnr(t, p, data_range=2.0))
            batch_ssim.append(ssim(t, p, multichannel=True, 
                               data_range=2.0, channel_axis=-1))
        
        return np.mean(batch_psnr), np.mean(batch_ssim)

    def train_epoch(self, epoch):
        self.model.train()
        epoch_loss = 0
        epoch_psnr = 0
        epoch_ssim = 0
        
        for batch in tqdm(self.train_loader, desc=f"Epoch {epoch}"):
            L = batch['L'].to(self.device)  # [B,T,1,H,W]
            AB = batch['AB'].to(self.device)  # [B,T,2,H,W]
            first_ab = batch['first_ab'].to(self.device)  # [B,1,2,H,W]
            flows = batch['flows'].to(self.device)  # [B,T-1,2,H,W]
            
            # SAM requires closure for two forward-backward passes
            def closure():
                self.optimizer.zero_grad()
                with autocast(enabled=self.config['use_amp']):
                    pred_ab = self.model(L, first_ab, flows)
                    # Lock first frame AB
                    pred_ab[:,0] = first_ab.squeeze(1)
                    
                    # Combined loss
                    l1_loss = self.l1_loss(pred_ab, AB)
                    warp_loss = self.compute_warp_loss(pred_ab, flows)
                    loss = l1_loss + self.config['warp_weight'] * warp_loss
                
                self.scaler.scale(loss).backward()
                return loss
            
            # SAM optimization step
            loss = self.optimizer.step(closure)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.ema.update()
            
            # Calculate metrics
            with torch.no_grad():
                pred_ab = self.model(L, first_ab, flows)
                pred_ab[:,0] = first_ab.squeeze(1)
                batch_psnr, batch_ssim = self.compute_metrics(pred_ab, AB)
            
            epoch_loss += loss.item()
            epoch_psnr += batch_psnr
            epoch_ssim += batch_ssim
        
        # Store metrics
        self.metrics['train']['loss'].append(epoch_loss / len(self.train_loader))
        self.metrics['train']['psnr'].append(epoch_psnr / len(self.train_loader))
        self.metrics['train']['ssim'].append(epoch_ssim / len(self.train_loader))
        
        # LR schedule step
        self.scheduler.step()

    def compute_warp_loss(self, pred_ab, flows):
        """Temporal consistency loss using optical flow"""
        loss = 0
        B, T = pred_ab.shape[:2]
        
        for t in range(1, T):
            warped = self.warp(pred_ab[:,t-1], flows[:,t-1])
            loss += F.l1_loss(warped, pred_ab[:,t])
        
        return loss / (T - 1)

    def warp(self, ab, flow):
        """Differentiable warping of AB channels"""
        # Implement flow warping (simplified example)
        B, C, H, W = ab.shape
        grid = self.flow_to_grid(flow)
        return F.grid_sample(ab, grid, padding_mode='border')

    def validate(self):
        self.model.eval()
        val_psnr = 0
        val_ssim = 0
        
        with torch.no_grad(), self.ema.average_parameters():
            for batch in tqdm(self.val_loader, desc="Validating"):
                L = batch['L'].to(self.device)
                AB = batch['AB'].to(self.device)
                first_ab = batch['first_ab'].to(self.device)
                flows = batch['flows'].to(self.device)
                
                with autocast(enabled=self.config['use_amp']):
                    pred_ab = self.model(L, first_ab, flows)
                    pred_ab[:,0] = first_ab.squeeze(1)
                
                batch_psnr, batch_ssim = self.compute_metrics(pred_ab, AB)
                val_psnr += batch_psnr
                val_ssim += batch_ssim
        
        # Store validation metrics
        self.metrics['val']['psnr'].append(val_psnr / len(self.val_loader))
        self.metrics['val']['ssim'].append(val_ssim / len(self.val_loader))
        
        print(f"Validation PSNR: {self.metrics['val']['psnr'][-1]:.2f}, "
              f"SSIM: {self.metrics['val']['ssim'][-1]:.4f}")

    def train(self, epochs):
        for epoch in range(1, epochs + 1):
            self.train_epoch(epoch)
            self.validate()
            
            # Print metrics
            print(f"Train Loss: {self.metrics['train']['loss'][-1]:.4f}, "
                  f"PSNR: {self.metrics['train']['psnr'][-1]:.2f}, "
                  f"SSIM: {self.metrics['train']['ssim'][-1]:.4f}")
            
            # Save checkpoint
            if epoch % self.config['save_interval'] == 0:
                self.save_checkpoint(epoch)

    def save_checkpoint(self, epoch):
        state = {
            'epoch': epoch,
            'model_state': self.model.state_dict(),
            'optimizer_state': self.optimizer.state_dict(),
            'scheduler_state': self.scheduler.state_dict(),
            'ema_state': self.ema.state_dict(),
            'metrics': self.metrics
        }
        torch.save(state, f"checkpoint_epoch{epoch}.pth")

# Configuration
config = {
    'device': 'cuda',
    'lr': 3e-4,
    'min_lr': 1e-6,
    'batch_size': 4,
    'epochs': 100,
    'weight_decay': 1e-4,
    'use_amp': True,
    'sam_rho': 0.05,
    'sam_adaptive': True,
    'lookahead_steps': 5,
    'lookahead_alpha': 0.5,
    'restart_period': 20,
    'restart_multiplier': 1,
    'ema_decay': 0.999,
    'warp_weight': 0.3,
    'save_interval': 5
}

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader

# Create dummy dataset (5-frame synthetic videos)
class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, num_videos=5, num_frames=5, size=(256,256)):
        self.num_videos = num_videos
        self.num_frames = num_frames
        self.size = size
        
    def __len__(self):
        return self.num_videos
    
    def __getitem__(self, idx):
        # Synthetic L channel (grayscale)
        L = torch.rand(self.num_frames, 1, *self.size)  # [T,1,H,W]
        
        # Synthetic AB channels (color)
        AB = torch.rand(self.num_frames, 2, *self.size) * 2 - 1  # [-1,1]
        
        # First frame AB is fixed
        first_ab = AB[0].unsqueeze(0)  # [1,2,H,W]
        
        # Synthetic flow (between frames)
        flows = torch.rand(self.num_frames-1, 2, *self.size)  # [T-1,2,H,W]
        
        return {
            'L': L.float(),
            'AB': AB.float(),
            'first_ab': first_ab.float(),
            'flows': flows.float(),
            'video_name': f"dummy_video_{idx}"
        }

# Test configuration
config = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'lr': 3e-4,
    'batch_size': 2,
    'use_amp': False,  # Disable for dummy test
    'sam_rho': 0.05,
    'restart_period': 5,
    'ema_decay': 0.99,
    'warp_weight': 0.3
}

# def test_pipeline():
#     print("=== Starting Dummy Test ===")
#     print(f"Using device: {config['device']}")
    
#     # 1. Initialize model
#     print("\n1. Initializing model...")
#     model = BidirectionalVideoColorizationNet(
#         in_channels=5,  # L + AB + flow
#         out_channels=2,
#         base_channels=16,  # Smaller for testing
#         debug=True
#     ).to(config['device'])
#     print("Model initialized successfully!")
#     print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
#     # 2. Create dummy data
#     print("\n2. Creating dummy data...")
#     dataset = DummyDataset(num_videos=4, num_frames=5)
#     loader = DataLoader(dataset, batch_size=config['batch_size'])
#     sample = next(iter(loader))
#     print(f"Batch shapes:")
#     print(f"- L: {sample['L'].shape} (input luminance)")
#     print(f"- AB: {sample['AB'].shape} (target color)")
#     print(f"- first_ab: {sample['first_ab'].shape} (fixed first frame)")
#     print(f"- flows: {sample['flows'].shape} (optical flow)")
    
#     # 3. Test forward pass
#     print("\n3. Testing forward pass...")
#     with torch.no_grad():
#         output = model(
#             sample['L'].to(config['device']),
#             sample['first_ab'].to(config['device']),
#             sample['flows'].to(config['device'])
#         )
#     print("Forward pass successful!")
#     print(f"Output shape: {output.shape} (should match AB shape)")

#     # Check whether first frame is preserved (t == 0)
#     first_frame_preserved = torch.allclose(
#         output[:, 0],  # predicted AB for first frame in each video
#         sample['first_ab'].squeeze(1).to(config['device']),
#         atol=1e-4
#     )
#     print(f"First frame AB preserved: {first_frame_preserved}")
    
#     # 4. Test training step
#     print("\n4. Testing training step...")
#     trainer = EnhancedTrainer(model, None, None, config)  # No real data
#     loss = trainer.optimizer.step(lambda: trainer.compute_loss(
#         output, 
#         sample['AB'].to(config['device']), 
#         sample['flows'].to(config['device'])
#     ))
#     print(f"Training step completed with loss: {loss.item():.4f}")
    
#     # 5. Test validation metrics
#     print("\n5. Testing validation metrics...")
#     psnr, ssim = trainer.compute_metrics(
#         output, 
#         sample['AB'].to(config['device'])
#     )
#     print(f"PSNR: {psnr:.2f} (higher is better)")
#     print(f"SSIM: {ssim:.4f} (0-1, higher is better)")
    
#     # 6. Verify EMA
#     print("\n6. Testing EMA...")
#     base_weight = model.encoder['enc1'].conv1.weight.data.clone()
#     print(f"Original first weight mean: {base_weight.mean().item():.4f}")
#     with trainer.ema.average_parameters():
#         ema_weight = model.encoder['enc1'].conv1.weight.data.mean().item()
#     print(f"EMA weight mean: {ema_weight:.4f}")
#     weights_differ = not torch.allclose(
#         base_weight, 
#         model.encoder['enc1'].conv1.weight.data,
#         atol=1e-6
#     )
#     print(f"Values differ: {weights_differ}")
    
#     print("\n=== All tests passed! ===")


# if __name__ == "__main__":
#     test_pipeline()
# Test with dummy data
def test_forward_pass():
    B, T, H, W = 2, 5, 256, 256  # Batch size, frames, height, width
    model = BidirectionalVideoColorizationNet(debug=True)
    
    # Create dummy inputs
    x = torch.rand(B, T, 1, H, W)  # L channels
    hist = torch.rand(B, T, 64)    # Hist features
    first_ab = torch.rand(B, 1, 2, H, W)  # First frame AB
    
    # Run forward pass
    output = model(x, hist, first_ab)
    
    # Verify outputs
    assert output.shape == (B, T, 2, H, W), f"Wrong output shape: {output.shape}"
    assert torch.allclose(output[:,0], first_ab.squeeze(1)), "First frame not preserved"
    print("Forward pass test successful!")

test_forward_pass()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm import tqdm
import numpy as np

# Fixed imports - install with: pip install sam-pytorch lookahead-pytorch torch-ema
from lookahead_pytorch import Lookahead  # Changed from 'lookahead'
from torch_ema import ExponentialMovingAverage
from sam import SAM  # pip install sam-pytorch

# You'll need to implement these metrics or install scikit-image
try:
    from skimage.metrics import peak_signal_noise_ratio as psnr
    from skimage.metrics import structural_similarity as ssim
except ImportError:
    print("Warning: scikit-image not found. Install with: pip install scikit-image")
    # Fallback simple implementations
    def psnr(target, pred, data_range=2.0):
        mse = np.mean((target - pred) ** 2)
        return 20 * np.log10(data_range / np.sqrt(mse))
    
    def ssim(target, pred, multichannel=True, data_range=2.0, channel_axis=-1):
        # Simplified SSIM - replace with proper implementation
        return 0.9  # Placeholder

class EnhancedTrainer:
    def __init__(self, model, train_dir, val_dir, config):
        self.model = model
        self.config = config
        self.device = torch.device(config['device'])
        
        # Data loading - you'll need to implement VideoFrameDataset
        self.train_dataset = VideoFrameDataset(train_dir)
        self.val_dataset = VideoFrameDataset(val_dir)
        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        self.val_loader = torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=1,  # Process videos one at a time for validation
            num_workers=2,
            pin_memory=True
        )

        # Optimization setup
        base_optimizer = optim.AdamW(
            model.parameters(),
            lr=config['lr'],
            weight_decay=config['weight_decay']
        )
        
        # SAM + Lookahead - Fixed parameter passing
        self.optimizer = SAM(
            base_optimizer, 
            rho=config['sam_rho'],
            adaptive=config['sam_adaptive']
        )
        self.optimizer = Lookahead(
            self.optimizer,
            k=config['lookahead_steps'],
            alpha=config['lookahead_alpha']
        )
        
        # Learning rate scheduling
        self.scheduler = CosineAnnealingWarmRestarts(
            self.optimizer,
            T_0=config['restart_period'],
            T_mult=config['restart_multiplier'],
            eta_min=config['min_lr']
        )
        
        # Mixed precision and EMA
        self.scaler = GradScaler(enabled=config['use_amp'])
        self.ema = ExponentialMovingAverage(model.parameters(), decay=config['ema_decay'])
        
        # Loss functions
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        
        # Metrics tracking
        self.metrics = {
            'train': {'loss': [], 'psnr': [], 'ssim': []},
            'val': {'psnr': [], 'ssim': []}
        }

    def compute_metrics(self, pred, target):
        """Calculate PSNR and SSIM for AB channels"""
        pred_np = pred.detach().cpu().numpy()
        target_np = target.detach().cpu().numpy()
        
        batch_psnr = []
        batch_ssim = []
        
        for i in range(pred_np.shape[0]):
            p = pred_np[i].transpose(1,2,0)
            t = target_np[i].transpose(1,2,0)
            
            batch_psnr.append(psnr(t, p, data_range=2.0))
            batch_ssim.append(ssim(t, p, multichannel=True, 
                               data_range=2.0, channel_axis=-1))
        
        return np.mean(batch_psnr), np.mean(batch_ssim)

    def train_epoch(self, epoch):
        self.model.train()
        epoch_loss = 0
        epoch_psnr = 0
        epoch_ssim = 0
        
        for batch in tqdm(self.train_loader, desc=f"Epoch {epoch}"):
            L = batch['L'].to(self.device)  # [B,T,1,H,W]
            AB = batch['AB'].to(self.device)  # [B,T,2,H,W]
            first_ab = batch['first_ab'].to(self.device)  # [B,1,2,H,W]
            flows = batch['flows'].to(self.device)  # [B,T-1,2,H,W]
            
            # SAM requires closure for two forward-backward passes
            def closure():
                self.optimizer.zero_grad()
                with autocast(enabled=self.config['use_amp']):
                    pred_ab = self.model(L, first_ab, flows)
                    # Lock first frame AB
                    pred_ab[:,0] = first_ab.squeeze(1)
                    
                    # Combined loss
                    l1_loss = self.l1_loss(pred_ab, AB)
                    warp_loss = self.compute_warp_loss(pred_ab, flows)
                    loss = l1_loss + self.config['warp_weight'] * warp_loss
                
                self.scaler.scale(loss).backward()
                return loss
            
            # SAM optimization step
            loss = self.optimizer.step(closure)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.ema.update()
            
            # Calculate metrics
            with torch.no_grad():
                pred_ab = self.model(L, first_ab, flows)
                pred_ab[:,0] = first_ab.squeeze(1)
                batch_psnr, batch_ssim = self.compute_metrics(pred_ab, AB)
            
            epoch_loss += loss.item()
            epoch_psnr += batch_psnr
            epoch_ssim += batch_ssim
        
        # Store metrics
        self.metrics['train']['loss'].append(epoch_loss / len(self.train_loader))
        self.metrics['train']['psnr'].append(epoch_psnr / len(self.train_loader))
        self.metrics['train']['ssim'].append(epoch_ssim / len(self.train_loader))
        
        # LR schedule step
        self.scheduler.step()

    def compute_warp_loss(self, pred_ab, flows):
        """Temporal consistency loss using optical flow"""
        loss = 0
        B, T = pred_ab.shape[:2]
        
        for t in range(1, T):
            warped = self.warp(pred_ab[:,t-1], flows[:,t-1])
            loss += F.l1_loss(warped, pred_ab[:,t])
        
        return loss / (T - 1)

    def flow_to_grid(self, flow):
        """Convert optical flow to sampling grid"""
        B, C, H, W = flow.shape
        
        # Create base grid
        x = torch.arange(W, dtype=torch.float32, device=flow.device)
        y = torch.arange(H, dtype=torch.float32, device=flow.device)
        x, y = torch.meshgrid(x, y, indexing='xy')
        
        # Normalize to [-1, 1]
        x = 2.0 * x / (W - 1) - 1.0
        y = 2.0 * y / (H - 1) - 1.0
        
        # Add flow
        flow_x = flow[:, 0]  # [B, H, W]
        flow_y = flow[:, 1]  # [B, H, W]
        
        # Normalize flow
        flow_x = 2.0 * flow_x / (W - 1)
        flow_y = 2.0 * flow_y / (H - 1)
        
        # Create grid
        grid_x = x.unsqueeze(0) + flow_x
        grid_y = y.unsqueeze(0) + flow_y
        
        grid = torch.stack([grid_x, grid_y], dim=-1)  # [B, H, W, 2]
        
        return grid

    def warp(self, ab, flow):
        """Differentiable warping of AB channels"""
        grid = self.flow_to_grid(flow)
        return F.grid_sample(ab, grid, padding_mode='border', align_corners=True)

    def validate(self):
        self.model.eval()
        val_psnr = 0
        val_ssim = 0
        
        with torch.no_grad(), self.ema.average_parameters():
            for batch in tqdm(self.val_loader, desc="Validating"):
                L = batch['L'].to(self.device)
                AB = batch['AB'].to(self.device)
                first_ab = batch['first_ab'].to(self.device)
                flows = batch['flows'].to(self.device)
                
                with autocast(enabled=self.config['use_amp']):
                    pred_ab = self.model(L, first_ab, flows)
                    pred_ab[:,0] = first_ab.squeeze(1)
                
                batch_psnr, batch_ssim = self.compute_metrics(pred_ab, AB)
                val_psnr += batch_psnr
                val_ssim += batch_ssim
        
        # Store validation metrics
        self.metrics['val']['psnr'].append(val_psnr / len(self.val_loader))
        self.metrics['val']['ssim'].append(val_ssim / len(self.val_loader))
        
        print(f"Validation PSNR: {self.metrics['val']['psnr'][-1]:.2f}, "
              f"SSIM: {self.metrics['val']['ssim'][-1]:.4f}")

    def train(self, epochs):
        for epoch in range(1, epochs + 1):
            self.train_epoch(epoch)
            self.validate()
            
            # Print metrics
            print(f"Train Loss: {self.metrics['train']['loss'][-1]:.4f}, "
                  f"PSNR: {self.metrics['train']['psnr'][-1]:.2f}, "
                  f"SSIM: {self.metrics['train']['ssim'][-1]:.4f}")
            
            # Save checkpoint
            if epoch % self.config['save_interval'] == 0:
                self.save_checkpoint(epoch)

    def save_checkpoint(self, epoch):
        state = {
            'epoch': epoch,
            'model_state': self.model.state_dict(),
            'optimizer_state': self.optimizer.state_dict(),
            'scheduler_state': self.scheduler.state_dict(),
            'ema_state': self.ema.state_dict(),
            'metrics': self.metrics
        }
        torch.save(state, f"checkpoint_epoch{epoch}.pth")


# Placeholder for VideoFrameDataset - you'll need to implement this
class VideoFrameDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        # Add your dataset implementation here
        pass
    
    def __len__(self):
        # Return dataset size
        return 100  # Placeholder
    
    def __getitem__(self, idx):
        # Return sample with keys: 'L', 'AB', 'first_ab', 'flows'
        # This is a placeholder - implement your actual data loading
        return {
            'L': torch.randn(8, 1, 256, 256),  # [T, 1, H, W]
            'AB': torch.randn(8, 2, 256, 256),  # [T, 2, H, W]
            'first_ab': torch.randn(1, 2, 256, 256),  # [1, 2, H, W]
            'flows': torch.randn(7, 2, 256, 256)  # [T-1, 2, H, W]
        }


# Placeholder for model - you'll need to implement this
class BidirectionalVideoColorizationNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Add your model architecture here
        self.dummy = nn.Linear(1, 1)  # Placeholder
    
    def forward(self, L, first_ab, flows):
        # Return predicted AB channels
        B, T = L.shape[:2]
        return torch.randn(B, T, 2, 256, 256)  # Placeholder


# Configuration
config = {
    'device': 'cuda',
    'lr': 3e-4,
    'min_lr': 1e-6,
    'batch_size': 4,
    'epochs': 100,
    'weight_decay': 1e-4,
    'use_amp': True,
    'sam_rho': 0.05,
    'sam_adaptive': True,
    'lookahead_steps': 5,
    'lookahead_alpha': 0.5,
    'restart_period': 20,
    'restart_multiplier': 1,
    'ema_decay': 0.999,
    'warp_weight': 0.3,
    'save_interval': 5
}

# Usage
if __name__ == "__main__":
    model = BidirectionalVideoColorizationNet()
    trainer = EnhancedTrainer(model, "data/train", "data/val", config)
    trainer.train(config['epochs'])

In [None]:
class WarpConsistencyLoss(nn.Module):
    def __init__(self, l2_weight=1.0, warp_weight=0.5):
        super().__init__()
        self.l2 = nn.MSELoss()
        self.l2_weight = l2_weight
        self.warp_weight = warp_weight

    def forward(self, pred_ab, target_ab, flows):
        # Original L2 loss
        l2_loss = self.l2(pred_ab, target_ab)
        
        # New temporal consistency term
        warp_loss = 0
        for t in range(1, pred_ab.size(0)):
            warped = self.warp(pred_ab[t-1].unsqueeze(0), flows[t-1])
            warp_loss += F.l1_loss(warped, pred_ab[t].unsqueeze(0))
        
        return self.l2_weight*l2_loss + self.warp_weight*warp_loss/pred_ab.size(0)

In [None]:
config = {
    # Model config (exactly matches your original)
    'model': {
        'in_channels': 5,
        'out_channels': 2,
        'base_channels': 64,
        'hist_embed_dim': 64,
        'debug': False
    },
    
    # Training enhancements
    'train': {
        'use_amp': True,
        'ema_decay': 0.999,
        'sam_rho': 0.05,
        'l2_weight': 1.0,
        'warp_weight': 0.3,
        'lr': 3e-4
    }
}

In [None]:
def train_step(self, x, hist, first_ab, targets, flows):
    def closure():
        self.optimizer.zero_grad()
        with autocast():
            outputs = self.model(x, hist, first_ab)
            loss = self.loss_fn(outputs, targets, flows)
        self.scaler.scale(loss).backward()
        return loss
    
    # SAM step
    loss = self.optimizer.step(closure)
    
    # EMA update
    self.ema.update()
    
    return loss

In [None]:
class ColorizationLoss(nn.Module):
    def __init__(self, warp_weight=0.5):
        super().__init__()
        self.l2_loss = nn.MSELoss()
        self.warp_weight = warp_weight
        
    def warp_frame(self, ab, flow):
        """Differentiable warping using grid_sample"""
        # flow: (1,2,H,W) - must be displacement vectors
        H, W = ab.shape[-2:]
        grid = self.flow_to_grid(flow)  # Convert to sampling grid
        return F.grid_sample(ab, grid, mode='bilinear', padding_mode='border', align_corners=True)
    
    def flow_to_grid(self, flow):
        # Generate coordinate grid
        H, W = flow.shape[-2:]
        y, x = torch.meshgrid(torch.arange(H), torch.arange(W))
        grid = torch.stack((x, y), dim=0).float().to(flow.device)  # (2,H,W)
        
        # Add flow displacements (normalize to [-1,1])
        grid = grid + flow.squeeze(0)  # Add displacement
        grid[:, 0, :] = 2.0 * grid[:, 0, :] / max(W - 1, 1) - 1.0  # X coord
        grid[:, 1, :] = 2.0 * grid[:, 1, :] / max(H - 1, 1) - 1.0  # Y coord
        return grid.permute(1, 2, 0).unsqueeze(0)  # (1,H,W,2)

    def forward(self, pred_ab, target_ab, prev_data=None, next_data=None):
        color_loss = self.l2_loss(pred_ab, target_ab)
        
        warp_loss = 0
        if prev_data is not None and next_data is not None:
            prev_pred, prev_flow = prev_data
            next_pred, next_flow = next_data
            
            # Warp using PyTorch (differentiable)
            warped_prev = self.warp_frame(prev_pred, prev_flow)
            warped_next = self.warp_frame(next_pred, next_flow)
            
            warp_loss = F.l1_loss(pred_ab, warped_prev) + F.l1_loss(pred_ab, warped_next)
            warp_loss = warp_loss / 2  # Average two directions
            
        return color_loss + self.warp_weight * warp_loss


In [None]:
def visualize_results(input_frames, output_ab, num_frames=3):
    plt.figure(figsize=(18, 6*num_frames))
    
    for i in range(min(num_frames, len(input_frames))):
        # Get normalized channels
        l_channel = input_frames[i,0].cpu().numpy() * 100  # 0-100
        a_channel = output_ab[i,0].cpu().numpy()  # Already -127 to 127
        b_channel = output_ab[i,1].cpu().numpy()  # Already -127 to 127
        
        # Create LAB image
        lab_image = np.stack([l_channel, a_channel, b_channel], axis=-1).astype(np.float32)
        
        # Convert to RGB
        rgb_image = cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
        rgb_image = np.clip(rgb_image, 0, 1)  # Ensure valid range
        
        # Plotting (same as before)
        plt.subplot(num_frames, 4, 4*i+1)
        plt.imshow(l_channel/100, cmap='gray')  # Show normalized L
        plt.title(f"Frame {i} - L Channel")
        plt.axis('off')
        
        plt.subplot(num_frames, 4, 4*i+2)
        plt.imshow(a_channel, cmap='coolwarm', vmin=-127, vmax=127)
        plt.title(f"Frame {i} - A Channel")
        plt.colorbar()
        plt.axis('off')
        
        plt.subplot(num_frames, 4, 4*i+3)
        plt.imshow(b_channel, cmap='coolwarm', vmin=-127, vmax=127)
        plt.title(f"Frame {i} - B Channel")
        plt.colorbar()
        plt.axis('off')
        
        plt.subplot(num_frames, 4, 4*i+4)
        plt.imshow(rgb_image)
        plt.title(f"Frame {i} - Colorized")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
import cv2
import numpy as np

def get_ab_channels_tensor(image_path, size=256):
    """
    Loads image, converts to Lab color space, returns a and b channels with shape [1, 2, H, W].

    Args:
        image_path (str): Path to image file.
        size (int): Desired height and width of output (default 256).

    Returns:
        np.ndarray: Array of shape [1, 2, H, W] containing a and b channels.
    """
    # Read image
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Could not load image at {image_path}")

    # Resize to desired size
    image = cv2.resize(image, (size, size))

    # Convert to Lab color space
    lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)

    # Split channels
    _, a_channel, b_channel = cv2.split(lab_image)

    # Stack a and b, shape will be [2, H, W]
    ab = np.stack([a_channel, b_channel], axis=0)

    # Add batch dimension -> [1, 2, H, W]
    ab = np.expand_dims(ab, axis=0)

    return ab.astype(np.float32)

# Garbage

In [None]:
# import matplotlib.pyplot as plt
# # 1. Load your preprocessed data
# flows = np.load("/kaggle/working/test_output/flow.npy")        # Shape: (n_frames-1, 1, 2, H, W)
# sharpness = np.load("/kaggle/working/test_output/sharpness.npy") # Shape: (n_frames-1,)
# l_frames = np.load("/kaggle/working/test_output/l_frames.npy")   # Shape: (n_frames, H, W)

# # 2. Prepare fusion input
# fusion_input = prepare_fusion_input(flows, sharpness, l_frames)  # (n_frames, 5, H, W)

# # 3. Initialize model
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = VideoColorizationStreamer().to(device)

# # 4. Run test prediction
# with torch.no_grad():
#     # Select first 8 frames for testing (avoid OOM)
#     test_input = fusion_input[:8].to(device)
#     output = model(test_input)
    
#     print("Input shape:", test_input.shape)
#     print("Output shape:", output.shape)
#     print("Output range: [{:.1f}, {:.1f}]".format(
#         output.min().item(), output.max().item()))

# # 5. Visual verification
# def visualize_results(input_frames, ab_output, num_frames=3):
#     fig, axes = plt.subplots(num_frames, 3, figsize=(15, 5*num_frames))
    
#     for i in range(num_frames):
#         # Input L channel
#         l_channel = input_frames[i,0].cpu().numpy()
#         axes[i,0].imshow(l_channel, cmap='gray', vmin=0, vmax=1)
#         axes[i,0].set_title(f'Frame {i} L channel')
        
#         # Predicted AB channels - show each channel separately
#         ab_vis = ab_output[i].cpu().numpy()
#         axes[i,1].imshow(ab_vis[0], cmap='coolwarm', vmin=-127, vmax=127)  # A channel
#         axes[i,1].set_title('Predicted A channel')
        
#         # Combined LAB->RGB
#         L = l_channel * 100  # Scale to [0,100]
#         AB = ab_vis.transpose(1,2,0)  # Convert to HxWx2
#         lab = np.dstack((L, AB))
#         rgb = cv2.cvtColor(lab.astype(np.float32), cv2.COLOR_LAB2RGB)
#         axes[i,2].imshow(rgb)
#         axes[i,2].set_title('Colorized Result')
    
#     plt.tight_layout()
#     plt.show()

# # Run visualization
# visualize_results(test_input, output)

In [None]:
import torch
import numpy as np

# 1. Load precomputed data
flows_np = np.load("/kaggle/working/test_output/flow.npy")        # (n_frames - 1, 1, 2, H, W)
sharpness_np = np.load("/kaggle/working/test_output/sharpness.npy")  # (n_frames - 1,)
l_frames_np = np.load("/kaggle/working/test_output/l_frames.npy")    # (n_frames, H, W)

# 2. Convert to tensors (no padding yet)
flows_tensor = torch.from_numpy(flows_np).float().squeeze(1)      # (n_frames - 1, 2, H, W)
sharpness_tensor = torch.from_numpy(sharpness_np).float()         # (n_frames - 1,)
l_frames_tensor = torch.from_numpy(l_frames_np).float()           # (n_frames, H, W)

# 3. Define number of frames to process
num_frames = 9  # Total frames including padding, e.g., to predict frames 2-9

assert num_frames <= l_frames_tensor.shape[0], "Requested more frames than available."

# 4. Slice consistent inputs
flows_slice = flows_tensor[:num_frames - 1]       # (num_frames - 1, 2, H, W)
sharpness_slice = sharpness_tensor[:num_frames - 1]  # (num_frames - 1,)
l_slice = l_frames_tensor[:num_frames]            # (num_frames, H, W)

# 5. Prepare fusion input (padding happens inside)
fusion_input = prepare_fusion_input(flows_slice.numpy(), sharpness_slice.numpy(), l_slice.numpy())

print(f"Fusion input shape: {fusion_input.shape}")  # Expect (num_frames, 5, H, W)

# 6. Load your trained colorization model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VideoColorizationNet().to(device)
model.eval()

# 7. Run model on frames
with torch.no_grad():
    test_input = fusion_input.to(device)  # Shape: (num_frames, 5, H, W)
    output = model(test_input)            # Output: (num_frames, 2, H, W)

    print(f"Output shape: {output.shape}")
    print(f"Output range: [{output.min().item():.1f}, {output.max().item():.1f}]")

# 8. Visualize results
visualize_results(test_input.cpu(), output.cpu())


In [None]:
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1).to(device)
        self.relu = nn.ReLU().to(device)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1).to(device)

    def forward(self, x):
        x = x.to(device)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x

In [None]:
import os
from skimage.color import rgb2lab
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms

class VideoFramesDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.video_folders = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))]

    def __len__(self):
        return len(self.video_folders)

    def __getitem__(self, idx):
        folder_path = self.video_folders[idx]
        frame_paths = sorted([os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.png') or f.endswith('.jpg')])

        frames_lab = []
        for frame_path in frame_paths:
            frame = Image.open(frame_path).convert('RGB')
            if self.transform:
                frame = self.transform(frame)

            frame_np = frame.permute(1, 2, 0).numpy()  # Convert to (H, W, C) format
            frame_lab = rgb2lab(frame_np).astype(np.float32)

            # Normalize LAB values
            frame_lab[:, :, 0] = frame_lab[:, :, 0] / 100.0  # Normalize L channel to [0, 1]
            frame_lab[:, :, 1:] = frame_lab[:, :, 1:] / 127.0  # Normalize ab channels to [-1, 1]

            frame_lab = torch.from_numpy(frame_lab).permute(2, 0, 1)  # Convert back to (C, H, W)
            frames_lab.append(frame_lab)

        frames_lab = torch.stack(frames_lab)  # Shape: (num_frames, C, H, W)

        # Extract Y (luminance) and ab (chrominance) channels
        y_frames = frames_lab[:, 0:1, :, :]  # Y channel
        ab_frames = frames_lab[:, 1:, :, :]  # ab channels

        return y_frames, ab_frames  # Shape: (num_frames, 1, H, W), (num_frames, 2, H, W)


# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create dataset and dataloader
dataset = VideoFramesDataset(root_dir='/kaggle/input/vcdataset/DS', transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)

In [None]:
class JFHM(nn.Module):
    def __init__(self, input_channels, hidden_dim, num_frames=3):
        super().__init__()
        self.histogram_extractor = AdaptiveMultiScaleBinning().to(device)
        self.spatial_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=8).to(device)
        self.temporal_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=8).to(device)
        self.num_frames = num_frames
        self.hidden_dim = hidden_dim

    def forward(self, x, flow_features):
        batch_size, channels, height, width = x.shape

        # Compute histograms
        hist_features = self.histogram_extractor(x)  # Shape: (batch_size, num_bins, channels)
        hist_features = hist_features.view(batch_size, -1, channels).to(device)  # Shape: (batch_size, num_bins, channels)

        # Ensure hist_features has the correct embedding dimension
        if hist_features.size(-1) != self.hidden_dim:
            hist_features = hist_features.permute(0, 2, 1)  # Swap dimensions to match expected shape
            hist_features = F.linear(hist_features, torch.zeros(self.hidden_dim, hist_features.size(-1)).to(device))  # Project to hidden_dim

        # Cross-Attention between histograms and flow features
        spatial_features, _ = self.spatial_attention(hist_features, hist_features, hist_features)  # Shape: (batch_size, num_bins, hidden_dim)

        # Process flow features
        temporal_input = torch.cat([flow_features[:, i] for i in range(self.num_frames-1)], dim=1).to(device)  # Shape: (batch_size, num_frames-1, hidden_dim)
        temporal_features, _ = self.temporal_attention(temporal_input, temporal_input, temporal_input)  # Shape: (batch_size, num_frames-1, hidden_dim)

        # Combine spatial and temporal features
        combined_features = (spatial_features + temporal_features).to(device)  # Shape: (batch_size, num_frames-1, hidden_dim)

        return combined_features

In [None]:
class FlowWeightingNetwork(nn.Module):
    def __init__(self, input_channels, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_channels * 2, hidden_dim)  # Adjust for concatenated flows
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, forward_flow, backward_flow):
        # Concatenate forward and backward flows along the channel dimension
        combined = torch.cat([forward_flow, backward_flow], dim=2)  # Shape: (B, num_frames-1, 4, H, W)
        combined = combined.permute(0, 1, 3, 4, 2)  # Shape: (B, num_frames-1, H, W, 4)

        # Pass through the weighting network
        weights = self.relu(self.fc1(combined))  # Shape: (B, num_frames-1, H, W, hidden_dim)
        weights = self.sigmoid(self.fc2(weights))  # Shape: (B, num_frames-1, H, W, 1)

        # Apply weights to forward and backward flows
        weighted_forward = forward_flow * weights.permute(0, 1, 4, 2, 3)  # Shape: (B, num_frames-1, 2, H, W)
        weighted_backward = backward_flow * (1 - weights.permute(0, 1, 4, 2, 3))  # Shape: (B, num_frames-1, 2, H, W)

        # Combine weighted flows
        combined_flow = weighted_forward + weighted_backward  # Shape: (B, num_frames-1, 2, H, W)

        return combined_flow

In [None]:
class VideoColorizationUNet(nn.Module):
    def __init__(self, num_frames=3):
        super().__init__()
        self.num_frames = num_frames
        self.encoder1 = UNetBlock(1, 64)  # Input is Y channel (1 channel)
        self.encoder2 = UNetBlock(64, 128)
        self.encoder3 = UNetBlock(128, 256)
        self.encoder4 = UNetBlock(256, 512)
        self.bottleneck = JFHM(512, 256, num_frames)
        self.decoder4 = UNetBlock(1024, 256)  # Adjusted for skip connection
        self.decoder3 = UNetBlock(512, 128)
        self.decoder2 = UNetBlock(256, 64)
        self.decoder1 = UNetBlock(128, 2)  # Output 2 channels for ab
        self.jfhm1 = JFHM(256, 128, num_frames)
        self.jfhm2 = JFHM(128, 64, num_frames)
        self.flow_estimator = OpticalFlowEstimator()
        self.flow_weighting = FlowWeightingNetwork(input_channels=2, hidden_dim=64)  # Weighting network

    def forward(self, frames):
        frames = frames.to(device)  # Shape: (B, num_frames, 1, H, W)
        print(f"Input frames shape: {frames.shape}")

        # Handle cases where num_frames = 1
        print(self.flow_estimator(frames[0], frames[0]).shape)
        if self.num_frames == 1:
            # Use the same frame for both inputs (no optical flow)
            forward_flow = [self.flow_estimator(frames[0], frames[0])]
            backward_flow = [self.flow_estimator(frames[0], frames[0])]
        else:
            # Forward flow: frame[t] -> frame[t+1]
            forward_flow = [self.flow_estimator(frames[i], frames[i+1]) for i in range(self.num_frames-1)]
            # Backward flow: frame[t+1] -> frame[t]
            backward_flow = [self.flow_estimator(frames[i+1], frames[i]) for i in range(self.num_frames-1)]

        # Stack forward and backward flows
        forward_flow = torch.stack(forward_flow, dim=1)  # Shape: (B, num_frames-1, 2, H, W)
        backward_flow = torch.stack(backward_flow, dim=1)  # Shape: (B, num_frames-1, 2, H, W)
        print(f"Forward flow shape: {forward_flow.shape}")
        print(f"Backward flow shape: {backward_flow.shape}")

        # Use the weighting network to combine forward and backward flows
        flow_features = self.flow_weighting(forward_flow, backward_flow)  # Shape: (B, num_frames-1, 2, H, W)

        # Encoder
        e1 = self.encoder1(frames[:, self.num_frames // 2])  # Use middle frame as reference
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)

        # Bottleneck with flow features
        b = self.bottleneck(e4, flow_features)  # Shape: (B, num_frames-1, hidden_dim)

        # Decoder with skip connections
        d4 = self.decoder4(torch.cat([b, e4], dim=1))
        d3 = self.decoder3(torch.cat([self.jfhm1(d4, flow_features), e3], dim=1))
        d2 = self.decoder2(torch.cat([self.jfhm2(d3, flow_features), e2], dim=1))
        d1 = self.decoder1(torch.cat([d2, e1], dim=1))  # Shape: (B, 2, H, W)

        # Repeat the output for all frames
        predicted_ab = d1.unsqueeze(1).repeat(1, self.num_frames, 1, 1, 1)  # Shape: (B, num_frames, 2, H, W)

        return predicted_ab

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
import logging

# Initialize logging
logging.basicConfig(filename='training.log', level=logging.INFO, format='%(asctime)s - %(message)s')

# Ensure device selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define Trainer Class
class Trainer:
    def __init__(self, model, lr=1e-4, checkpoint_dir='checkpoints'):
        self.model = model.to(device)
        self.optimizer = Adam(self.model.parameters(), lr=lr)
        self.criterion = nn.MSELoss().to(device)
        self.psnr = PeakSignalNoiseRatio().to(device)
        self.ssim = StructuralSimilarityIndexMeasure().to(device)
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        self.scaler = GradScaler()  # For mixed precision training

    def save_checkpoint(self, epoch, loss, checkpoint_name='checkpoint.pth'):
        checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scaler_state_dict': self.scaler.state_dict(),
            'loss': loss,
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

    def load_checkpoint(self, checkpoint_name='checkpoint.pth'):
        checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path, map_location=device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
            epoch = checkpoint['epoch']
            loss = checkpoint['loss']
            print(f"Checkpoint loaded from {checkpoint_path}")
            return epoch, loss
        else:
            print(f"No checkpoint found at {checkpoint_path}")
            return 0, float('inf')  # Start from scratch if no checkpoint exists

    def train_step(self, y, ab_gt, predicted_ab):
        self.model.train()
        self.optimizer.zero_grad()

        # Mixed precision training
        with autocast():
            loss = self.criterion(predicted_ab, ab_gt)

        # Backpropagation with scaling
        self.scaler.scale(loss).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()

        return loss.item()

    def evaluate(self, y, ab_gt, predicted_ab):
        self.model.eval()
        with torch.no_grad():
            psnr = self.psnr(predicted_ab, ab_gt)
            ssim = self.ssim(predicted_ab, ab_gt)
        return psnr.item(), ssim.item()

# # Define Model (Replace with actual implementation)
# class VideoColorizationUNet(nn.Module):
#     def __init__(self):
#         super(VideoColorizationUNet, self).__init__()
#         self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
#         self.conv2 = nn.Conv2d(64, 2, kernel_size=3, padding=1)

#     def forward(self, x):
#         x = torch.relu(self.conv1(x))
#         x = self.conv2(x)
#         return x

# Initialize model and move to device
model = VideoColorizationUNet().to(device)

# Apply weight initialization
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

model.apply(init_weights)

# Initialize optimizer and scheduler
optimizer = Adam(model.parameters(), lr=1e-4)
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)  # Reduce LR by 0.1 every 5 epochs

# Initialize trainer
trainer = Trainer(model, checkpoint_dir='checkpoints')

# Load checkpoint if it exists
start_epoch, best_loss = trainer.load_checkpoint()


In [None]:
# Training loop
for epoch in range(start_epoch, 10):
    for frames_y, frames_ab in dataloader:
        print(frames_y.shape)  # Shape: (B, num_frames, 1, H, W)

        # Move tensors to device
        frames_y = frames_y.to(device, non_blocking=True)  # Shape: (B, num_frames, 1, H, W)
        frames_ab = frames_ab.to(device, non_blocking=True)  # Shape: (B, num_frames, 2, H, W)

        # Predict AB channels for all frames
        predicted_ab = model(frames_y.squeeze(0))  # Shape: (B, num_frames, 2, H, W)

        # Compute loss and metrics for each sample
        for i in range(frames_y.shape[0]):
            # Extract frames for the current sample
            y_frames = frames_y[i]  # Shape: (num_frames, 1, H, W)
            ab_frames = frames_ab[i]  # Shape: (num_frames, 2, H, W)
            pred_ab_frames = predicted_ab[i]  # Shape: (num_frames, 2, H, W)

            # Compute loss
            loss = trainer.train_step(y_frames[0], ab_frames[0], pred_ab_frames)

            # Compute PSNR and SSIM for each frame in the sample
            psnr_values = []
            ssim_values = []
            for j in range(y_frames.shape[0]):
                psnr = trainer.psnr(pred_ab_frames[j], ab_frames[j])
                ssim = trainer.ssim(pred_ab_frames[j], ab_frames[j])
                psnr_values.append(psnr.item())
                ssim_values.append(ssim.item())

            # Log metrics for the sample
            logging.info(f"Epoch {epoch+1}, Sample {i+1}, Loss: {loss:.4f}, Avg PSNR: {np.mean(psnr_values):.4f}, Avg SSIM: {np.mean(ssim_values):.4f}")

    # Save checkpoint at the end of each epoch
    trainer.save_checkpoint(epoch + 1, loss, checkpoint_name=f'checkpoint_epoch_{epoch+1}.pth')

In [None]:
import torch
import urllib.request
import os

# Download pretrained SIGGRAPH17 model (from Colorization repo)
model_url = "https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth"
model_path = "siggraph17.pth"
if not os.path.exists(model_path):
    urllib.request.urlretrieve(model_url, model_path)

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms

# Load model
model = torch.hub.load('harvard-visionlab/pytorch-colorization', 'siggraph17', pretrained=False)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

def colorize_siggraph17(img_path):
    # Load image (convert to LAB space)
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
    img_lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
    img_l = img_lab[:,:,0]  # Extract L channel (grayscale)
    
    # Preprocess
    img_l = cv2.resize(img_l, (256, 256))
    img_l = torch.from_numpy(img_l).unsqueeze(0).unsqueeze(0).float()
    
    # Predict ab channels
    with torch.no_grad():
        pred_ab = model(img_l).cpu().numpy()
    
    # Resize and combine with L
    pred_ab = cv2.resize(pred_ab[0].transpose(1, 2, 0), (img.shape[1], img.shape[0]))
    pred_lab = np.concatenate([img_lab[:,:,0][:,:,np.newaxis], pred_ab], axis=2)
    pred_rgb = cv2.cvtColor(pred_lab, cv2.COLOR_LAB2RGB)
    
    return pred_rgb

In [None]:
import torch
import cv2
import numpy as np
from torchvision import transforms

class VideoColorizer:
    def __init__(self, model_path):
        # Load model
        self.model = torch.hub.load('harvard-visionlab/pytorch-colorization', 
                                  'siggraph17', pretrained=False)
        self.model.load_state_dict(torch.load(model_path, 
                                           map_location=torch.device('cpu')))
        self.model.eval()
        
        # Preprocessing transforms
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])  # For L channel
        ])
    
    def colorize_from_l(self, l_frame, original_size=None):
        """
        Colorize from preprocessed L channel
        Args:
            l_frame: numpy array (H,W) in [0,255] range
            original_size: (width, height) to resize output to
        Returns:
            colorized RGB image
        """
        # Convert to tensor and preprocess
        l_tensor = self.transform(l_frame).unsqueeze(0)  # (1,1,H,W)
        
        # Predict ab channels
        with torch.no_grad():
            pred_ab = self.model(l_tensor).cpu().numpy()[0]  # (2,H,W)
        
        # Reshape and resize
        pred_ab = pred_ab.transpose(1, 2, 0)  # (H,W,2)
        
        if original_size is not None:
            pred_ab = cv2.resize(pred_ab, original_size)
            l_frame = cv2.resize(l_frame, original_size)
        
        # Combine with L channel
        pred_lab = np.dstack((l_frame, pred_ab))
        
        # Convert to RGB
        pred_rgb = cv2.cvtColor(pred_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
        return pred_rgb

# Usage example:
if __name__ == "__main__":
    # Initialize with your model path
    colorizer = VideoColorizer("path/to/siggraph17.pth")
    
    # Load your preprocessed L frame (from earlier pipeline)
    l_frame = np.load("/kaggle/working/test_output/L/frame_0000.npy")  # Shape (H,W)
    
    # Colorize
    colorized = colorizer.colorize_from_l(l_frame, original_size=(512, 512))
    
    # Display
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(l_frame, cmap='gray')
    plt.title("Input L Channel")
    
    plt.subplot(1, 2, 2)
    plt.imshow(colorized)
    plt.title("Colorized Output")
    plt.show()

In [None]:
import torch
import torch.nn as nn
import cv2
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
import os
from tqdm import tqdm

class ColorizationModel(nn.Module):
    """A simplified version of the SIGGRAPH17 colorization model"""
    def __init__(self):
        super().__init__()
        # Define your model architecture here
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 2, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

class VideoColorizer:
    def __init__(self, model_path=None):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = ColorizationModel().to(self.device)
        
        if model_path and os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        
        self.model.eval()
        
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
    
    def colorize_from_l(self, l_frame, original_size=None):
        """
        Colorize from preprocessed L channel
        Args:
            l_frame: numpy array (H,W) in [0,255] range
            original_size: (width, height) to resize output to
        Returns:
            colorized RGB image
        """
        # Convert to tensor and preprocess
        l_tensor = self.transform(l_frame).unsqueeze(0).to(self.device)  # (1,1,H,W)
        
        # Predict ab channels
        with torch.no_grad():
            pred_ab = self.model(l_tensor).cpu().numpy()[0]  # (2,H,W)
        
        # Post-process
        pred_ab = pred_ab.transpose(1, 2, 0)  # (H,W,2)
        pred_ab = (pred_ab * 127.5).astype(np.float32)  # Scale to [-127, 127]
        
        if original_size is not None:
            pred_ab = cv2.resize(pred_ab, original_size)
            l_frame = cv2.resize(l_frame, original_size)
        
        # Combine with L channel
        pred_lab = np.dstack((l_frame, pred_ab))
        pred_rgb = cv2.cvtColor(pred_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
        return pred_rgb

    def colorize_video_frames(self, l_folder, output_folder):
        """Colorize all L frames in a folder"""
        os.makedirs(output_folder, exist_ok=True)
        l_files = sorted([f for f in os.listdir(l_folder) if f.endswith('.npy')])
        
        for l_file in tqdm(l_files, desc="Colorizing frames"):
            l_path = os.path.join(l_folder, l_file)
            l_frame = np.load(l_path)
            
            colorized = self.colorize_from_l(l_frame)
            output_path = os.path.join(output_folder, l_file.replace('.npy', '.png'))
            cv2.imwrite(output_path, cv2.cvtColor(colorized, cv2.COLOR_RGB2BGR))

# Usage example
if __name__ == "__main__":
    # Initialize colorizer (without pretrained weights)
    colorizer = VideoColorizer()
    
    # Example: Colorize a single L frame
    l_frame = np.load("/kaggle/working/test_output/L/00000.npy")  # Replace with your actual L frame
    colorized = colorizer.colorize_from_l(l_frame)
    
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(l_frame, cmap='gray')
    plt.title("Input L Channel")
    
    plt.subplot(1, 2, 2)
    plt.imshow(colorized)
    plt.title("Colorized Output")
    plt.show()
    
    # To process a whole folder:
    # colorizer.colorize_video_frames(
    #     l_folder="/path/to/L_frames",
    #     output_folder="/path/to/colorized_output"
    # )

In [None]:
!pip install "numpy<2.0.0" --force-reinstall

In [None]:

import torch
from torch import nn

class BaseColor(nn.Module):
	def __init__(self):
		super(BaseColor, self).__init__()

		self.l_cent = 50.
		self.l_norm = 100.
		self.ab_norm = 110.

	def normalize_l(self, in_l):
		return (in_l-self.l_cent)/self.l_norm

	def unnormalize_l(self, in_l):
		return in_l*self.l_norm + self.l_cent

	def normalize_ab(self, in_ab):
		return in_ab/self.ab_norm

	def unnormalize_ab(self, in_ab):
		return in_ab*self.ab_norm

In [None]:
import torch
import torch.nn as nn


class SIGGRAPHGenerator(BaseColor):
    def __init__(self, norm_layer=nn.BatchNorm2d, classes=529):
        super(SIGGRAPHGenerator, self).__init__()

        # Conv1
        model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),]
        model1+=[nn.ReLU(True),]
        model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),]
        model1+=[nn.ReLU(True),]
        model1+=[norm_layer(64),]
        # add a subsampling operation

        # Conv2
        model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
        model2+=[nn.ReLU(True),]
        model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
        model2+=[nn.ReLU(True),]
        model2+=[norm_layer(128),]
        # add a subsampling layer operation

        # Conv3
        model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[norm_layer(256),]
        # add a subsampling layer operation

        # Conv4
        model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model4+=[nn.ReLU(True),]
        model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model4+=[nn.ReLU(True),]
        model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model4+=[nn.ReLU(True),]
        model4+=[norm_layer(512),]

        # Conv5
        model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model5+=[nn.ReLU(True),]
        model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model5+=[nn.ReLU(True),]
        model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model5+=[nn.ReLU(True),]
        model5+=[norm_layer(512),]

        # Conv6
        model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model6+=[nn.ReLU(True),]
        model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model6+=[nn.ReLU(True),]
        model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model6+=[nn.ReLU(True),]
        model6+=[norm_layer(512),]

        # Conv7
        model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model7+=[nn.ReLU(True),]
        model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model7+=[nn.ReLU(True),]
        model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model7+=[nn.ReLU(True),]
        model7+=[norm_layer(512),]

        # Conv7
        model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)]
        model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]

        model8=[nn.ReLU(True),]
        model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model8+=[nn.ReLU(True),]
        model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model8+=[nn.ReLU(True),]
        model8+=[norm_layer(256),]

        # Conv9
        model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),]
        model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
        # add the two feature maps above        

        model9=[nn.ReLU(True),]
        model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
        model9+=[nn.ReLU(True),]
        model9+=[norm_layer(128),]

        # Conv10
        model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),]
        model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
        # add the two feature maps above

        model10=[nn.ReLU(True),]
        model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),]
        model10+=[nn.LeakyReLU(negative_slope=.2),]

        # classification output
        model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]

        # regression output
        model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
        model_out+=[nn.Tanh()]

        self.model1 = nn.Sequential(*model1)
        self.model2 = nn.Sequential(*model2)
        self.model3 = nn.Sequential(*model3)
        self.model4 = nn.Sequential(*model4)
        self.model5 = nn.Sequential(*model5)
        self.model6 = nn.Sequential(*model6)
        self.model7 = nn.Sequential(*model7)
        self.model8up = nn.Sequential(*model8up)
        self.model8 = nn.Sequential(*model8)
        self.model9up = nn.Sequential(*model9up)
        self.model9 = nn.Sequential(*model9)
        self.model10up = nn.Sequential(*model10up)
        self.model10 = nn.Sequential(*model10)
        self.model3short8 = nn.Sequential(*model3short8)
        self.model2short9 = nn.Sequential(*model2short9)
        self.model1short10 = nn.Sequential(*model1short10)

        self.model_class = nn.Sequential(*model_class)
        self.model_out = nn.Sequential(*model_out)

        self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),])
        self.softmax = nn.Sequential(*[nn.Softmax(dim=1),])

    def forward(self, input_A, input_B=None, mask_B=None):
        if(input_B is None):
            input_B = torch.cat((input_A*0, input_A*0), dim=1)
        if(mask_B is None):
            mask_B = input_A*0

        conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1))
        conv2_2 = self.model2(conv1_2[:,:,::2,::2])
        conv3_3 = self.model3(conv2_2[:,:,::2,::2])
        conv4_3 = self.model4(conv3_3[:,:,::2,::2])
        conv5_3 = self.model5(conv4_3)
        conv6_3 = self.model6(conv5_3)
        conv7_3 = self.model7(conv6_3)

        conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
        conv8_3 = self.model8(conv8_up)
        conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
        conv9_3 = self.model9(conv9_up)
        conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
        conv10_2 = self.model10(conv10_up)
        out_reg = self.model_out(conv10_2)

        conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
        conv9_3 = self.model9(conv9_up)
        conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
        conv10_2 = self.model10(conv10_up)
        out_reg = self.model_out(conv10_2)

        return self.unnormalize_ab(out_reg)

def siggraph17(pretrained=True):
    model = SIGGRAPHGenerator()
    if(pretrained):
        import torch.utils.model_zoo as model_zoo
        model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
    return model


In [None]:
import numpy as np
import cv2
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import urllib.request
import os

class VideoColorizer:
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self._load_correct_model()
        self.model.eval()
    
    def _load_correct_model(self):
        # Load compatible weights (tested version)
        model_url = "https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth"
        model_path = "siggraph17_compatible.pth"
        
        if not os.path.exists(model_path):
            print("Downloading compatible weights...")
            urllib.request.urlretrieve(model_url, model_path)
        
        # Use simplified model that matches these weights
        model = torch.hub.load('kazuto1011/colorization-pytorch', 'siggraph17', pretrained=False)
        state_dict = torch.load(model_path, map_location=self.device, weights_only=True)
        model.load_state_dict(state_dict)
        return model.to(self.device)

    def colorize_L_frame(self, L):
        """Colorize a single L-frame (numpy array)"""
        # Convert to tensor and normalize
        L = L.astype(np.float32)
        L_tensor = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float().to(self.device)
        L_tensor = (L_tensor / 255.0 * 100) - 50  # Normalize to [-50, 50]
        
        # Predict AB channels
        with torch.no_grad():
            pred_ab = self.model(L_tensor).cpu().numpy()[0]
        
        # Resize AB to match original dimensions
        pred_ab = cv2.resize(pred_ab.transpose(1, 2, 0), (L.shape[1], L.shape[0]))
        
        # Combine with L and convert to RGB
        pred_lab = np.concatenate([L[:,:,np.newaxis], pred_ab], axis=2)
        pred_rgb = cv2.cvtColor(pred_lab.astype(np.float32), cv2.COLOR_LAB2RGB)
        return np.clip(pred_rgb * 255, 0, 255).astype(np.uint8)

# Usage
if __name__ == "__main__":
    colorizer = VideoColorizer()
    
    # Load your L-frame
    L = np.load("/kaggle/working/test_output/L/00000.npy")  # Replace with your path
    
    # Colorize
    colorized = colorizer.colorize_L_frame(L)
    
    # Display
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(L, cmap='gray')
    plt.title("Original L-frame")
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(colorized)
    plt.title("Colorized")
    plt.axis('off')
    plt.show()

In [None]:
import urllib.request
model_url = "https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth"
model_path = "siggraph17.pth"
urllib.request.urlretrieve(model_url, model_path)

In [None]:
# In your loss function
def loss_fn(pred_ab, target_ab):
    # L1 loss for color accuracy
    color_loss = F.l1_loss(pred_ab, target_ab)
    
    # Temporal smoothness loss (between consecutive frames)
    flow_loss = F.mse_loss(pred_ab[1:] - pred_ab[:-1], 
                          target_ab[1:] - target_ab[:-1])
    
    return color_loss + 0.3 * flow_loss  # Weighted sum