In [None]:
!pip install accelerate transformers diffusers torch torchvision pillow tqdm



In [None]:
import torch
import torch.nn as nn
from diffusers import StableVideoDiffusionPipeline, AutoencoderKL, ControlNetModel, DDPMScheduler
from diffusers.models import UNetSpatioTemporalConditionModel
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from torch.optim import AdamW
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import cv2
from pathlib import Path
from tqdm import tqdm
import gc
from torchvision import transforms
from accelerate import Accelerator
from typing import Dict, Any

In [None]:
# ----------------- 1. Configuration and Setup -----------------

MODEL_ID = "stabilityai/stable-video-diffusion-img2vid-xt"
DATASET_ROOT = "/content/data/rovi-reduc"  # Fixed: Consistent path
RESOLUTION = 400
BATCH_SIZE = 1
NUM_FRAMES = 8
NUM_EPOCHS = 10
GRADIENT_ACCUMULATION_STEPS = 2
CHECKPOINT_SAVE_EPOCHS = 2

In [None]:
# Initialize Accelerator
accelerator = Accelerator(
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    mixed_precision='fp16'
)
device = accelerator.device

print(f"Using device: {device}")

Using device: cuda


In [None]:
# ----------------- 2. Dataset Implementation -----------------

class RoVITrackingInpaintingDataset(Dataset):
    """
    Video tracking + inpainting dataset.
    Input: Original video + First frame mask only
    Target: Edited video + All frame masks
    """
    def __init__(self, dataset_path, sequence_length=8, image_size=(256, 256)):
        self.dataset_path = Path(dataset_path)
        self.jpeg_dir = self.dataset_path / 'JPEGImages'
        self.inpaint_dir = self.dataset_path / 'InpaintImages'
        self.annotations_dir = self.dataset_path / 'Annotations'
        self.sequence_length = sequence_length
        self.image_size = image_size

        if not self.jpeg_dir.exists():
            raise ValueError(f"JPEGImages not found")
        if not self.inpaint_dir.exists():
            raise ValueError(f"InpaintImages not found")
        if not self.annotations_dir.exists():
            raise ValueError(f"Annotations not found")

        all_sequences = sorted([p.name for p in self.jpeg_dir.iterdir() if p.is_dir()])
        self.sequences = [
            seq for seq in all_sequences
            if (self.inpaint_dir / seq / '1').is_dir()
            and (self.annotations_dir / seq).is_dir()
        ]

        if len(self.sequences) == 0:
            raise ValueError(f"No valid sequences found")

        print(f"Found {len(self.sequences)} sequences for tracking + inpainting")

        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3),
        ])

        self.mask_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        video_name = self.sequences[idx]

        original_frames_path = self.jpeg_dir / video_name
        edited_frames_path = self.inpaint_dir / video_name / '1'
        mask_frames_path = self.annotations_dir / video_name

        frame_files = sorted([f.name for f in original_frames_path.glob('*.jpg')])
        frame_files = frame_files[:self.sequence_length]

        if len(frame_files) < self.sequence_length:
            frame_files = frame_files + [frame_files[-1]] * (self.sequence_length - len(frame_files))

        original_video = []
        edited_video = []
        mask_video = []

        for filename in frame_files:
            # ORIGINAL (input)
            original_img = Image.open(original_frames_path / filename).convert("RGB")
            original_video.append(self.transform(original_img))

            # EDITED (target for inpainting)
            edited_img = Image.open(edited_frames_path / filename).convert("RGB")
            edited_video.append(self.transform(edited_img))

            # MASK (target for tracking)
            mask_filename = filename.replace('.jpg', '.png')
            try:
                mask_img = Image.open(mask_frames_path / mask_filename).convert("L")
            except FileNotFoundError:
                mask_img = Image.open(mask_frames_path / filename).convert("L")

            mask_tensor = self.mask_transform(mask_img)
            mask_video.append((mask_tensor > 0.1).float())

        # Stack and permute to (C, T, H, W)
        original_video = torch.stack(original_video).permute(1, 0, 2, 3)
        edited_video = torch.stack(edited_video).permute(1, 0, 2, 3)
        mask_video = torch.stack(mask_video).permute(1, 0, 2, 3)

        # First frames
        original_first_frame = original_video[:, 0, :, :]
        first_mask = mask_video[:, 0:1, :, :]  # Only first mask (1, H, W)

        return {
            "original_video": original_video,
            "original_first_frame": original_first_frame,
            "first_mask": first_mask,           # INPUT: only first mask
            "edited_video": edited_video,        # TARGET: inpainted video
            "mask_video": mask_video,            # TARGET: all masks for tracking
        }

