# Advanced Video Colorization with Sequence-Based Model

This notebook demonstrates video colorization using an advanced sequence-based model, conceptually similar to architectures like BiSTNet. Such models consider multiple frames at once to improve temporal consistency and reduce flickering compared to single-frame colorization methods.

## Setup

Import necessary modules and set up the environment.

In [None]:
#NOTE:  This must be the first DeOldify cell to run! Select GPU environment in Colab for best performance.
from deoldify import device
from deoldify.device_id import DeviceId

#Choices: CPU, GPU0...GPU7
device.set(device=DeviceId.GPU0)

In [None]:
import warnings
from deoldify.visualize import get_advanced_video_colorizer, show_video_in_notebook
import matplotlib.pyplot as plt

# Matplotlib style for dark background, optional
plt.style.use('dark_background')
warnings.filterwarnings("ignore", category=UserWarning, message=".*?Your .*? set is empty.*?")

## Initialize the Advanced Video Colorizer

This step loads the sequence-based model. The `n_frames_input` parameter (defaulting to 5 in the model setup) determines how many frames the model considers simultaneously. A higher number can lead to better temporal consistency but may be more computationally intensive.

In [None]:
# Parameters like weights_name, render_factor, n_frames_input can be adjusted here if needed,
# but we'll use defaults or values passed during get_advanced_video_colorizer setup.
colorizer = get_advanced_video_colorizer() 
# To use specific weights or model parameters (if future models allow):
# colorizer = get_advanced_video_colorizer(weights_name='MyAdvancedModel.pth', n_frames_input=7)

## Debug Visualizations Setup
The following cells help visualize various stages of the colorization process for a deeper understanding and debugging.

In [None]:
%matplotlib inline
import os
from pathlib import Path
import numpy as np
from PIL import Image as PilImage
import matplotlib.pyplot as plt
import torch
from fastai.vision.image import Image as FastAIImage 
from fastai.vision import transform as vision_transform 

def show_images_row(imgs_with_titles, main_title='', fig_size=(15, 5)):
    num_imgs = len(imgs_with_titles)
    if num_imgs == 0:
        print("No images to display.")
        return
    fig, axes = plt.subplots(1, num_imgs, figsize=fig_size)
    if num_imgs == 1:
        axes = [axes] 
    for ax, (img, title) in zip(axes, imgs_with_titles):
        if isinstance(img, torch.Tensor):
            img = vision_transform.ToPILImage()(img.cpu())
        ax.imshow(img)
        ax.set_title(title)
        ax.axis('off')
    fig.suptitle(main_title, fontsize=16)
    plt.show()

## Instructions

This notebook is primarily set up for processing local video files. You can also process videos from URLs.

### For Local Files (Recommended for Debugging Visualizations):
1. Upload your video file to the `video/source/` directory in your Colab/Jupyter environment.
2. In the **"Colorize Video!"** cell below, ensure `source_url` is set to `None`.
3. Set `file_name_with_ext` in that same cell to the exact name of your uploaded video file (e.g., 'my_local_video.mp4').

### `file_name_stem` (Used for naming output directories and files):
   - If processing a local file, the `file_name_stem` is automatically derived from `file_name_with_ext` (e.g., 'my_local_video' from 'my_local_video.mp4'). This stem is used for creating directories in `video/bwframes/` and `video/colorframes/`.
   - If processing from a URL, you'll define a `file_name_stem_for_url` which will be used for the downloaded and processed files.

### `source_url` (Optional):
   - To process a video from a URL (e.g., YouTube), uncomment the URL-related lines in the **"Colorize Video!"** cell and set `source_url` to the video's web address. Provide a `file_name_stem_for_url`.

### `render_factor`:
   - Determines the resolution at which the colorization model processes the video. Lower values (e.g., 21) render faster, and colors might appear more vibrant. Higher values can be better for high-quality input but may wash out colors slightly. Default is 21. Max is ~44 on 11GB GPUs.

### `result_path`:
   - This will be automatically determined. The final colorized video will be in the `video/result/` directory.

## Colorize Video!

In [None]:
# --- LOCAL FILE PROCESSING (Primary Example) ---
source_url = None 
# 1. Upload your video to the 'video/source/' directory.
# 2. Set file_name_with_ext to the exact name of your uploaded video file.
file_name_with_ext = 'my_local_video.mp4' # IMPORTANT: Change this to your uploaded video's name

render_factor = 21 # Default render_factor. Adjust if needed.

# --- Optional: URL PROCESSING ---
# To use a URL, comment out the local file section above, and uncomment the lines below.
# source_url = 'https://www.youtube.com/watch?v=your_video_id_here' # Replace with a real video URL
# file_name_stem_for_url = 'my_url_video' # Base name for downloaded/output files

result_path = None
file_name_stem = None # Will be set based on the processing path

