# LongCat Interactive Inference Notebook

This notebook enables persistent model loading for LongCat video generation. 
Models are loaded once and remain in VRAM, allowing multiple inference runs without reloading.

## Usage
1. **Cell 1 (Setup)**: Run once to import libraries and define helper functions.
2. **Cell 2 (Model Loading)**: Run once to load all models into VRAM (~40GB).
3. **Cell 3 (Inference)**: Run repeatedly with different parameters. Supports batch inference and optional 720p upscaling.

## Data Directory Structure
```
data/scene_name/
├── imgs/              # Input frames and masks (VIDEO_REF_DIR points here)
│   ├── 00000.png      # Video frames
│   ├── 00001.png
│   ├── mask_00000.png # Binary masks (white=inpaint region)
│   └── mask_00001.png
└── ref/               # (Optional) High-res reference frame for 720p upscaling
    └── 00000.png
```

## Notes
- Requires a single GPU with at least 40GB VRAM (e.g., A100/H100/A800).
- First inference may be slower due to CUDA kernel compilation.
- Modify `CHECKPOINT_DIR` in Cell 2 to point to your model weights directory.


In [None]:
# ============================================================
# Cell 1: Imports & Setup (run once)
# ============================================================

import sys
import os

# Add project directory to Python path
PROJECT_DIR = os.path.dirname(os.path.abspath("__file__"))  # auto-detect project root
if PROJECT_DIR not in sys.path:
    sys.path.insert(0, PROJECT_DIR)
os.chdir(PROJECT_DIR)

import datetime
import PIL.Image
import numpy as np
import glob
import torch
import torch.distributed as dist
from scipy.ndimage import distance_transform_edt
import torchvision.io as tvio

from transformers import AutoTokenizer, UMT5EncoderModel
from torchvision.io import write_video
from diffusers.utils import load_image

from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
from prompts import get_prompt, list_available_scenes
from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
from longcat_video.context_parallel import context_parallel_util
from longcat_video.context_parallel.context_parallel_util import init_context_parallel

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

def torch_gc():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

def load_video(video_path):
    """Load video file and return as list of PIL Images."""
    video, audio, info = tvio.read_video(video_path, pts_unit='sec')
    return [PIL.Image.fromarray(video[i].numpy()) for i in range(video.shape[0])]

def read_frames_from_directory(directory):
    """Read video frames and corresponding masks from a directory.
    
    Frames: regular image files (e.g., 00000.png)
    Masks: files prefixed with 'mask_' (e.g., mask_00000.png)
    """
    print(f"Reading frames from: {directory}")
    
    image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
    all_files = []
    for ext in image_extensions:
        all_files.extend(glob.glob(os.path.join(directory, ext)))
    all_files = sorted(all_files)
    
    if not all_files:
        raise ValueError(f"No images found in {directory}")
    
    frame_files = [f for f in all_files if not os.path.basename(f).startswith('mask_')]
    mask_files = [f for f in all_files if os.path.basename(f).startswith('mask_')]
    
    frames = [PIL.Image.open(f).convert('RGB') for f in frame_files]
    masks = [PIL.Image.open(f).convert('L') for f in mask_files]
    
    if not masks and frames:
        print("No mask images found, creating zero masks")
        masks = [PIL.Image.new('L', frames[0].size, 0) for _ in frames]
    
    if len(masks) != len(frames):
        print(f"Warning: mask count ({len(masks)}) != frame count ({len(frames)})")
        while len(masks) < len(frames):
            masks.append(masks[-1] if masks else PIL.Image.new('L', frames[0].size, 0))
        masks = masks[:len(frames)]
    
    first_frame = frames[0] if frames else None
    first_frame_path = frame_files[0] if frame_files else None
    print(f"Loaded {len(frames)} frames and {len(masks)} masks")
    return frames, masks, first_frame, first_frame_path