In [None]:
# ----------------- 3. Model Components -----------------

class SelectiveContentEncoder(nn.Module):
    """
    SCE: Encodes first frame mask into features.
    Input: First frame mask (B, 1, H, W)
    Output: Encoded features (B, feature_dim)
    """
    def __init__(self, feature_dim=256):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, feature_dim),
        )
        print(f"SCE initialized with {feature_dim} feature dimensions")

    def forward(self, first_mask):
        """first_mask: (B, 1, H, W) -> features: (B, feature_dim)"""
        return self.encoder(first_mask)

class ControlAdapter3D(nn.Module):
    """
    Control Adapter: Processes mask video for inpainting guidance.
    Input: Mask video (B, 1, T, H, W)
    Output: Control features (B, 4, T, H', W')
    """
    def __init__(self, in_channels=1, base_channels=32, latent_channels=4):
        super().__init__()

        self.conv_in = nn.Conv3d(in_channels, base_channels, kernel_size=3, padding=1)
        self.to_latent = nn.Conv3d(base_channels, latent_channels, kernel_size=1)
        self.zero_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)

        nn.init.zeros_(self.zero_conv.weight)
        nn.init.zeros_(self.zero_conv.bias)

        print(f"ControlAdapter3D: {base_channels} base -> {latent_channels} latent channels")

    def forward(self, mask_video):
        h = self.conv_in(mask_video)
        h = self.to_latent(h)
        h = self.zero_conv(h)
        return h