if source_url is not None:
    print(f"Attempting to colorize video from URL: {source_url}")
    file_name_stem = file_name_stem_for_url # Use the user-defined stem for URL processing
    result_path = colorizer.colorize_from_url(source_url, base_file_name=file_name_stem, render_factor=render_factor)
    if result_path and os.path.exists(result_path):
        print(f"Colorized video saved to: {result_path}")
        show_video_in_notebook(result_path)
    else:
        print(f"Colorization from URL failed or result_path is not valid: {result_path}")
elif file_name_with_ext: # Local file processing
    print(f"Attempting to colorize local file: video/source/{file_name_with_ext}")
    file_name_stem = Path(file_name_with_ext).stem 
    source_video_path = Path('./video/source') / file_name_with_ext
    if not source_video_path.exists():
        print(f"ERROR: Source video not found at {source_video_path}. Please upload the video and try again.")
    else:
        result_path = colorizer.colorize_from_file_name(file_name_with_ext, render_factor=render_factor)
        if result_path and os.path.exists(result_path):
            print(f"Colorized video saved to: {result_path}")
            show_video_in_notebook(result_path)
        else:
            print(f"Colorization of local file failed or result_path is not valid: {result_path}")
else:
    print("Please set `file_name_with_ext` for local file processing or provide and uncomment a `source_url`.")

## --- Debug Visualizations ---
The cells below are for inspecting the intermediate results of the video colorization process. 
**Important:** You must first run the "Colorize Video!" cell above to generate the necessary frame data.

### Specify Video for Frame Inspection
This cell sets the `file_name_stem_for_inspection` variable. It should match the `file_name_stem` derived from the video processed in the "Colorize Video!" cell. If you processed a local file named `my_video.mp4`, then `file_name_stem` would be `my_video`.

In [None]:
# This variable should match the 'file_name_stem' from the video you just processed.
if 'file_name_stem' in locals() and file_name_stem:
    file_name_stem_for_inspection = file_name_stem
else:
    # Fallback if file_name_stem was not set (e.g., if colorization cell was not run or failed early)
    file_name_stem_for_inspection = 'my_local_video' # Replace with your default if needed
    print(f"Warning: 'file_name_stem' was not found from the colorization cell. Defaulting to '{file_name_stem_for_inspection}'. Ensure this directory exists in video/bwframes/ and video/colorframes/.")

print(f"Using file_name_stem_for_inspection: '{file_name_stem_for_inspection}' for debug visualizations.")

# Define root paths for black & white (bwframes) and colorized (colorframes) frames
bw_frames_root = Path('./video/bwframes')
color_frames_root = Path('./video/colorframes')

# Construct full paths to the specific video's frame directories
bw_frames_dir = bw_frames_root / file_name_stem_for_inspection
color_frames_dir = color_frames_root / file_name_stem_for_inspection

# Frame number to display (e.g., '00001.jpg'). Change as needed.
frame_to_show_name = '00001.jpg'

### Original Grayscale vs. Colorized Frame
Compares a selected original grayscale frame with its colorized version.

In [None]:
original_frame_path = bw_frames_dir / frame_to_show_name
colorized_frame_path = color_frames_dir / frame_to_show_name

imgs_to_display = []
if original_frame_path.exists():
    imgs_to_display.append((PilImage.open(original_frame_path), f'Original BW ({frame_to_show_name})'))
else:
    print(f'Original frame not found: {original_frame_path}')

if colorized_frame_path.exists():
    imgs_to_display.append((PilImage.open(colorized_frame_path), f'Colorized ({frame_to_show_name})'))
else:
    print(f'Colorized frame not found: {colorized_frame_path}')

if imgs_to_display:
    show_images_row(imgs_to_display, 'Original vs. Colorized Frame Comparison')

### Visualize L, A, B Channels
Shows the individual L (Luminance), A (Green-Red), and B (Blue-Yellow) channels of a colorized frame. The A and B channels are visualized as grayscale images representing the intensity of color.

In [None]:
if colorized_frame_path.exists(): # Uses colorized_frame_path from the cell above
    color_pil = PilImage.open(colorized_frame_path)
    lab_pil = color_pil.convert('LAB')
    l_pil, a_pil, b_pil = lab_pil.split()

    imgs_to_display_lab = [
        (color_pil, 'Colorized Original'),
        (l_pil, 'L Channel'),
        (a_pil, 'A Channel (Grayscale)'),
        (b_pil, 'B Channel (Grayscale)')
    ]
    show_images_row(imgs_to_display_lab, 'L, A, B Channel Visualization', fig_size=(20,5))
else:
    print(f'Colorized frame not found for LAB analysis: {colorized_frame_path}')

### Visualize Input Sequence to Model (Conceptual)
Displays a sequence of grayscale frames that would be fed to the sequence-based model. The model uses such sequences to colorize the middle frame.

In [None]:
n_frames_for_model = colorizer.vis.filter.n_frames_input if hasattr(colorizer.vis.filter, 'n_frames_input') else 5
center_frame_file_idx = -1
imgs_sequence = []