def soften_mask(mask_array, transition_distance=15, decay_type='sine'):
    """Soften mask boundaries with smooth distance-based transitions."""
    softened_mask = mask_array.copy().astype(np.float32)
    
    for frame_idx in range(mask_array.shape[0]):
        current_mask = mask_array[frame_idx].astype(bool)
        if np.all(current_mask) or np.all(~current_mask):
            continue
        
        softened_frame = mask_array[frame_idx].copy().astype(np.float32)
        distance_from_ones = distance_transform_edt(current_mask)
        ones_transition = current_mask & (distance_from_ones <= transition_distance)
        
        def smooth_transition(t, dtype):
            t = np.clip(t, 0.0, 1.0)
            if dtype == 'linear':    return t
            elif dtype == 'exponential': return 1.0 - np.exp(-3.0 * t)
            elif dtype == 'sine':    return np.sin(np.pi / 2 * t)
            elif dtype == 'cosine':  return 1.0 - np.cos(np.pi / 2 * t)
            else: raise ValueError(f"Unsupported decay type: {dtype}")
        
        if np.any(ones_transition):
            distances = distance_from_ones[ones_transition]
            softened_frame[ones_transition] = smooth_transition(distances / transition_distance, decay_type)
        softened_mask[frame_idx] = softened_frame
    
    return softened_mask

print("Cell 1 complete: environment configured, helper functions defined.")


In [None]:
# ============================================================
# Cell 2: Model Loading (run once)
# ============================================================
# Models will remain in VRAM after this cell completes.

# ==================== Configuration ====================
CHECKPOINT_DIR = "/path/to/LongCat-Video"  # Path to model weights
CONTEXT_PARALLEL_SIZE = 1
ENABLE_COMPILE = False    # Enable torch.compile (slower first run)
USE_DISTILL = False       # Use 16-step distillation mode
# =======================================================

# Performance settings
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
if hasattr(torch._dynamo.config, 'cache_size_limit'):
    torch._dynamo.config.cache_size_limit = 128

# Single GPU setup
local_rank = 0
global_rank = 0
num_processes = 1

if torch.cuda.is_available():
    torch.cuda.set_device(local_rank)
    print(f"Using GPU {local_rank}: {torch.cuda.get_device_name(local_rank)}")
else:
    raise RuntimeError("CUDA is not available.")

# Initialize distributed environment (single-process mode)
if not dist.is_initialized():
    import socket
    def find_free_port():
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(('', 0))
            s.listen(1)
            return s.getsockname()[1]
    
    master_port = str(find_free_port())
    os.environ.update({
        'MASTER_ADDR': 'localhost', 'MASTER_PORT': master_port,
        'RANK': '0', 'LOCAL_RANK': '0', 'WORLD_SIZE': '1'
    })
    dist.init_process_group(
        backend="nccl", init_method=f"tcp://localhost:{master_port}",
        rank=0, world_size=1, timeout=datetime.timedelta(seconds=3600)
    )

# Initialize context parallel
init_context_parallel(context_parallel_size=CONTEXT_PARALLEL_SIZE, global_rank=global_rank, world_size=num_processes)
cp_size = context_parallel_util.get_cp_size()
cp_split_hw = context_parallel_util.get_optimal_split(cp_size)

# Load models (suppress harmless LOAD REPORT from transformers 5.x about T5 weight-tying)
print("Loading models...")
import transformers as _tf; _prev_verbosity = _tf.logging.get_verbosity(); _tf.logging.set_verbosity_error()

tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer")
text_encoder = UMT5EncoderModel.from_pretrained(CHECKPOINT_DIR, subfolder="text_encoder", torch_dtype=torch.bfloat16)
print("Tokenizer + Text encoder loaded")

vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch.bfloat16)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler")
print("VAE + Scheduler loaded")

dit = LongCatVideoTransformer3DModel.from_pretrained(CHECKPOINT_DIR, subfolder="dit", cp_split_hw=cp_split_hw, torch_dtype=torch.bfloat16)
print("DiT model loaded")
_tf.logging.set_verbosity(_prev_verbosity)

# Pre-load all LoRAs (loaded once, toggle enable/disable later)
cfg_step_lora_path = os.path.join(CHECKPOINT_DIR, 'lora/cfg_step_lora.safetensors')
dit.load_lora(cfg_step_lora_path, 'cfg_step_lora')