class MaskPredictionDecoder(nn.Module):
    """
    MPD: Predicts masks for all frames from latent features.
    Input: Latent features (B, 4, T, H, W) + SCE features
    Output: Predicted masks (B, 1, T, H, W)
    """
    def __init__(self, latent_channels=4, sce_feature_dim=256):
        super().__init__()

        # Project SCE features to spatial
        self.sce_projection = nn.Sequential(
            nn.Linear(sce_feature_dim, 32 * 32),
            nn.Unflatten(1, (1, 32, 32)),
        )

        # Combine latents + SCE features -> masks
        self.decoder = nn.Sequential(
            nn.Conv3d(latent_channels + 1, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(16, 1, kernel_size=1),
            nn.Sigmoid(),  # Output binary masks
        )
        print("MPD initialized")

    def forward(self, latent_features, sce_features, num_frames):
        """
        latent_features: (B, 4, T, H, W)
        sce_features: (B, feature_dim)
        -> predicted_masks: (B, 1, T, H, W)
        """
        B = latent_features.shape[0]

        # Project SCE to spatial and expand to all frames
        sce_spatial = self.sce_projection(sce_features)  # (B, 1, H, W)
        sce_spatial = F.interpolate(sce_spatial, size=latent_features.shape[-2:], mode='bilinear')
        sce_spatial = sce_spatial.unsqueeze(2).repeat(1, 1, num_frames, 1, 1)  # (B, 1, T, H, W)

        # Concatenate and decode
        combined = torch.cat([latent_features, sce_spatial], dim=1)  # (B, 5, T, H, W)
        masks = self.decoder(combined)  # (B, 1, T, H, W)

        return masks

In [None]:
# ----------------- 4. Model Loading -----------------

print("Loading models...")

try:
    # Load models directly to GPU with low_cpu_mem_usage=False to avoid meta device
    vae = AutoencoderKL.from_pretrained(
        MODEL_ID,
        subfolder="vae",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=False
    ).to(device)

    unet = UNetSpatioTemporalConditionModel.from_pretrained(
        MODEL_ID,
        subfolder="unet",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=False
    ).to(device)

    image_encoder = CLIPVisionModelWithProjection.from_pretrained(
        MODEL_ID,
        subfolder="image_encoder",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=False
    ).to(device)

    print("Models loaded successfully!")

except Exception as e:
    print(f"Error loading models: {e}")
    raise


Loading models...


Some weights of the model checkpoint at stabilityai/stable-video-diffusion-img2vid-xt were not used when initializing AutoencoderKL: 
 ['decoder.up_blocks.1.resnets.2.temporal_res_block.norm2.bias, decoder.up_blocks.1.resnets.0.temporal_res_block.norm2.bias, decoder.up_blocks.2.resnets.0.temporal_res_block.norm1.bias, decoder.mid_block.resnets.0.spatial_res_block.conv2.weight, decoder.up_blocks.2.resnets.2.temporal_res_block.conv2.bias, decoder.up_blocks.1.resnets.1.temporal_res_block.norm1.weight, decoder.up_blocks.3.resnets.1.spatial_res_block.conv2.weight, decoder.up_blocks.2.resnets.1.spatial_res_block.conv1.bias, decoder.up_blocks.1.resnets.2.temporal_res_block.norm1.weight, decoder.up_blocks.2.resnets.0.temporal_res_block.conv2.bias, decoder.mid_block.resnets.0.time_mixer.mix_factor, decoder.up_blocks.3.resnets.1.temporal_res_block.norm1.bias, decoder.up_blocks.0.resnets.0.temporal_res_block.conv1.weight, decoder.up_blocks.0.resnets.1.temporal_res_block.norm1.bias, decoder.up_blo

Models loaded successfully!


In [None]:
# Freeze base models
vae.requires_grad_(False)
unet.requires_grad_(False)
image_encoder.requires_grad_(False)

# Initialize SCE
sce = SelectiveContentEncoder(feature_dim=256).to(device)

# Initialize Control Adapter
control_adapter = ControlAdapter3D(in_channels=1, base_channels=32, latent_channels=4).to(device)

# Initialize MPD
mpd = MaskPredictionDecoder(latent_channels=4, sce_feature_dim=256).to(device)

# Trainable parameters
sce.train()
control_adapter.train()
mpd.train()

trainable_params = list(sce.parameters()) + list(control_adapter.parameters()) + list(mpd.parameters())
print(f"\nTotal Trainable Parameters: {sum(p.numel() for p in trainable_params):,}")

optimizer = AdamW(trainable_params, lr=1e-5, weight_decay=0.01)

# Load noise scheduler
noise_scheduler = DDPMScheduler.from_pretrained(MODEL_ID, subfolder="scheduler")

# Prepare with accelerator
sce, control_adapter, mpd, optimizer = accelerator.prepare(
    sce, control_adapter, mpd, optimizer
)

# Keep frozen models in eval
vae.eval()
unet.eval()
image_encoder.eval()

print("All models ready for training!")

SCE initialized with 256 feature dimensions
ControlAdapter3D: 32 base -> 4 latent channels
MPD initialized

Total Trainable Parameters: 408,121
All models ready for training!


In [None]:
# ----------------- 5. Data Loading -----------------

from google.colab import drive
import os, zipfile

drive.mount('/content/drive/')

# Extract data
zip_path = "/content/drive/MyDrive/data-reduc.zip"
extract_to = "/content/data"

print("Extracting data...")
os.makedirs(extract_to, exist_ok=True)

if not os.path.exists(DATASET_ROOT):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        for file in tqdm(zip_ref.infolist(), desc="Extracting"):
            zip_ref.extract(file, extract_to)
else:
    print("Data already extracted, skipping...")

# Create dataset and dataloader
try:
    dataset = RoVITrackingInpaintingDataset(DATASET_ROOT, NUM_FRAMES, (RESOLUTION, RESOLUTION))  # Changed class name!
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    dataloader = accelerator.prepare(dataloader)
    print(f"Dataset loaded: {len(dataset)} sequences")
except Exception as e:
    print(f"Error loading dataset: {e}")
    raise

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
Extracting data...


Extracting: 100%|██████████| 1203/1203 [00:01<00:00, 605.25it/s]

Error loading dataset: name 'RoVITrackingInpaintingDataset' is not defined





NameError: name 'RoVITrackingInpaintingDataset' is not defined

In [None]:
# ----------------- 6. Helper Functions -----------------

def encode_image(image: torch.Tensor, image_encoder, device):
    """Encode image with CLIP for SVD conditioning."""
    # image shape: (B, C, H, W)
    image_resized = F.interpolate(
        image,
        size=(224, 224),
        mode='bilinear',
        align_corners=False
    )

    # Convert to float16 to match model dtype
    image_resized = image_resized.to(dtype=torch.float16)
    image_embeddings = image_encoder(image_resized).image_embeds
    return image_embeddings

@torch.no_grad()
def get_latents(video: torch.Tensor, vae: AutoencoderKL):
    """Convert video tensor to VAE latents."""
    B, C, T, H, W = video.shape
    video_flat = video.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)

    # Convert to float16 to match VAE dtype
    video_flat = video_flat.to(dtype=torch.float16)

    latents = vae.encode(video_flat).latent_dist.sample()
    latents = latents * vae.config.scaling_factor
    # Keep as (B, C, T, H, W) - correct order for SVD UNet
    latents = latents.reshape(B, T, *latents.shape[1:]).permute(0, 2, 1, 3, 4)
    return latents  # Returns (B, 4, T, H/8, W/8)

In [None]:
# ----------------- TRAINING (Completely rewritten) -----------------

def train_epoch(epoch):
    control_adapter.train()
    mpd.train()
    sce.train()

    total_losses = []
    inpaint_losses = []
    mask_losses = []
    temporal_losses = []

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}")

    for step, batch in enumerate(progress_bar):
        with accelerator.accumulate(control_adapter, mpd, sce):
            # Get data
            original_video = batch["original_video"].to(torch.float16)
            original_first_frame = batch["original_first_frame"].to(torch.float16)
            first_mask = batch["first_mask"].to(torch.float16)  # ONLY first mask as input
            edited_video = batch["edited_video"].to(torch.float16)  # TARGET
            mask_video = batch["mask_video"].to(torch.float16)  # TARGET (all masks)

            batch_size = original_video.shape[0]
            num_frames = original_video.shape[2]

            # Encode
            with torch.no_grad():
                target_latents = get_latents(edited_video, vae)
                image_embeddings = encode_image(original_first_frame, image_encoder, device).unsqueeze(1)
                first_frame_latents = vae.encode(original_first_frame).latent_dist.sample()
                first_frame_latents = first_frame_latents * vae.config.scaling_factor

            # SCE: Encode first mask
            first_mask_2d = first_mask.squeeze(2)  # Remove T dimension: (B, 1, 1, H, W) -> (B, 1, H, W)
            sce_features = sce(first_mask_2d)  # (B, feature_dim)
            # sce_features = sce(first_mask)  # (B, feature_dim)

            # Add noise to target
            noise = torch.randn_like(target_latents)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=device).long()

            with torch.no_grad():
                noisy_latents = noise_scheduler.add_noise(target_latents, noise, timesteps)

            # Control adapter: Use full mask video (during training we have ground truth)
            control_cond = control_adapter(mask_video)
            control_cond = F.interpolate(control_cond, size=noisy_latents.shape[-3:], mode='trilinear', align_corners=False)

            # Add control
            controlled_latents = noisy_latents + control_cond

            # Expand first frame
            with torch.no_grad():
                first_frame_latents_expanded = first_frame_latents.unsqueeze(2).repeat(1, 1, num_frames, 1, 1)

            # Concatenate
            conditioned_latents = torch.cat([controlled_latents, first_frame_latents_expanded], dim=1)
            conditioned_latents = conditioned_latents.permute(0, 2, 1, 3, 4).to(torch.float16)

            # Time IDs
            with torch.no_grad():
                added_time_ids = torch.tensor([[7, 127, 0.02]], dtype=torch.float16, device=device).repeat(batch_size, 1)
                image_embeddings_fp16 = image_embeddings.to(torch.float16)

            # UNet forward (inpainting)
            model_pred = unet(conditioned_latents, timesteps, encoder_hidden_states=image_embeddings_fp16,
                            added_time_ids=added_time_ids, return_dict=False)[0]
            model_pred = model_pred.permute(0, 2, 1, 3, 4)

            # MPD: Predict masks from controlled latents + SCE features
            predicted_masks = mpd(controlled_latents, sce_features, num_frames)

            # Upsample predicted masks to match mask_video size
            predicted_masks_upsampled = F.interpolate(
                predicted_masks,
                size=mask_video.shape[-3:],
                mode='trilinear',
                align_corners=False
            )

            # === LOSSES ===

            # 1. Inpainting loss (denoising)
            inpaint_loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")

            # 2. Mask prediction loss (tracking)
            mask_loss = F.binary_cross_entropy(
                predicted_masks_upsampled.float(),
                mask_video.float(),
                reduction="mean"
            )
            # 3. Temporal consistency
            pred_diff = torch.abs(model_pred[:, :, 1:] - model_pred[:, :, :-1])
            temporal_loss = pred_diff.mean()

            # 4. Region-aware loss (focus on masked regions)
            denoise_loss_per_pixel = F.mse_loss(model_pred.float(), noise.float(), reduction="none")
            with torch.no_grad():
                latent_mask = F.interpolate(mask_video, size=target_latents.shape[-3:], mode='trilinear', align_corners=False)

            ra_inside = (denoise_loss_per_pixel * latent_mask).sum() / (latent_mask.sum() + 1e-6)
            ra_outside = (denoise_loss_per_pixel * (1 - latent_mask)).sum() / ((1 - latent_mask).sum() + 1e-6)

            # Total loss
            lambda_inpaint = 1.0
            lambda_mask = 2.0  # Mask prediction is important
            lambda_temp = 0.2
            lambda_ra = 0.5

            total_loss = (lambda_inpaint * inpaint_loss) + \
                        (lambda_mask * mask_loss) + \
                        (lambda_temp * temporal_loss) + \
                        (lambda_ra * (ra_inside + ra_outside))

            # Backprop
            accelerator.backward(total_loss)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(trainable_params, 1.0)
            optimizer.step()
            optimizer.zero_grad()

            # Logging
            total_losses.append(total_loss.item())
            inpaint_losses.append(inpaint_loss.item())
            mask_losses.append(mask_loss.item())
            temporal_losses.append(temporal_loss.item())

            progress_bar.set_postfix({
                'total': f"{total_loss.item():.4f}",
                'inpaint': f"{inpaint_loss.item():.4f}",
                'mask': f"{mask_loss.item():.4f}",
                'temp': f"{temporal_loss.item():.4f}"
            })

            if step % 5 == 0:
                torch.cuda.empty_cache()

    return {
        'total': np.mean(total_losses),
        'inpaint': np.mean(inpaint_losses),
        'mask': np.mean(mask_losses),
        'temporal': np.mean(temporal_losses)
    }

