In [None]:
%cd /workspaces/torch-basics/
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torchvision
import numpy as np
import cv2
import kornia
import logging
from rich import print
from rich.logging import RichHandler

logging.basicConfig(
    level=logging.INFO,
    format="%(message)s",
    datefmt="[%X]",
    handlers=[RichHandler(rich_tracebacks=True)],
)

from flux_control.utils.describe import describe

In [None]:
video = torchvision.io.read_video(
    "./flux/assets/video.mp4", output_format="TCHW", pts_unit="sec"
)
video_frames = video[0].float() / 255.0
describe(video)
describe(video_frames)

In [None]:
from flux_control.datasets.collage.flow import (
    load_raft_model,
    compute_aggregated_flow,
    unload_raft_model,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
load_raft_model(device=device)

selected_frames = video_frames[80:100].to(device)
flow, target_idx = compute_aggregated_flow(selected_frames)
describe(flow)
unload_raft_model()

In [None]:
from einops import rearrange, repeat, reduce
from PIL import Image

def visualize_flow(flow):
    if len(flow.shape) == 4:
        flow = rearrange(flow, "1 c h w -> c h w")
    flow = flow.cpu().numpy()
    flow = rearrange(flow, "c h w -> h w c")
    # Use Hue, Saturation, Value colour model
    hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
    hsv[..., 2] = 255

    mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
    hsv[..., 0] = ang / np.pi / 2 * 180
    hsv[..., 1] = np.clip(mag * 255 / 24, 0, 255)
    bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    return Image.fromarray(bgr)

def visualize_image(image):
    if len(image.shape) == 4:
        image = rearrange(image, "1 c h w -> c h w")
    image = image.cpu().numpy()
    image = rearrange(image, "c h w -> h w c")
    return Image.fromarray((image * 255).astype(np.uint8))

display(visualize_image(selected_frames[0]))
display(visualize_image(selected_frames[target_idx]))
display(visualize_flow(flow))

In [None]:
from flux_control.datasets.collage.warp import forward_warp

warp_result = forward_warp(selected_frames[0], selected_frames[target_idx], flow)
describe(warp_result)

In [None]:
display(visualize_image(warp_result[0]))

In [None]:
from flux_control.datasets.collage.depth import (
    load_depth_model,
    estimate_depth,
    unload_depth_model,
)

load_depth_model(device=device)
depth = estimate_depth(selected_frames[0])
describe(depth)
unload_depth_model()

In [None]:
def visualize_greyscale(image):
    image = image.squeeze().cpu().numpy() # [h, w]
    image = repeat(image, "h w -> h w c", c=3)
    image = (image * 255).astype(np.uint8)
    return Image.fromarray(image)

display(visualize_greyscale(depth))

In [None]:
from flux_control.datasets.collage.segmentation import load_segmentation_model, generate_masks, unload_segmentation_model

load_segmentation_model(device=device)
masks = generate_masks(selected_frames[0])
describe(masks)
unload_segmentation_model()

In [None]:
from flux_control.datasets.collage.affine import compute_transform_data_structured, apply_transforms

transform, dropped = compute_transform_data_structured(flow, depth, masks)
describe(transform)
describe(dropped)

In [None]:
warped, grid, warped_regions, warped_alpha = apply_transforms(
    selected_frames[0], depth, transform
)
display(visualize_image(warped))
# display(visualize_greyscale(warped_regions))
# display(visualize_greyscale(warped_alpha))

In [None]:
describe(warped)
describe(grid)
describe(warped_regions)
describe(warped_alpha)

warped_alpha = torch.clamp(warped_alpha, 0, 1)

In [None]:
from flux_control.datasets.collage.video import splat_lost_regions

warped, grid, warped_alpha = splat_lost_regions(
    selected_frames[0],
    selected_frames[target_idx],
    flow,
    warped,
    grid,
    warped_regions,
    warped_alpha,
)

display(visualize_image(warped))

In [None]:
from flux_control.datasets.collage.dexined import load_dexined_model, estimate_edges

load_dexined_model(device=device)
edges = estimate_edges(selected_frames[target_idx])
describe(edges)
display(visualize_greyscale(edges))

In [None]:
from flux_control.datasets.collage.palette import extract_palette_from_masked_image, show_color_palette
palette, _ = extract_palette_from_masked_image(
    selected_frames[0], torch.ones_like(warped_alpha), 5
)
describe(palette)
show_color_palette(palette)

In [None]:
from flux_control.datasets.collage.video import encode_color_palette

palettes, locations = encode_color_palette(selected_frames[0], dropped, area_threshold=0.05)
describe((palettes, locations))

In [None]:
show_color_palette(palettes)

In [None]:
import matplotlib.pyplot as plt
from flux_control.utils.common import meshgrid_to_ij

def show_palette_with_locations(image, palettes, locations):
    c, h, w = image.shape
    image = rearrange(image, "c h w -> h w c")
    image = image.cpu().numpy()
    plt.imshow(image)
    locations_ij = meshgrid_to_ij(locations, h, w)
    palettes = palettes.cpu().numpy()
    locations_ij = locations_ij.cpu().numpy()
    for i in range(palettes.shape[0]):
        palette = palettes[i]
        loc = locations_ij[i]
        plt.scatter(loc[1], loc[0], color=palette, s=100, marker="o", edgecolors="black")
    plt.axis("off")
    plt.show()

show_palette_with_locations(
    selected_frames[0] * (1 - warped_regions), palettes, locations
)

In [None]:
_, h, w = selected_frames[0].shape
dropped_masks_qualify = [mask for mask in dropped if mask["area"] > h * w * 0.05]
describe(dropped_masks_qualify)

for mask in dropped_masks_qualify:
    mask_torch = torch.from_numpy(mask["mask"]).to(device)
    # display(visualize_image(selected_frames[0] * mask_torch))
    palettes, locations = extract_palette_from_masked_image(
        selected_frames[0], mask_torch, max_colors=3, min_colors=1
    )
    show_color_palette(palettes)
    show_palette_with_locations(selected_frames[0] * mask_torch, palettes, locations)
    print(locations)

In [None]:
palette, locations = extract_palette_from_masked_image(
    selected_frames[0], torch.ones_like(warped_alpha), 5
)
show_color_palette(palette)
show_palette_with_locations(selected_frames[0], palette, locations)

In [None]:
from flux_control.datasets.collage.palette import extract_palette_from_masked_image_with_spatial

palette, locations = extract_palette_from_masked_image_with_spatial(
    selected_frames[0], torch.ones_like(warped_alpha), 5
)
show_color_palette(palette)
show_palette_with_locations(selected_frames[0], palette, locations)