refinement_lora_path = os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors')
dit.load_lora(refinement_lora_path, 'refinement_lora')
print("LoRAs loaded (cfg_step_lora + refinement_lora)")

if USE_DISTILL:
    dit.enable_loras(['cfg_step_lora'])
    print("Distill mode: cfg_step_lora enabled")
else:
    dit.disable_all_loras()
    print("Standard mode: all LoRAs disabled")

if ENABLE_COMPILE:
    print("Compiling DiT model...")
    dit = torch.compile(dit)

# Create pipeline
pipe = LongCatVideoPipeline(
    tokenizer=tokenizer, text_encoder=text_encoder,
    vae=vae, scheduler=scheduler, dit=dit,
)
pipe.to(local_rank)

print(f"\nModels loaded. VRAM: {torch.cuda.memory_allocated()/1024**3:.2f} GB")


In [None]:
# ============================================================
# Cell 3: Batch Inference (run repeatedly with different params)
# ============================================================

# ==================== Parameters (modify as needed) ====================
VIDEO_REF_DIR = "/path/to/data/scene_name/imgs"   # Directory with frames + mask_ files
OUTPUT_BASE_DIR = "/path/to/output"               # Output base directory
SCENE_NAME = "coffee"            # Scene name (for prompt lookup, see prompts.py)
RESOLUTION = "480p"             # Inference resolution
NUM_FRAMES = 29                 # Number of frames to generate
STATIC_MODE = False             # Static scene mode (less motion)
FPS = 16                        # Output video FPS
NUM_INFERENCE_STEPS = 50        # Diffusion steps (16 for distill, 50 for standard)

# Parameter grid for batch inference
max_channels_list = [1]         # Max FLF replacement channels
guidance_scales_list = [4]      # CFG scale
omegas_list = [4]               # Auto-guidance omega
transition_distances_list = [15] # Mask softening distance (0=no softening)
step_additions_list = [0]       # Addition to guide_steps for resample_round
step_guide_list = [20]          # Guide steps
seeds_list = [42]               # Random seeds

# Fixed parameters
RESAMPLE_STEPS = 2
USE_PCA_CHANNEL_SELECTION = True
ENABLE_SOFTEN_MASK = True
DECAY_TYPE = "sine"
SAVE_PNG = False

# Upscaling parameters
ENABLE_UPSCALE = True          # Enable 720p upscaling
T_THRESH = 0.6                 # Upscale denoise threshold (0.0-1.0)
# =======================================================================

import itertools

param_combinations = list(itertools.product(
    guidance_scales_list, max_channels_list, transition_distances_list,
    step_additions_list, omegas_list, step_guide_list, seeds_list
))

print(f"Total {len(param_combinations)} parameter combinations")
if ENABLE_UPSCALE:
    print(f"Upscaling: enabled (t_thresh={T_THRESH})")
print("=" * 60)

def read_ref_frame_for_upscale(video_ref_dir):
    """Read the first high-res reference frame from ref/ directory (sibling of imgs/)."""
    ref_dir = os.path.join(os.path.dirname(video_ref_dir), "ref")
    if not os.path.exists(ref_dir):
        print(f"Warning: ref directory not found: {ref_dir}")
        return None, None
    
    ref_files = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']:
        ref_files.extend(glob.glob(os.path.join(ref_dir, ext)))
    ref_files = sorted(ref_files)
    
    if not ref_files:
        print(f"Warning: No reference images found in {ref_dir}")
        return None, None
    
    ref_frame = PIL.Image.open(ref_files[0]).convert('RGB')
    print(f"Loaded reference frame: {ref_files[0]} ({ref_frame.size[0]}x{ref_frame.size[1]})")
    return ref_frame, ref_files[0]