In [None]:
# ----------------- 8. Main Training Loop -----------------

print(f"\n{'='*60}")
print(f"Starting Training for {NUM_EPOCHS} Epochs")
print(f"{'='*60}\n")

for epoch in range(10):
    losses = train_epoch(epoch + 1)

    print(f"\n{'='*60}")
    print(f"EPOCH {epoch+1}/{NUM_EPOCHS} Summary")
    print(f"{'='*60}")
    print(f"Total Loss:     {losses['total']:.4f}")
    print(f"Inpaint Loss:   {losses['inpaint']:.4f}")  # Changed
    print(f"Mask Loss:      {losses['mask']:.4f}")      # Changed
    print(f"Temporal Loss:  {losses['temporal']:.4f}")
    print(f"{'='*60}\n")

    # Save checkpoint
    if (epoch + 1) % CHECKPOINT_SAVE_EPOCHS == 0:
        checkpoint_dir = Path(f"./checkpoints/epoch_{epoch+1}")
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        # Save with accelerator
        accelerator.save_state(str(checkpoint_dir))
        print(f"Checkpoint saved to {checkpoint_dir}\n")

        # Clear memory
        torch.cuda.empty_cache()
        gc.collect()

print("Training Complete!")


Starting Training for 10 Epochs



Epoch 1: 100%|██████████| 10/10 [00:20<00:00,  2.10s/it, total=21.0448, inpaint=3.8780, mask=0.6217, temp=0.0917]



