# Relipa grid generation

In [24]:
import numpy as np
import os, glob
import matplotlib.pyplot as plt
from PIL import Image
import json

mani_light = 'rotate_sh_axis=1'
sample_pair_json = '/home/mint/Dev/DiFaReli++/difareli_pp/experiment_scripts/TPAMI/sample_json/TPAMI_MajorRevision/rotateSH.json'
num_frames = 60
with open(sample_pair_json, 'r') as f:
    sample_pairs = json.load(f)['pair']
    sample_pairs_k = [k for k in sample_pairs.keys()]
    sample_pairs_v = [v for v in sample_pairs.values()]


def gen(fid, src, dst, fn):
    rows = []
    for scale_sh in [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
        cols = []
        for gs in [2.0, 4.0, 4.5, 6.0, 8.0]:
            sampling_path = f'/data/mint/TPAMI_MajorRevision/New_Baselines_Tuning/Relipa/ffhq_tuning/{mani_light}/src={src}_dst={dst}/scale_sh={scale_sh}/gs={gs}_ds=25/n_step={num_frames}/'
            imgs = sorted(glob.glob(f'{sampling_path}/{fn}_frame*.png'))
            if len(imgs) == 0:
                frame = None
            else:
                frame = imgs[fid]
            cols.append(frame)
        rows.append(cols)
    return rows

for idx in [0, 1]:
    pair = sample_pairs_v[idx]
    pair_id = sample_pairs_k[idx]
    src = pair['src']
    dst = pair['dst']
    os.makedirs(f'./tuning_grid_relipa/{mani_light}_grid_{src}/', exist_ok=True)
    
    frame_id = [[3, 20, 33, 50, 55], [2, 13, 38, 48, 54]]
    for fn in ['res', 'ren']:
        for fi in frame_id[idx]:
            fn_grid = gen(fi, src, dst, fn)
            
            # Load image
            out_grid = []
            for i in range(len(fn_grid)):
                col = []
                for j in range(len(fn_grid[i])):
                    if fn_grid[i][j] is not None:
                        img = Image.open(fn_grid[i][j])
                        img = img.resize((256, 256), Image.LANCZOS)
                    else:
                        img = Image.new('RGB', (256, 256), color='black')
                    col.append(np.array(img))
                out_grid.append(np.concatenate(col, axis=1))
            out_grid = np.concatenate(out_grid, axis=0)
            
            # plt.figure(figsize=(10, 10))
            # plt.imshow(out_grid)
            # plt.show()
            # plt.close()
            
            Image.fromarray(out_grid).save(f'./tuning_grid_relipa/{mani_light}_grid_{src}/{fn}_fid={fi}.png')

# Videos

In [23]:
import numpy as np
import os, glob
import matplotlib.pyplot as plt
from PIL import Image
import json
import subprocess

mani_light = 'rotate_sh_axis=1'
sample_pair_json = '/home/mint/Dev/DiFaReli++/difareli_pp/experiment_scripts/TPAMI/sample_json/TPAMI_MajorRevision/rotateSH.json'
num_frames = 60
with open(sample_pair_json, 'r') as f:
    sample_pairs = json.load(f)['pair']
    sample_pairs_k = [k for k in sample_pairs.keys()]
    sample_pairs_v = [v for v in sample_pairs.values()]


def gen(src, dst, fn):
    rows = []
    for scale_sh in [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
        cols = []
        for gs in [2.0, 4.0, 4.5, 6.0, 8.0]:
            sampling_path = f'/data/mint/TPAMI_MajorRevision/New_Baselines_Tuning/Relipa/ffhq_tuning/{mani_light}/src={src}_dst={dst}/scale_sh={scale_sh}/gs={gs}_ds=25/n_step={num_frames}/'
            vid = f'{sampling_path}/{fn}_rt.mp4'
            cols.append(vid)
        rows.append(cols)
    return rows

def make_grid_from_2d(video_grid, out_path, fps=None, quiet=True, size=256, keep_tmp=False):
    """
    Build a grid video from a 2D array of video paths using ffmpeg.

    - video_grid: list[list[str]]  (rectangular: all rows same length)
    - Every input is resized to exactly `size x size` (no AR preservation).
    - Output: H.264 mp4 with yuv420p at `out_path`.
    - If `fps` is provided, the output frame rate is forced to `fps`.
    - If `quiet` is True, ffmpeg logs are minimized.
    - Set `keep_tmp=True` to keep row videos for debugging.
    """
    assert video_grid and all(video_grid), "video_grid must be non-empty 2D list"
    rows = len(video_grid)
    cols = len(video_grid[0])
    assert all(len(r) == cols for r in video_grid), "all rows must have the same length"

    # Validate inputs exist
    for r in video_grid:
        for p in r:
            if not os.path.exists(p):
                raise FileNotFoundError(f"Input not found: {p}")

    out_path = str(out_path)
    tmpdir = os.path.join(os.path.dirname(out_path) or ".", ".grid_tmp")
    os.makedirs(tmpdir, exist_ok=True)

    ff_quiet = ["-hide_banner", "-loglevel", "error"] if quiet else []

    row_files = []
    # 1) Build each row (hstack)
    for i, row in enumerate(video_grid):
        # inputs
        cmd = ["ffmpeg", "-y"]
        for p in row:
            cmd += ["-i", p]
        cmd += ff_quiet

        # filter graph: scale each input -> hstack
        # [0:v]scale=256:256[v0];[1:v]scale=256:256[v1];...;[v0][v1]...hstack=inputs=cols:shortest=1[row]
        parts = []
        for k in range(cols):
            parts.append(f"[{k}:v]scale={size}:{size}[v{k}]")
        parts.append("".join(f"[v{k}]" for k in range(cols)) + f"hstack=inputs={cols}:shortest=1[row]")
        fgraph = ";".join(parts)

        row_out = os.path.join(tmpdir, f"r_{i}.mp4")
        row_cmd = cmd + [
            "-filter_complex", fgraph,
            "-map", "[row]",
        ]
        if fps is not None:
            row_cmd += ["-r", str(fps)]
        row_cmd += ["-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart", row_out]

        subprocess.run(row_cmd, check=True)
        row_files.append(row_out)

    # 2) Stack rows vertically (vstack)
    cmd = ["ffmpeg", "-y"]
    for p in row_files:
        cmd += ["-i", p]
    cmd += ff_quiet

    # All row videos have same width (cols*size) and height (size), so we can vstack directly:
    # [0:v][1:v]...vstack=inputs=rows:shortest=1[out]
    fgraph = "".join(f"[{k}:v]" for k in range(rows)) + f"vstack=inputs={rows}:shortest=1[out]"
    final_cmd = cmd + [
        "-filter_complex", fgraph,
        "-map", "[out]",
    ]
    if fps is not None:
        final_cmd += ["-r", str(fps)]
    final_cmd += ["-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart", out_path]

    subprocess.run(final_cmd, check=True)

for idx in [0, 1]:
    pair = sample_pairs_v[idx]
    pair_id = sample_pairs_k[idx]
    src = pair['src']
    dst = pair['dst']
    os.makedirs(f'./tuning_grid_relipa/{mani_light}_grid_{src}/', exist_ok=True)
    
    for fn in ['res', 'ren']:
        fn_grid = gen(src, dst, fn)
        # From fn grid of each video, save the video as a big grid using ffmpeg (intermidiate row or col can be save in tmp folder)
        make_grid_from_2d(
        fn_grid,
        out_path=f'./tuning_grid_relipa/{mani_light}_grid_{src}/{fn}_{src}.mp4',
        fps=24,              # inherit from inputs; set e.g. 20 to force
        quiet=True             # hides ffmpeg output
    )