if bw_frames_dir.exists():
    frame_files = sorted([f for f in os.listdir(bw_frames_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])
    if not frame_files:
        print(f"No frames found in {bw_frames_dir}")
    else:
        try:
            center_frame_file_idx = frame_files.index(frame_to_show_name)
        except ValueError:
            print(f"Frame {frame_to_show_name} not found in {bw_frames_dir}. Defaulting to a frame near the middle if possible.")
            center_frame_file_idx = len(frame_files) // 2

        half_seq = n_frames_for_model // 2
        seq_start_file_idx = max(0, center_frame_file_idx - half_seq)
        seq_start_file_idx = min(seq_start_file_idx, len(frame_files) - n_frames_for_model)
        seq_start_file_idx = max(0, seq_start_file_idx) 
        seq_end_file_idx = seq_start_file_idx + n_frames_for_model

        if len(frame_files) >= n_frames_for_model and seq_end_file_idx <= len(frame_files):
            for i in range(seq_start_file_idx, seq_end_file_idx):
                frame_path = bw_frames_dir / frame_files[i]
                if frame_path.exists():
                    title = f'Seq F{i-seq_start_file_idx+1} ({frame_files[i]})'
                    if i == center_frame_file_idx:
                        title = f'**CENTER** {title}'
                    imgs_sequence.append((PilImage.open(frame_path).convert('L'), title))
                else:
                    print(f'Frame {frame_files[i]} not found for sequence.')
                    imgs_sequence = [] 
                    break
            if len(imgs_sequence) == n_frames_for_model:
                show_images_row(imgs_sequence, f'Example Grayscale Input Sequence (Model sees {n_frames_for_model} frames)', fig_size=(n_frames_for_model * 3, 3))
        else:
            center_frame_name_for_print = frame_files[center_frame_file_idx] if center_frame_file_idx >=0 and center_frame_file_idx < len(frame_files) else frame_to_show_name
            print(f'Not enough frames in {bw_frames_dir} (found {len(frame_files)}, need {n_frames_for_model}) to display sequence around frame {center_frame_name_for_print}.')
else:
    print(f"Directory with BW frames not found: {bw_frames_dir}")

### Visualize Model's AB Output Contribution (Conceptual)
This cell attempts to show only the color information (AB channels) produced by the model by applying it to a neutral gray L channel. This helps to isolate what color data the model is generating for the middle frame of a sequence.

In [None]:
if colorized_frame_path.exists(): 
    color_pil = PilImage.open(colorized_frame_path)
    lab_pil = color_pil.convert('LAB')
    _, a_pil, b_pil = lab_pil.split()

    gray_l_channel = PilImage.new('L', color_pil.size, color=128)
    
    model_color_contribution = PilImage.merge('LAB', (gray_l_channel, a_pil, b_pil)).convert('RGB')

    imgs_to_display_ab = [
        (color_pil, 'Full Colorized Frame'),
        (model_color_contribution, 'Model AB on Neutral L')
    ]
    show_images_row(imgs_to_display_ab, 'Visualizing Model Color (AB) Contribution')
else:
    print(f'Colorized frame not found for AB contribution analysis: {colorized_frame_path}')

## Discussion & Comparison

The sequence-based model used here is designed to produce more temporally stable colorization with reduced flickering compared to models that process each frame independently. 

Consider colorizing the same video using the original `VideoColorizer.ipynb` (which uses a single-frame model) and compare the results. Pay attention to:
- Flickering in areas that should have consistent color.
- Color consistency of moving objects.

Note: The actual visual improvement depends heavily on the underlying model architecture and training. This notebook provides the *framework* for using such advanced models.

## Frame Preview (Optional)

The `plot_transformed_image` function visualizes how a single frame is colorized. When used with the `VideoColorizerFilter`, this shows the result of its single-image processing mode (which duplicates the frame to form a sequence). This preview might not fully represent the temporal benefits seen in the full video but can be useful for checking general color quality for a given `render_factor`.

In [None]:
# This uses the .P() method of the filter, which for VideoColorizerFilter duplicates the frame.
# Ensure you have processed a video first so that 'bwframes' are available, or provide a direct image path.

if 'file_name_stem_for_inspection' in locals() and file_name_stem_for_inspection and 'frame_to_show_name' in locals() and frame_to_show_name:
    preview_frame_path = bw_frames_dir / frame_to_show_name # Uses bw_frames_dir from a cell above
    if preview_frame_path.exists():
        print(f"Displaying preview for: {preview_frame_path}")
        colorizer.vis.plot_transformed_image(str(preview_frame_path), render_factor=render_factor, display_render_factor=True, figsize=(8,8))
    else:
        print(f"Preview frame {preview_frame_path} not found. Ensure video was processed and file_name_stem_for_inspection is correct.")
else:
    print("Variables 'file_name_stem_for_inspection' or 'frame_to_show_name' are not set. Ensure the colorization and inspection setup cells were run.")