In [None]:
%cd /workspaces/torch-basics/
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torchvision
import numpy as np
import cv2
import kornia

from flux_control.utils.describe import describe
from data.visualize import (
    visualize_flow,
    visualize_image,
    visualize_grayscale,
    visualize_grid,
    show_grayscale_colorbar,
    show_image_histogram
)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from flux_control.datasets.collage.config import CollageConfig

cfg = CollageConfig(
    chance_keep_leaf=1.,
    chance_keep_stem=0.3,
    chance_split_stem=0.7,
    transform_erode_size=0,
    transform_dilate_size=0
)

In [None]:
from flux_control.datasets.collage.video import load_video

video_path = "data/panda-ours/5pk7860iymE_5_0to124.mp4"
video = load_video(video_path)
describe(video)

In [None]:
from flux_control.datasets.collage.flow import load_raft_model, unload_raft_model
from flux_control.datasets.collage.video import try_extract_frame

load_raft_model(device)
result = try_extract_frame(video, device=device, cfg=cfg)
assert result is not None, "Failed to extract frame"
flow, src, tgt = result
unload_raft_model()

In [None]:
describe(result)
display(visualize_image(src))
display(visualize_image(tgt))
display(visualize_flow(flow))

In [None]:
from flux_control.datasets.collage.rmbg import load_rmbg_model, unload_rmbg_model, estimate_foreground

load_rmbg_model(device)
foreground = estimate_foreground(tgt)
unload_rmbg_model()

describe(foreground)
display(visualize_grayscale(foreground))

In [None]:
from flux_control.datasets.collage.warp import forward_warp

splat, grid_splat, mask_splat = forward_warp(src, tgt, flow)

display(visualize_image(splat))
display(visualize_grid(grid_splat, mask_splat))
display(visualize_grayscale(mask_splat))

In [None]:
from flux_control.datasets.collage.segmentation import load_segmentation_model, unload_segmentation_model, generate_masks

load_segmentation_model(device)
masks = generate_masks(src, pack_result=True)
unload_segmentation_model()

describe(masks)

In [None]:
from flux_control.datasets.collage.depth import load_depth_model, unload_depth_model, estimate_depth

load_depth_model(device)
depth = estimate_depth(src)
unload_depth_model()

show_grayscale_colorbar(depth)

In [None]:
from flux_control.datasets.collage.affine import compute_transform_data_structured

selected_masks, dropped_masks = compute_transform_data_structured(flow, depth, masks, cfg=cfg)

describe(selected_masks, max_items=5)
describe(dropped_masks, max_items=5)

In [None]:
from flux_control.datasets.collage.affine import apply_transforms

affine, grid_affine, mask_affine_src, mask_affine_tgt = apply_transforms(
    src, depth, selected_masks, cfg=cfg
)

display(visualize_image(affine))
display(visualize_grid(grid_affine, mask_affine_tgt))
display(visualize_grayscale(mask_affine_src))
display(visualize_grayscale(mask_affine_tgt))

In [None]:
from einops import rearrange
import torch.nn.functional as F
grid_diff = grid_affine - grid_splat
grid_diff = torch.norm(grid_diff, dim=0, p=2)
grid_diff = torch.tanh(grid_diff * 10)
# grid_diff = 1 - torch.exp(-grid_diff * 5)
mask_bool = (mask_splat > 0.5) & (mask_affine_tgt > 0.5)
grid_diff = grid_diff * mask_bool + ~mask_bool * 1.0
confidence = F.avg_pool2d(rearrange(grid_diff, "h w -> 1 1 h w"), kernel_size=16, stride=16)
show_grayscale_colorbar(confidence, interpolation="nearest")
show_image_histogram(confidence, show_cdf=True)

In [None]:
import kornia.filters as KF
kernel_size = 51
sigma = 20.0
gaussian_blur = KF.GaussianBlur2d(
    kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma), border_type="reflect"
).to(device)
hint_mask = grid_diff > 0.5
hint_mask = rearrange(hint_mask, "h w -> 1 1 h w").float()
hint_mask = gaussian_blur(hint_mask)
hint_mask = hint_mask > 0.5
display(visualize_grayscale(hint_mask))

In [None]:
from flux_control.datasets.collage.palette import palette_downsample

hint_fg = palette_downsample(tgt, hint_mask * foreground, colors=4)
hint_bg = palette_downsample(tgt, hint_mask * (1 - foreground), colors=4)
hint = hint_fg + hint_bg
display(visualize_image(hint))
display(visualize_image(affine * mask_affine_tgt + hint * (1 - mask_affine_tgt)))

In [None]:
describe(hint_mask)

In [None]:
save_data = {
    "video_path": video_path,
    "src": src,
    "tgt": tgt,
    "flow": flow,
    "affine": affine,
    "mask_affine_tgt": mask_affine_tgt,
    "foreground": foreground,
    "confidence": confidence,
}

describe(save_data)

Model inputs:

1. Noisy Latent
2. Text Prompt Embedding
3. Coarse Edit Latent (F.1 Fill, D-Concat)
4. Coarse Edit Alpha Mask (F.1 Fill, Pixelshuffle then Concat)
5. Local Confidence Map (Train a new modulation module)
6. Color Hint Latent (F.1 Fill, D-Concat)