EPOCH 1/10 Summary
Total Loss:     15.0342
Inpaint Loss:   2.7578
Mask Loss:      0.6387
Temporal Loss:  0.1629



Epoch 2: 100%|██████████| 10/10 [00:21<00:00,  2.14s/it, total=17.8813, inpaint=3.2598, mask=0.6485, temp=0.1039]



EPOCH 2/10 Summary
Total Loss:     14.9526
Inpaint Loss:   2.6988
Mask Loss:      0.6351
Temporal Loss:  0.1713

Checkpoint saved to checkpoints/epoch_2



Epoch 3: 100%|██████████| 10/10 [00:20<00:00,  2.09s/it, total=12.3363, inpaint=2.1894, mask=0.6429, temp=0.1569]



EPOCH 3/10 Summary
Total Loss:     15.6775
Inpaint Loss:   2.8569
Mask Loss:      0.6337
Temporal Loss:  0.1479



Epoch 4: 100%|██████████| 10/10 [00:20<00:00,  2.06s/it, total=9.9912, inpaint=1.8634, mask=0.6260, temp=0.2063]



EPOCH 4/10 Summary
Total Loss:     13.6095
Inpaint Loss:   2.5058
Mask Loss:      0.6310
Temporal Loss:  0.1618

Checkpoint saved to checkpoints/epoch_4