def run_single_inference(params):
    g, mc, m, addition, o, step, seed_val = params
    r_r = step + addition
    
    # Build output filename
    distill_suffix = "dt16" if USE_DISTILL else "fu50"
    output_file = f"{OUTPUT_BASE_DIR}/{distill_suffix}/{RESOLUTION}/o{o}_re{RESAMPLE_STEPS}_guide{step}_round{r_r}_mask{m}_max{mc}_seed{seed_val}.mp4"
    
    # Skip if already exists
    if os.path.exists(output_file):
        if ENABLE_UPSCALE:
            output_720p_file = f"{os.path.splitext(output_file)[0]}_720p.mp4"
            if os.path.exists(output_720p_file):
                print(f"Skipping (exists): {os.path.basename(output_file)}")
                return True
        else:
            print(f"Skipping (exists): {os.path.basename(output_file)}")
            return True
    
    print(f"\n{'='*60}")
    print(f"omega={o}, resample={RESAMPLE_STEPS}, guide_steps={step}, round={r_r}, mask={m}, max_replace={mc}, seed={seed_val}")
    print(f"Output: {output_file}")
    print("=" * 60)
    
    try:
        frames, masks, first_frame, _ = read_frames_from_directory(VIDEO_REF_DIR)
        
        image = first_frame
        if image is None:
            raise ValueError("No first frame available")
        
        # Compute target dimensions
        scale_factor_spatial = pipe.vae_scale_factor_spatial * 2
        if pipe.dit.cp_split_hw is not None:
            scale_factor_spatial *= max(pipe.dit.cp_split_hw)
        height, width = pipe.get_condition_shape(image, RESOLUTION, scale_factor_spatial=scale_factor_spatial)
        
        # Process video frames and masks
        resized_frames = [frame.resize((width, height)) for frame in frames]
        video_frames = torch.stack([
            torch.tensor(np.array(frame)).permute(2, 0, 1).float() / 255.0
            for frame in resized_frames
        ])
        video_ref = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)
        
        mask = None
        if masks:
            resized_masks = [mk.resize((width, height)) for mk in masks]
            mask_array = np.stack([np.array(mk) / 255.0 for mk in resized_masks])
            if ENABLE_SOFTEN_MASK and m > 0:
                mask_array = soften_mask(mask_array, m, DECAY_TYPE)
            mask = torch.from_numpy(mask_array).unsqueeze(0).unsqueeze(0)
        
        prompt = get_prompt(SCENE_NAME)
        negative_prompt = (
            "Blink, twinkle, waggle, speak, wind, windy, leaves shaking, leaves tremble, "
            "background dynamics, dynamic imagery, gray sky, hazy sky, overcast, "
            "gloomy sky, dim, murky, smoggy, shake, object motion blur, streaking objects, "
            "object jitter, camera shake, illogical composition, bright tones, "
            "overexposed, blurred details, subtitles, text, logo, worst quality, "
            "low quality, ugly, incomplete, sudden scene shift, incoherent scene jump, "
            "extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, "
            "misshapen limbs, fused fingers, any movement, character motion, "
            "object vibration, messy background, scene changes, object disintegration."
        ) if STATIC_MODE else (
            "Streaking objects, mosaic, grainy, pixelated, noise, flickering, cropped, glitch, "
            "fragmented, broken, artifacts, chromatic aberration, camera shake, "
            "blurry, sudden scene shift, incoherent scene jump, sudden object appearance, "
            "blinking, object jitter, illogical composition, bright tones, overexposed, "
            "blurred details, subtitles, worst quality, low quality, "
            "ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, "
            "deformed, disfigured, misshapen limbs, fused fingers, messy background"
        )
        
        generator = torch.Generator(device='cpu')
        generator.manual_seed(seed_val)
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        
        # Run 480p inference
        output = pipe.generate_i2v(
            image=image,
            prompt=prompt,
            negative_prompt=negative_prompt,
            resolution=RESOLUTION,
            num_frames=NUM_FRAMES,
            num_inference_steps=NUM_INFERENCE_STEPS,
            use_distill=USE_DISTILL,
            guidance_scale=1.0 if USE_DISTILL else g,
            num_videos_per_prompt=1,
            generator=generator,
            video_ref=video_ref,
            mask=mask,
            guided=True,
            resample_steps=RESAMPLE_STEPS,
            guide_steps=step,
            resample_round=r_r,
            omega=o,
            omega_resample=o,
            use_pca_channel_selection=USE_PCA_CHANNEL_SELECTION,
            static=STATIC_MODE,
            max_replace_threshold=mc,
        )[0]
        
        # Save 480p video
        output_pil = [PIL.Image.fromarray((output[i] * 255).astype(np.uint8)) for i in range(output.shape[0])]
        output_tensor = torch.from_numpy(np.array([np.array(img) for img in output_pil]))
        write_video(output_file, output_tensor, fps=FPS, video_codec="libx264", options={"crf": "10"})
        
        if SAVE_PNG:
            png_dir = os.path.join(os.path.dirname(output_file), f"imgs_{os.path.splitext(os.path.basename(output_file))[0]}")
            os.makedirs(png_dir, exist_ok=True)
            for i, frame in enumerate(output_pil):
                frame.save(os.path.join(png_dir, f"{i:05d}.png"))
        
        print(f"480p done: {os.path.basename(output_file)}")
        
        # Optional 720p upscaling
        if ENABLE_UPSCALE:
            print("Starting 720p upscaling...")
            output_720p_file = f"{os.path.splitext(output_file)[0]}_720p.mp4"
            
            ref_frame, _ = read_ref_frame_for_upscale(VIDEO_REF_DIR)
            if ref_frame is None:
                print("Warning: no reference frame found, skipping upscale")
            else:
                # Switch to refinement LoRA
                pipe.dit.disable_all_loras()
                pipe.dit.enable_loras(['refinement_lora'])
                pipe.dit.enable_bsa()
                
                video_frames_480p = load_video(output_file)
                print(f"Loaded {len(video_frames_480p)} frames for upscaling")
                
                # Compute model-compatible target resolution
                ref_width, ref_height = ref_frame.size
                sf = pipe.vae_scale_factor_spatial * 2
                if pipe.dit.cp_split_hw is not None:
                    sf *= max(pipe.dit.cp_split_hw)
                target_h = (ref_height // sf) * sf
                target_w = (ref_width // sf) * sf
                ref_frame_resized = ref_frame.resize((target_w, target_h))
                print(f"Upscale target: {target_w}x{target_h}")
                
                generator_upscale = torch.Generator(device='cpu')
                generator_upscale.manual_seed(seed_val)
                
                output_refine = pipe.generate_refine(
                    image=ref_frame_resized,
                    prompt=prompt,
                    stage1_video=video_frames_480p,
                    num_cond_frames=1,
                    num_inference_steps=50,
                    generator=generator_upscale,
                    spatial_refine_only=True,
                    t_thresh=T_THRESH,
                )[0]
                
                # Restore LoRA state
                pipe.dit.disable_all_loras()
                pipe.dit.disable_bsa()
                if USE_DISTILL:
                    pipe.dit.enable_loras(['cfg_step_lora'])
                
                # Save 720p video
                output_720p_frames = [(output_refine[i] * 255).astype(np.uint8) for i in range(output_refine.shape[0])]
                output_720p_tensor = torch.from_numpy(np.array(output_720p_frames))
                write_video(output_720p_file, output_720p_tensor, fps=FPS, video_codec="libx264", options={"crf": "10"})
                
                if SAVE_PNG:
                    png_720p_dir = os.path.join(os.path.dirname(output_720p_file), f"imgs_{os.path.splitext(os.path.basename(output_720p_file))[0]}")
                    os.makedirs(png_720p_dir, exist_ok=True)
                    for i, frame_data in enumerate(output_720p_frames):
                        PIL.Image.fromarray(frame_data).save(os.path.join(png_720p_dir, f"{i:05d}.png"))
                
                print(f"720p done: {os.path.basename(output_720p_file)}")
        
        torch_gc()
        return True
        
    except Exception as e:
        import traceback
        print(f"Error: {e}")
        traceback.print_exc()
        torch_gc()
        return False

# Run batch inference
print(f"\nStarting batch inference...")
success_count = 0
fail_count = 0

for idx, params in enumerate(param_combinations):
    print(f"\n[{idx+1}/{len(param_combinations)}] ", end="")
    if run_single_inference(params):
        success_count += 1
    else:
        fail_count += 1

print(f"\nBatch complete: {success_count} success, {fail_count} failed. VRAM: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