Epoch 5: 100%|██████████| 10/10 [00:20<00:00,  2.08s/it, total=17.6049, inpaint=2.7859, mask=0.6260, temp=0.0710]



EPOCH 5/10 Summary
Total Loss:     15.7320
Inpaint Loss:   2.8388
Mask Loss:      0.6276
Temporal Loss:  0.1248



Epoch 6: 100%|██████████| 10/10 [00:20<00:00,  2.10s/it, total=17.6415, inpaint=3.2491, mask=0.6138, temp=0.1295]



EPOCH 6/10 Summary
Total Loss:     15.7318
Inpaint Loss:   2.8949
Mask Loss:      0.6272
Temporal Loss:  0.1452

Checkpoint saved to checkpoints/epoch_6



Epoch 7: 100%|██████████| 10/10 [00:20<00:00,  2.08s/it, total=11.7046, inpaint=2.1264, mask=0.6174, temp=0.2429]



EPOCH 7/10 Summary
Total Loss:     14.9527
Inpaint Loss:   2.7418
Mask Loss:      0.6236
Temporal Loss:  0.1477



Epoch 8: 100%|██████████| 10/10 [00:20<00:00,  2.08s/it, total=18.5052, inpaint=3.3762, mask=0.6388, temp=0.0776]



EPOCH 8/10 Summary
Total Loss:     16.3070
Inpaint Loss:   2.9740
Mask Loss:      0.6207
Temporal Loss:  0.1404

Checkpoint saved to checkpoints/epoch_8



Epoch 9: 100%|██████████| 10/10 [00:20<00:00,  2.09s/it, total=14.3713, inpaint=2.7069, mask=0.6028, temp=0.1263]



EPOCH 9/10 Summary
Total Loss:     13.8158
Inpaint Loss:   2.5498
Mask Loss:      0.6171
Temporal Loss:  0.1799



Epoch 10: 100%|██████████| 10/10 [00:20<00:00,  2.09s/it, total=13.6323, inpaint=2.8993, mask=0.5993, temp=0.2070]



EPOCH 10/10 Summary
Total Loss:     16.0338
Inpaint Loss:   2.9509
Mask Loss:      0.6149
Temporal Loss:  0.1499

Checkpoint saved to checkpoints/epoch_10

Training Complete!


In [None]:
# Extract data
zip_path = "/content/bird.zip"
extract_to = "/content/"

print("Extracting data...")
os.makedirs(extract_to, exist_ok=True)


with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    for file in tqdm(zip_ref.infolist(), desc="Extracting"):
        zip_ref.extract(file, extract_to)

Extracting data...


FileNotFoundError: [Errno 2] No such file or directory: '/content/bird.zip'

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

In [None]:
# Load SD 1.5 VAE for decoding
print("Loading Stable Diffusion VAE for decoding...")
sd_vae = AutoencoderKL.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="vae",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=False
).to(device)
sd_vae.eval()
print("✓ SD VAE loaded!")

@torch.no_grad()
def decode_latents_with_sd_vae(latents, sd_vae):
    """Decode video latents frame by frame using SD VAE."""
    B, C, T, H, W = latents.shape
    decoded_frames = []

    for t in range(T):
        # Get single frame
        frame_latent = latents[:, :, t, :, :] / sd_vae.config.scaling_factor
        frame_latent = frame_latent.to(torch.float16)

        # Decode
        decoded = sd_vae.decode(frame_latent).sample
        decoded_frames.append(decoded)

    # Stack: (B, T, 3, H, W)
    return torch.stack(decoded_frames, dim=1)

In [None]:
@torch.no_grad()
def track_and_inpaint(video_folder, first_mask_path, num_frames=6, num_inference_steps=25):
    """
    Track object + inpaint video.
    Input: Original video + ONLY first frame mask
    Output: Inpainted video + Predicted masks for all frames
    """

    sce.eval()
    control_adapter.eval()
    mpd.eval()
    unet.eval()
    image_encoder.eval()

    print("="*60)
    print("TRACKING + INPAINTING")
    print("="*60)

    # Load original video
    print("\n1. Loading original video...")
    video_path = Path(video_folder)
    frame_files = sorted(list(video_path.glob('*.jpg')) + list(video_path.glob('*.png')))[:num_frames]

    original_frames = []
    for frame_file in frame_files:
        img = Image.open(frame_file).convert('RGB').resize((RESOLUTION, RESOLUTION))
        original_frames.append(np.array(img))

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3),
    ])

    original_first = Image.fromarray(original_frames[0])
    original_first_tensor = transform(original_first).unsqueeze(0).to(device).to(torch.float16)

    # Load ONLY first frame mask
    print("\n2. Loading first frame mask ONLY...")
    first_mask = Image.open(first_mask_path).convert('L').resize((RESOLUTION, RESOLUTION))
    first_mask_array = np.array(first_mask) / 255.0
    first_mask_tensor = torch.from_numpy(first_mask_array).unsqueeze(0).unsqueeze(0).to(device).to(torch.float16)
    print(f"   First mask shape: {first_mask_tensor.shape}")

    # Encode first frame
    print("\n3. Encoding original first frame...")
    image_embeddings = encode_image(original_first_tensor, image_encoder, device).unsqueeze(1)
    first_frame_latents = vae.encode(original_first_tensor).latent_dist.sample()
    first_frame_latents = first_frame_latents * vae.config.scaling_factor

    # SCE: Encode first mask
    print("\n4. Encoding first mask with SCE...")
    sce_features = sce(first_mask_tensor)  # (B, feature_dim)
    print(f"   SCE features shape: {sce_features.shape}")

    # Initialize latents
    print("\n5. Initializing latents...")
    B, C, T = 1, 4, num_frames
    H, W = RESOLUTION // 8, RESOLUTION // 8
    latents = torch.randn(B, C, T, H, W, device=device, dtype=torch.float16)

    noise_scheduler.set_timesteps(num_inference_steps, device=device)
    first_frame_latents_expanded = first_frame_latents.unsqueeze(2).repeat(1, 1, T, 1, 1)

    # Denoising loop with mask prediction
    print(f"\n6. Tracking + Inpainting ({num_inference_steps} steps)...")

    predicted_masks_final = None

    for i, t in enumerate(tqdm(noise_scheduler.timesteps, desc="Tracking & Inpainting")):
        # MPD: Predict masks for ALL frames from current latents + SCE
        predicted_masks = mpd(latents, sce_features, num_frames)  # (B, 1, T, H, W) in latent space

        # Save final predicted masks
        if i == len(noise_scheduler.timesteps) - 1:
            predicted_masks_final = predicted_masks.clone()

        # Control adapter: Use predicted masks (not ground truth!)
        control_cond = control_adapter(predicted_masks)
        control_cond = F.interpolate(
            control_cond,
            size=latents.shape[-3:],
            mode='trilinear',
            align_corners=False
        )

        # Add control
        latents_with_control = latents + control_cond

        # Concatenate with first frame
        latent_model_input = torch.cat([latents_with_control, first_frame_latents_expanded], dim=1)
        latent_model_input = latent_model_input.permute(0, 2, 1, 3, 4).to(torch.float16)

        timestep = torch.tensor([t], device=device, dtype=torch.long)
        added_time_ids = torch.tensor([[7, 127, 0.02]], dtype=torch.float16, device=device)

        # UNet prediction
        noise_pred = unet(
            latent_model_input,
            timestep,
            encoder_hidden_states=image_embeddings.to(torch.float16),
            added_time_ids=added_time_ids,
            return_dict=False
        )[0]

        noise_pred = noise_pred.permute(0, 2, 1, 3, 4)
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

    # Decode inpainted video
    print("\n7. Decoding inpainted video...")
    decoded_video = decode_latents_with_sd_vae(latents, sd_vae)

    inpainted_frames = []
    for t in range(T):
        frame = decoded_video[0, t].cpu()
        frame = (frame * 0.5 + 0.5).clamp(0, 1)
        frame = frame.permute(1, 2, 0).numpy().astype(np.float32)
        inpainted_frames.append(frame)

    # Upsample predicted masks for visualization
    print("\n8. Upsampling predicted masks...")
    predicted_masks_upsampled = F.interpolate(
        predicted_masks_final,
        size=(num_frames, RESOLUTION, RESOLUTION),
        mode='trilinear',
        align_corners=False
    )
    predicted_masks_np = predicted_masks_upsampled[0, 0].cpu().numpy()  # (T, H, W)

    print("   ✓ Complete!")

    # Visualize
    print("\n9. Visualizing...")
    fig, axes = plt.subplots(3, num_frames, figsize=(20, 12))

    for i in range(num_frames):
        # Row 1: Original (with object)
        axes[0, i].imshow(original_frames[i])
        axes[0, i].set_title(f'Original {i}', fontsize=10)
        axes[0, i].axis('off')

        # Row 2: Predicted masks (tracking!)
        axes[1, i].imshow(predicted_masks_np[i], cmap='gray', vmin=0, vmax=1)
        if i == 0:
            axes[1, i].set_title(f'Mask {i} (Input)', color='green', fontsize=10, fontweight='bold')
        else:
            axes[1, i].set_title(f'Mask {i} (Tracked)', color='blue', fontsize=10, fontweight='bold')
        axes[1, i].axis('off')

        # Row 3: Inpainted (object removed)
        axes[2, i].imshow(inpainted_frames[i])
        axes[2, i].set_title(f'Inpainted {i}', color='green', fontsize=10, fontweight='bold')
        axes[2, i].axis('off')

    plt.suptitle('Original | Predicted Masks (Tracking) | Inpainted (Object Removed)', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig('tracking_inpainting_result.png', dpi=150, bbox_inches='tight')

    from IPython.display import Image as IPImage, display
    plt.close()
    display(IPImage('tracking_inpainting_result.png'))

    print("\n" + "="*60)
    print("✓ TRACKING + INPAINTING COMPLETE!")
    print("="*60)
    print("Row 1: Original video (with cat)")
    print("Row 2: Mask 0 = Your input | Masks 1-5 = Model tracked!")
    print("Row 3: Cat removed using predicted masks")
    print("="*60)

# Run inference
VIDEO_FOLDER = "/content/__KkKB4wzrY"  # Original video WITH cat
FIRST_MASK = "/content/mask_bird.png"     # ONLY first frame mask

track_and_inpaint(VIDEO_FOLDER, FIRST_MASK, num_frames=6, num_inference_steps=25)

In [1]:
# Copy checkpoint to your mounted Drive
!cp -r ./checkpoints/epoch_10 /content/drive/MyDrive/genprop_checkpoints/

print("Checkpoint saved to Google Drive!")



Checkpoint saved to Google Drive!
