In [None]:
"""
FastEdit: Lightweight SVD Implementation for One-Step Video Editing
Optimized for Kaggle (16GB GPU) and easy experimentation
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image, ImageDraw
import cv2
from pathlib import Path
import os
import gc
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# PART 1: INSTALLATION & SETUP
# ============================================================================

def setup_environment():
    """Install required packages"""
    import subprocess
    import sys

    packages = [
        "diffusers==0.24.0",
        "transformers",
        "accelerate",
        "xformers",
        "opencv-python",
        "pillow",
        "tqdm"
    ]

    for package in packages:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])

    print("✓ Environment setup complete")

    # Check GPU
    if torch.cuda.is_available():
        print(f"✓ GPU: {torch.cuda.get_device_name(0)}")
        print(f"✓ Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    else:
        print("⚠ No GPU found, using CPU")

# ============================================================================
# PART 2: LIGHTWEIGHT SVD MODEL
# ============================================================================

class LightweightSVD(nn.Module):
    """
    Simplified SVD architecture for fast video generation
    Uses minimal memory and can run on 8GB GPUs
    """

    def __init__(self,
                 latent_channels=4,
                 hidden_dim=128,  # Reduced from typical 320
                 num_frames=8,
                 image_size=128):  # Smaller than typical 256/512
        super().__init__()

        self.num_frames = num_frames
        self.image_size = image_size

        # Lightweight VAE Encoder (for single frame)
        self.frame_encoder = nn.Sequential(
            nn.Conv2d(3, hidden_dim//4, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim//4, hidden_dim//2, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim//2, hidden_dim, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, latent_channels, 3, padding=1)
        )

        # Temporal expansion network (frame -> video latents)
        self.temporal_net = nn.Sequential(
            nn.Conv3d(latent_channels, hidden_dim, (3, 3, 3), padding=1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, hidden_dim, (3, 3, 3), padding=1),
            nn.ReLU()
        )

        # Lightweight UNet for denoising
        self.denoiser = SimplifiedUNet(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            hidden_dim=hidden_dim
        )

        # VAE Decoder (latents -> frames)
        self.frame_decoder = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim, hidden_dim//2, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim//2, hidden_dim//4, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim//4, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

        # Edit injection layer
        self.edit_encoder = nn.Sequential(
            nn.Conv2d(6, hidden_dim, 3, padding=1),  # 6ch: edited + original
            nn.ReLU(),
            nn.Conv2d(hidden_dim, latent_channels, 1)
        )

    def encode_frame(self, x):
        """Encode single frame to latent"""
        return self.frame_encoder(x)

    def decode_frame(self, z):
        """Decode latent to frame"""
        return self.frame_decoder(z)

    def forward(self, edited_frame, original_frame=None):
        """
        Generate video from edited first frame
        Args:
            edited_frame: [B, 3, H, W]
            original_frame: [B, 3, H, W] optional
        Returns:
            video: [B, T, 3, H, W]
        """
        B = edited_frame.shape[0]

        # Encode edited frame
        z_edited = self.encode_frame(edited_frame)  # [B, C, h, w]

        # Extract edit information if original provided
        if original_frame is not None:
            combined = torch.cat([edited_frame, original_frame], dim=1)
            edit_features = self.edit_encoder(combined)  # [B, C, H, W]
            z_edited = z_edited + 0.5 * edit_features

        # Expand temporally
        z_video = z_edited.unsqueeze(2).repeat(1, 1, self.num_frames, 1, 1)  # [B, C, T, h, w]
        z_video = self.temporal_net(z_video)  # [B, C, T, h, w]

        # Denoise each frame
        video_frames = []
        for t in range(self.num_frames):
            z_t = z_video[:, :, t, :, :]  # [B, C, h, w]
            z_t_denoised = self.denoiser(z_t)
            frame = self.decode_frame(z_t_denoised)
            video_frames.append(frame)

        video = torch.stack(video_frames, dim=1)  # [B, T, 3, H, W]

        # Force first frame to match edited input
        video[:, 0] = edited_frame

        return video

class SimplifiedUNet(nn.Module):
    """Tiny UNet for denoising"""

    def __init__(self, in_channels, out_channels, hidden_dim):
        super().__init__()

        # Encoder
        self.enc1 = nn.Conv2d(in_channels, hidden_dim, 3, padding=1)
        self.enc2 = nn.Conv2d(hidden_dim, hidden_dim*2, 3, stride=2, padding=1)

        # Bottleneck
        self.bottleneck = nn.Conv2d(hidden_dim*2, hidden_dim*2, 3, padding=1)

        # Decoder
        self.dec2 = nn.ConvTranspose2d(hidden_dim*2, hidden_dim, 3, stride=2, padding=1, output_padding=1)
        self.dec1 = nn.Conv2d(hidden_dim*2, out_channels, 3, padding=1)  # +skip connection

    def forward(self, x):
        # Encode
        e1 = F.relu(self.enc1(x))
        e2 = F.relu(self.enc2(e1))

        # Bottleneck
        b = F.relu(self.bottleneck(e2))

        # Decode with skip connections
        d2 = F.relu(self.dec2(b))
        d1 = self.dec1(torch.cat([d2, e1], dim=1))

        return d1

# ============================================================================
# PART 3: DATASET
# ============================================================================

class SimpleVideoDataset(Dataset):
    """
    Simple dataset for testing - creates synthetic data
    For real training, replace with DAVIS or your own data
    """

    def __init__(self, num_samples=100, num_frames=8, image_size=128):
        self.num_samples = num_samples
        self.num_frames = num_frames
        self.image_size = image_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        """Generate synthetic video with object to remove"""

        # Create base video (moving square)
        frames = []
        masks = []

        for t in range(self.num_frames):
            # Create frame
            img = Image.new('RGB', (self.image_size, self.image_size),
                          color=(100, 100, 200))  # Blue background
            draw = ImageDraw.Draw(img)

            # Moving red square (object to remove)
            x = 20 + t * 5
            y = 30 + t * 3
            draw.rectangle([x, y, x+30, y+30], fill=(255, 0, 0))

            # Static green circle (should remain)
            draw.ellipse([80, 80, 110, 110], fill=(0, 255, 0))

            frames.append(np.array(img))

            # Create mask for red square
            mask = np.zeros((self.image_size, self.image_size))
            mask[y:y+30, x:x+30] = 1
            masks.append(mask)

        frames = np.stack(frames)  # [T, H, W, 3]
        masks = np.stack(masks)    # [T, H, W]

        # Create edited first frame (remove red square)
        edited_frame = frames[0].copy()
        edited_frame[masks[0] > 0] = [100, 100, 200]  # Fill with background

        # Convert to tensors
        frames_tensor = torch.from_numpy(frames).permute(0, 3, 1, 2).float() / 255.0
        edited_tensor = torch.from_numpy(edited_frame).permute(2, 0, 1).float() / 255.0
        original_tensor = torch.from_numpy(frames[0]).permute(2, 0, 1).float() / 255.0

        return {
            'video': frames_tensor,           # [T, 3, H, W]
            'edited_frame': edited_tensor,     # [3, H, W]
            'original_frame': original_tensor, # [3, H, W]
            'masks': torch.from_numpy(masks)   # [T, H, W]
        }

# ============================================================================
# PART 4: TRAINING
# ============================================================================

class FastEditTrainer:
    """Trainer for the lightweight SVD model"""

    def __init__(self, model, device='cuda'):
        self.model = model.to(device)
        self.device = device

        # Optimizer
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

        # Loss functions
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()

    def compute_loss(self, generated, target, masks=None):
        """Compute training losses"""

        # Reconstruction loss
        recon_loss = self.mse_loss(generated, target)

        # Temporal consistency loss
        gen_diff = torch.abs(generated[:, 1:] - generated[:, :-1])
        target_diff = torch.abs(target[:, 1:] - target[:, :-1])
        temporal_loss = self.l1_loss(gen_diff, target_diff)

        # Perceptual loss (simplified - just feature matching)
        gen_features = F.avg_pool2d(generated.reshape(-1, *generated.shape[2:]), 4)
        target_features = F.avg_pool2d(target.reshape(-1, *target.shape[2:]), 4)
        perceptual_loss = self.mse_loss(gen_features, target_features)

        # Total loss
        total_loss = recon_loss + 0.1 * temporal_loss + 0.1 * perceptual_loss

        return {
            'total': total_loss,
            'recon': recon_loss.item(),
            'temporal': temporal_loss.item(),
            'perceptual': perceptual_loss.item()
        }

    def train_step(self, batch):
        """Single training step"""

        # Move to device
        edited_frame = batch['edited_frame'].to(self.device)
        original_frame = batch['original_frame'].to(self.device)
        target_video = batch['video'].to(self.device)

        # Forward pass
        generated_video = self.model(edited_frame, original_frame)

        # Compute loss
        losses = self.compute_loss(generated_video, target_video)

        # Backward pass
        self.optimizer.zero_grad()
        losses['total'].backward()
        self.optimizer.step()

        return losses

    def train(self, dataloader, num_epochs=10):
        """Full training loop"""

        print("Starting training...")
        self.model.train()

        for epoch in range(num_epochs):
            epoch_losses = {'recon': 0, 'temporal': 0, 'perceptual': 0}

            progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
            for batch_idx, batch in enumerate(progress_bar):
                losses = self.train_step(batch)

                # Update metrics
                for key in epoch_losses:
                    epoch_losses[key] += losses[key]

                # Update progress bar
                progress_bar.set_postfix({
                    'loss': f"{losses['recon']:.4f}",
                    'temp': f"{losses['temporal']:.4f}"
                })

                # Save sample every 50 batches
                if batch_idx % 50 == 0:
                    self.save_sample(batch, epoch, batch_idx)

            # Print epoch summary
            num_batches = len(dataloader)
            print(f"\nEpoch {epoch+1} Summary:")
            print(f"  Recon Loss: {epoch_losses['recon']/num_batches:.4f}")
            print(f"  Temporal Loss: {epoch_losses['temporal']/num_batches:.4f}")
            print(f"  Perceptual Loss: {epoch_losses['perceptual']/num_batches:.4f}")

            # Save checkpoint
            self.save_checkpoint(epoch)

    def save_sample(self, batch, epoch, batch_idx):
        """Save generated video sample"""
        self.model.eval()

        with torch.no_grad():
            edited = batch['edited_frame'][:1].to(self.device)
            original = batch['original_frame'][:1].to(self.device)
            generated = self.model(edited, original)

        # Save as GIF
        frames = []
        for t in range(generated.shape[1]):
            frame = generated[0, t].cpu()
            frame = (frame.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
            frames.append(Image.fromarray(frame))

        output_path = f"samples/epoch_{epoch}_batch_{batch_idx}.gif"
        os.makedirs("samples", exist_ok=True)
        frames[0].save(
            output_path,
            save_all=True,
            append_images=frames[1:],
            duration=100,
            loop=0
        )

        self.model.train()

    def save_checkpoint(self, epoch):
        """Save model checkpoint"""
        os.makedirs("checkpoints", exist_ok=True)
        torch.save({
            'epoch': epoch,
            'model_state': self.model.state_dict(),
            'optimizer_state': self.optimizer.state_dict()
        }, f"checkpoints/checkpoint_epoch_{epoch}.pt")

# ============================================================================
# PART 5: INFERENCE & TESTING
# ============================================================================

class FastEditInference:
    """Inference pipeline for testing"""

    def __init__(self, model_path=None, device='cuda'):
        self.device = device

        # Load model
        self.model = LightweightSVD(
            hidden_dim=128,
            num_frames=8,
            image_size=128
        ).to(device)

        # Load checkpoint if provided
        if model_path and Path(model_path).exists():
            checkpoint = torch.load(model_path, map_location=device)
            self.model.load_state_dict(checkpoint['model_state'])
            print(f"✓ Loaded checkpoint from epoch {checkpoint['epoch']}")

        self.model.eval()

    def remove_object(self, image_path, mask_path=None):
        """Remove object from image and generate video"""

        # Load image
        img = Image.open(image_path).convert('RGB')
        img = img.resize((128, 128))

        # Create or load mask
        if mask_path:
            mask = Image.open(mask_path).convert('L')
            mask = mask.resize((128, 128))
            mask = np.array(mask) > 128
        else:
            # Simple color-based mask (remove red objects)
            img_np = np.array(img)
            mask = (img_np[:,:,0] > 150) & (img_np[:,:,1] < 100) & (img_np[:,:,2] < 100)

        # Create edited version (simple inpainting)
        edited_np = np.array(img)
        if mask.any():
            edited_np[mask] = edited_np[~mask].mean(axis=0)
        edited = Image.fromarray(edited_np)

        # Convert to tensors
        original_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
        edited_tensor = torch.from_numpy(edited_np).permute(2, 0, 1).float() / 255.0

        # Generate video
        with torch.no_grad():
            original_batch = original_tensor.unsqueeze(0).to(self.device)
            edited_batch = edited_tensor.unsqueeze(0).to(self.device)
            video = self.model(edited_batch, original_batch)

        return video[0].cpu()  # [T, 3, H, W]

    def test_synthetic(self):
        """Test on synthetic data"""

        # Create test image
        img = Image.new('RGB', (128, 128), color='blue')
        draw = ImageDraw.Draw(img)
        draw.rectangle([30, 30, 60, 60], fill='red')  # Object to remove
        draw.ellipse([70, 70, 100, 100], fill='green')  # Should remain
        img.save("test_input.png")

        # Generate video
        print("Generating video with object removed...")
        video = self.remove_object("test_input.png")

        # Save result
        frames = []
        for t in range(video.shape[0]):
            frame = (video[t].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
            frames.append(Image.fromarray(frame))

        frames[0].save(
            "test_output.gif",
            save_all=True,
            append_images=frames[1:],
            duration=100,
            loop=0
        )
        print("✓ Saved result to test_output.gif")

# ============================================================================
# PART 6: MAIN EXECUTION
# ============================================================================

def main():
    """Main execution function"""

    # Setup
    setup_environment()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

    print("\n" + "="*50)
    print("FastEdit: Lightweight SVD Training")
    print("="*50)

    # 1. Create model
    print("\n1. Creating model...")
    model = LightweightSVD(
        hidden_dim=128,
        num_frames=8,
        image_size=128
    )

    # Print model size
    total_params = sum(p.numel() for p in model.parameters())
    print(f"✓ Model created: {total_params/1e6:.2f}M parameters")

    # 2. Create dataset
    print("\n2. Creating dataset...")
    dataset = SimpleVideoDataset(
        num_samples=200,
        num_frames=8,
        image_size=128
    )
    dataloader = DataLoader(
        dataset,
        batch_size=4,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    print(f"✓ Dataset created: {len(dataset)} samples")

    # 3. Train model
    print("\n3. Training model...")
    trainer = FastEditTrainer(model, device)
    trainer.train(dataloader, num_epochs=5)  # Quick training

    # 4. Test model
    print("\n4. Testing model...")
    inference = FastEditInference(
        model_path="checkpoints/checkpoint_epoch_4.pt",
        device=device
    )
    inference.test_synthetic()

    print("\n✓ Complete! Check 'samples/' and 'test_output.gif' for results.")

# ============================================================================
# PART 7: QUICK TEST SCRIPTS
# ============================================================================

def quick_test():
    """Quick test without full training"""

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}")

    # Create small model
    model = LightweightSVD(
        hidden_dim=64,  # Even smaller
        num_frames=4,   # Fewer frames
        image_size=64   # Lower resolution
    ).to(device)

    # Test forward pass
    test_input = torch.randn(1, 3, 64, 64).to(device)
    with torch.no_grad():
        output = model(test_input)

    print(f"Input shape: {test_input.shape}")
    print(f"Output shape: {output.shape}")
    print("✓ Model working correctly!")

    return model

def test_on_real_video(video_path, output_path="output.gif"):
    """Test on real video file"""

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Load model
    model = LightweightSVD().to(device)
    model.eval()

    # Load video
    cap = cv2.VideoCapture(video_path)
    frames = []

    for _ in range(8):  # Load first 8 frames
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (128, 128))
        frames.append(frame)

    cap.release()

    if len(frames) == 0:
        print("Error: Could not load video")
        return

    # Use first frame as edited (simplified)
    first_frame = frames[0]
    first_tensor = torch.from_numpy(first_frame).permute(2, 0, 1).float() / 255.0
    first_tensor = first_tensor.unsqueeze(0).to(device)

    # Generate
    with torch.no_grad():
        generated = model(first_tensor)

    # Save
    out_frames = []
    for t in range(generated.shape[1]):
        frame = generated[0, t].cpu()
        frame = (frame.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        out_frames.append(Image.fromarray(frame))

    out_frames[0].save(
        output_path,
        save_all=True,
        append_images=out_frames[1:],
        duration=100,
        loop=0
    )
    print(f"✓ Saved to {output_path}")

# ============================================================================
# RUN EVERYTHING
# ============================================================================

if __name__ == "__main__":
    # Choose what to run:

    # Option 1: Full training pipeline
    main()

    # Option 2: Quick test only
    # quick_test()

    # Option 3: Test on real video
    # test_on_real_video("input_video.mp4", "output.gif")

✓ Environment setup complete
✓ GPU: Tesla T4
✓ Memory: 15.83 GB

FastEdit: Lightweight SVD Training

1. Creating model...
✓ Model created: 2.57M parameters

2. Creating dataset...
✓ Dataset created: 200 samples

3. Training model...
Starting training...


Epoch 1/5:   0%|          | 0/50 [00:01<?, ?it/s]


RuntimeError: The size of tensor a (16) must match the size of tensor b (128) at non-singleton dimension 3

In [None]:
# Cell 1: Install packages
!pip install -q diffusers==0.24.0 transformers accelerate
!pip install -q opencv-python pillow tqdm
!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu118

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━[0m [32m1.0/1.8 MB[0m [31m29.8 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m33.2 MB/s[0m eta [36m0:00:00[0m
[?25hPyTorch version: 2.8.0+cu126
GPU available: True
GPU: Tesla T4


In [None]:
# Cell 2: Get DAVIS dataset - just 5 sequences for testing
import os
import shutil

# Download DAVIS if not present
if not os.path.exists("DAVIS_small"):
    print("Downloading DAVIS dataset...")
    !wget -q https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip
    !unzip -q DAVIS-2017-trainval-480p.zip

    # Create small subset (5 sequences only)
    os.makedirs("DAVIS_small/JPEGImages/480p", exist_ok=True)
    os.makedirs("DAVIS_small/Annotations/480p", exist_ok=True)

    # Copy only these sequences (good for object removal)
    sequences = ['blackswan', 'car-shadow', 'dog', 'parkour', 'scooter-black']

    for seq in sequences:
        print(f"Copying {seq}...")
        shutil.copytree(f"DAVIS/JPEGImages/480p/{seq}",
                       f"DAVIS_small/JPEGImages/480p/{seq}")
        shutil.copytree(f"DAVIS/Annotations/480p/{seq}",
                       f"DAVIS_small/Annotations/480p/{seq}")

    # Clean up large download
    !rm -rf DAVIS DAVIS-2017-trainval-480p.zip

print("✓ Dataset ready!")
print(f"Sequences available: {os.listdir('DAVIS_small/JPEGImages/480p/')}")

Downloading DAVIS dataset...
Copying blackswan...
Copying car-shadow...
Copying dog...
Copying parkour...
Copying scooter-black...
✓ Dataset ready!
Sequences available: ['blackswan', 'scooter-black', 'car-shadow', 'dog', 'parkour']


In [None]:
# Cell 3: Lightweight SVD model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import cv2
from pathlib import Path
from tqdm import tqdm

class MiniSVD(nn.Module):
    """Ultra-lightweight SVD for testing"""

    def __init__(self, num_frames=8, image_size=128):
        super().__init__()
        self.num_frames = num_frames

        # Simple encoder (image -> latent)
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 128->64
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 64->32
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1), # 32->16
            nn.ReLU(),
        )

        # Temporal generator (latent -> video latents)
        self.temporal = nn.Conv3d(128, 128, (3,3,3), padding=1)

        # Simple decoder (latent -> image)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 16->32
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),   # 32->64
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),    # 64->128
            nn.Sigmoid()
        )

    def forward(self, edited_frame):
        B = edited_frame.shape[0]

        # Encode first frame
        z = self.encoder(edited_frame)

        # Expand temporally
        z_video = z.unsqueeze(2).repeat(1, 1, self.num_frames, 1, 1)
        z_video = self.temporal(z_video)

        # Decode each frame
        frames = []
        for t in range(self.num_frames):
            frame = self.decoder(z_video[:, :, t, :, :])
            frames.append(frame)

        video = torch.stack(frames, dim=1)
        video[:, 0] = edited_frame  # Keep first frame

        return video

print("✓ Model defined")

✓ Model defined


In [None]:
# Cell 4: Dataset loader for DAVIS
class DAVISDataset(Dataset):
    def __init__(self, davis_path="DAVIS_small", num_frames=8, image_size=128):
        self.davis_path = Path(davis_path)
        self.num_frames = num_frames
        self.image_size = image_size

        # Get all sequences
        self.sequences = list((self.davis_path / "JPEGImages/480p").iterdir())
        print(f"Found {len(self.sequences)} sequences")

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq_path = self.sequences[idx]

        # Load frames
        frame_files = sorted(seq_path.glob("*.jpg"))[:self.num_frames]
        frames = []

        for f in frame_files:
            img = Image.open(f).convert('RGB')
            img = img.resize((self.image_size, self.image_size))
            img = np.array(img) / 255.0
            frames.append(img)

        # Pad if needed
        while len(frames) < self.num_frames:
            frames.append(frames[-1])

        frames = np.stack(frames)

        # Create "edited" first frame (simple color shift for testing)
        edited_frame = frames[0].copy()
        edited_frame[:, :50, :] = edited_frame[:, :50, :] * 0.5  # Darken top part

        # Convert to tensors
        frames_tensor = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
        edited_tensor = torch.from_numpy(edited_frame).permute(2, 0, 1).float()

        return {
            'video': frames_tensor,
            'edited_frame': edited_tensor,
            'seq_name': seq_path.name
        }

# Create dataset and loader
dataset = DAVISDataset()
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

print(f"✓ Dataset ready with {len(dataset)} sequences")

Found 5 sequences
✓ Dataset ready with 5 sequences


In [None]:
# Cell 5: Training
def train_mini_svd():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Training on: {device}")

    # Initialize model
    model = MiniSVD(num_frames=8, image_size=128).to(device)

    # Count parameters
    params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {params/1e6:.2f}M")

    # Setup training
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()

    # Training loop
    num_epochs = 5
    print(f"\nTraining for {num_epochs} epochs...")

    for epoch in range(num_epochs):
        total_loss = 0

        for batch_idx, batch in enumerate(dataloader):
            # Get data
            edited = batch['edited_frame'].to(device)
            target = batch['video'].to(device)

            # Forward pass
            output = model(edited)

            # Compute loss
            loss = criterion(output, target)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Print progress
            if batch_idx % 2 == 0:
                print(f"  Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(dataloader)}, Loss: {loss.item():.4f}")

        # Epoch summary
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} - Average Loss: {avg_loss:.4f}")

        # Save checkpoint
        if epoch % 2 == 0:
            torch.save(model.state_dict(), f"model_epoch_{epoch}.pt")
            print(f"  Saved checkpoint: model_epoch_{epoch}.pt")

    print("\n✓ Training complete!")
    return model

# Run training
model = train_mini_svd()

Training on: cuda
Model parameters: 0.77M

Training for 5 epochs...
  Epoch 1/5, Batch 1/5, Loss: 0.0446
  Epoch 1/5, Batch 3/5, Loss: 0.0566
  Epoch 1/5, Batch 5/5, Loss: 0.0339
Epoch 1 - Average Loss: 0.0486
  Saved checkpoint: model_epoch_0.pt
  Epoch 2/5, Batch 1/5, Loss: 0.0339
  Epoch 2/5, Batch 3/5, Loss: 0.0561
  Epoch 2/5, Batch 5/5, Loss: 0.0583
Epoch 2 - Average Loss: 0.0476
  Epoch 3/5, Batch 1/5, Loss: 0.0329
  Epoch 3/5, Batch 3/5, Loss: 0.0390
  Epoch 3/5, Batch 5/5, Loss: 0.0465
Epoch 3 - Average Loss: 0.0458
  Saved checkpoint: model_epoch_2.pt
  Epoch 4/5, Batch 1/5, Loss: 0.0556
  Epoch 4/5, Batch 3/5, Loss: 0.0496
  Epoch 4/5, Batch 5/5, Loss: 0.0329
Epoch 4 - Average Loss: 0.0431
  Epoch 5/5, Batch 1/5, Loss: 0.0457
  Epoch 5/5, Batch 3/5, Loss: 0.0314
  Epoch 5/5, Batch 5/5, Loss: 0.0516
Epoch 5 - Average Loss: 0.0412
  Saved checkpoint: model_epoch_4.pt

✓ Training complete!


In [None]:
# Cell 6: Test generation
def test_model(model):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.eval()

    print("Testing model on first sequence...")

    # Get one sample
    sample = dataset[0]
    edited = sample['edited_frame'].unsqueeze(0).to(device)

    # Generate video
    with torch.no_grad():
        generated = model(edited)

    print(f"Generated video shape: {generated.shape}")

    # Save as images
    import os
    os.makedirs("output", exist_ok=True)

    for t in range(generated.shape[1]):
        frame = generated[0, t].cpu()
        frame = (frame.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        img = Image.fromarray(frame)
        img.save(f"output/frame_{t:03d}.png")

    # Save as GIF
    frames = []
    for t in range(generated.shape[1]):
        frame = generated[0, t].cpu()
        frame = (frame.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        frames.append(Image.fromarray(frame))

    frames[0].save(
        "output/generated.gif",
        save_all=True,
        append_images=frames[1:],
        duration=100,
        loop=0
    )

    print("✓ Saved output to 'output/' folder")
    print("  - Individual frames: output/frame_XXX.png")
    print("  - Animation: output/generated.gif")

    return generated

# Test the trained model
output = test_model(model)

Testing model on first sequence...
Generated video shape: torch.Size([1, 8, 3, 128, 128])
✓ Saved output to 'output/' folder
  - Individual frames: output/frame_XXX.png
  - Animation: output/generated.gif


In [None]:
# Cell 1: Better Approach - Use Pretrained SVD
!pip install -q diffusers==0.24.0 transformers accelerate xformers

import torch
from diffusers import StableVideoDiffusionPipeline, I2VGenXLPipeline
from PIL import Image
import numpy as np
import cv2

# Use a smaller, faster model
def setup_working_svd():
    """Setup a model that actually works"""

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Option 1: Use I2VGen-XL (smaller, faster)
    print("Loading pretrained video model...")
    pipe = I2VGenXLPipeline.from_pretrained(
        "ali-vilab/i2vgen-xl",
        torch_dtype=torch.float16,
        variant="fp16"
    ).to(device)

    # Enable optimizations
    pipe.enable_model_cpu_offload()
    pipe.enable_vae_slicing()

    return pipe

pipe = setup_working_svd()
print("✓ Model ready!")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.2/117.2 MB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25h

ImportError: cannot import name 'cached_download' from 'huggingface_hub' (/usr/local/lib/python3.12/dist-packages/huggingface_hub/__init__.py)

In [None]:
# Cell 1: Complete reset and proper installation
import os
import sys

# Restart runtime after this cell
!pip uninstall -y diffusers huggingface-hub transformers accelerate peft sentence-transformers
!pip install --upgrade pip
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
!pip install pillow opencv-python numpy matplotlib tqdm

print("✓ Base packages installed")
print("⚠ Please restart runtime now: Runtime -> Restart runtime")

[0mFound existing installation: huggingface-hub 0.19.4
Uninstalling huggingface-hub-0.19.4:
  Successfully uninstalled huggingface-hub-0.19.4
[0mFound existing installation: accelerate 1.11.0
Uninstalling accelerate-1.11.0:
  Successfully uninstalled accelerate-1.11.0
Found existing installation: peft 0.17.1
Uninstalling peft-0.17.1:
  Successfully uninstalled peft-0.17.1
Found existing installation: sentence-transformers 5.1.2
Uninstalling sentence-transformers-5.1.2:
  Successfully uninstalled sentence-transformers-5.1.2
Collecting pip
  Downloading pip-25.3-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.3-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m32.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.3
Looking in index

In [None]:
# Cell 2: Install diffusion packages with correct versions
!pip install -q transformers==4.36.0
!pip install -q huggingface-hub==0.20.0
!pip install -q diffusers==0.25.0
!pip install -q accelerate==0.25.0

# Test imports
try:
    import torch
    from diffusers import DiffusionPipeline, DDPMScheduler
    print("✓ Imports successful!")
    print(f"✓ Torch version: {torch.__version__}")
    print(f"✓ CUDA available: {torch.cuda.is_available()}")
except ImportError as e:
    print(f"Error: {e}")

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 4.0.0 requires huggingface-hub>=0.24.0, but you have huggingface-hub 0.20.0 which is incompatible.
gradio 5.49.1 requires huggingface-hub<2.0,>=0.33.5, but you have huggingface-hub 0.20.0 which is incompatible.[0m[31m
[0m

  _torch_pytree._register_pytree_node(
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


✓ Imports successful!
✓ Torch version: 2.8.0+cu126
✓ CUDA available: True


In [None]:
# Cell 1: Install and import
!pip install -q diffusers transformers accelerate
!pip install -q torch torchvision
!pip install -q opencv-python pillow tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from diffusers import StableVideoDiffusionPipeline
from PIL import Image, ImageDraw
import numpy as np
import os
from tqdm import tqdm
import gc

print(f"✓ PyTorch: {torch.__version__}")
print(f"✓ GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


✓ PyTorch: 2.8.0+cu126
✓ GPU: Tesla T4


In [None]:
# Cell 2: Load pretrained SVD
def load_svd_model():
    print("Loading SVD model...")

    # Load pretrained model
    pipe = StableVideoDiffusionPipeline.from_pretrained(
        "stabilityai/stable-video-diffusion-img2vid",
        torch_dtype=torch.float16,
        variant="fp16"
    )

    # Move to GPU
    pipe = pipe.to("cuda")
    pipe.enable_model_cpu_offload()

    print("✓ SVD loaded successfully!")
    return pipe

# Load model
pipe = load_svd_model()

Loading SVD model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

✓ SVD loaded successfully!


In [None]:
# Cell 1: Load DAVIS dataset
import os
import cv2
from pathlib import Path
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch

class DAVISDataset(Dataset):
    """Real DAVIS dataset for object removal"""

    def __init__(self, davis_path="DAVIS", sequences=None, num_frames=8):
        self.davis_path = Path(davis_path)
        self.num_frames = num_frames

        # Get available sequences
        if sequences is None:
            # Use all available sequences
            self.sequences = list((self.davis_path / "JPEGImages/480p").iterdir())
        else:
            # Use specified sequences
            self.sequences = [self.davis_path / "JPEGImages/480p" / s for s in sequences]

        # Filter to only existing sequences
        self.sequences = [s for s in self.sequences if s.exists()]

        print(f"Found {len(self.sequences)} sequences:")
        for seq in self.sequences[:5]:  # Show first 5
            print(f"  - {seq.name}")

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq_path = self.sequences[idx]
        mask_path = self.davis_path / "Annotations/480p" / seq_path.name

        # Load frames
        frame_files = sorted(seq_path.glob("*.jpg"))[:self.num_frames]
        mask_files = sorted(mask_path.glob("*.png"))[:self.num_frames] if mask_path.exists() else []

        # Load first frame and mask
        first_frame = Image.open(frame_files[0])
        first_frame = first_frame.resize((512, 288))  # Resize for SVD

        if mask_files:
            first_mask = Image.open(mask_files[0]).convert('L')
            first_mask = first_mask.resize((512, 288))
            mask_array = np.array(first_mask) > 0
        else:
            # No mask available, create dummy
            mask_array = np.zeros((288, 512), dtype=bool)

        # Create edited frame by inpainting
        frame_array = np.array(first_frame)
        if mask_array.any():
            # Use OpenCV inpainting to remove object
            edited_array = cv2.inpaint(
                frame_array,
                mask_array.astype(np.uint8) * 255,
                3,
                cv2.INPAINT_TELEA
            )
            edited_frame = Image.fromarray(edited_array)
        else:
            edited_frame = first_frame

        # Load all frames for reference
        all_frames = []
        for f in frame_files:
            frame = Image.open(f).resize((512, 288))
            all_frames.append(np.array(frame))

        return {
            'edited_frame': edited_frame,      # First frame with object removed
            'original_frame': first_frame,     # Original first frame
            'mask': mask_array,                # Object mask
            'all_frames': all_frames,          # All video frames
            'sequence_name': seq_path.name
        }

# Check if DAVIS exists, if not download a sample
if not os.path.exists("DAVIS"):
    print("Downloading DAVIS samples...")
    os.makedirs("DAVIS/JPEGImages/480p", exist_ok=True)
    os.makedirs("DAVIS/Annotations/480p", exist_ok=True)

    # Download sample sequences
    !wget -q -O sample.zip "https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip"
    !unzip -q sample.zip "DAVIS/JPEGImages/480p/blackswan/*" "DAVIS/Annotations/480p/blackswan/*"
    !unzip -q sample.zip "DAVIS/JPEGImages/480p/car-shadow/*" "DAVIS/Annotations/480p/car-shadow/*"
    !unzip -q sample.zip "DAVIS/JPEGImages/480p/dog/*" "DAVIS/Annotations/480p/dog/*"
    !rm sample.zip

# Create dataset
dataset = DAVISDataset(
    davis_path="DAVIS",
    sequences=['blackswan', 'car-shadow', 'dog'],  # Use specific sequences
    num_frames=8
)

dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
print(f"✓ Dataset ready: {len(dataset)} sequences")

Downloading DAVIS samples...
Found 3 sequences:
  - blackswan
  - car-shadow
  - dog
✓ Dataset ready: 3 sequences


In [None]:
# Cell 3: Load SVD and prepare for training
from diffusers import StableVideoDiffusionPipeline
import torch.nn as nn

# Load pretrained SVD
pipe = StableVideoDiffusionPipeline.from_pretrained(
    "stabilityai/stable-video-diffusion-img2vid",
    torch_dtype=torch.float16,
    variant="fp16"
)
pipe = pipe.to("cuda")
pipe.enable_model_cpu_offload()

# Prepare for FastEdit training (freeze 85%)
def prepare_svd_for_training(pipe):
    unet = pipe.unet

    # Freeze VAE completely
    for param in pipe.vae.parameters():
        param.requires_grad = False

    # Freeze 85% of UNet
    total_params = sum(p.numel() for p in unet.parameters())
    frozen = 0

    for name, param in unet.named_parameters():
        if frozen < total_params * 0.85:
            param.requires_grad = False
            frozen += param.numel()
        elif 'attn' in name or 'norm' in name:
            param.requires_grad = True

    trainable = sum(p.numel() for p in unet.parameters() if p.requires_grad)
    print(f"✓ Trainable params: {trainable/1e6:.2f}M / {total_params/1e6:.2f}M")

    return pipe

pipe = prepare_svd_for_training(pipe)

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

✓ Trainable params: 228.61M / 1524.62M


In [None]:
# Cell 1: Fixed DAVIS Dataset that returns tensors
import torch
from torchvision import transforms

class DAVISDataset(Dataset):
    """Fixed DAVIS dataset that returns tensors"""

    def __init__(self, davis_path="DAVIS", sequences=None, num_frames=8):
        self.davis_path = Path(davis_path)
        self.num_frames = num_frames

        # Transform to convert PIL to tensor
        self.transform = transforms.Compose([
            transforms.Resize((288, 512)),
            transforms.ToTensor()
        ])

        # Get available sequences
        if sequences is None:
            self.sequences = list((self.davis_path / "JPEGImages/480p").iterdir())
        else:
            self.sequences = [self.davis_path / "JPEGImages/480p" / s for s in sequences]

        self.sequences = [s for s in self.sequences if s.exists()]
        print(f"Found {len(self.sequences)} sequences")

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq_path = self.sequences[idx]
        mask_path = self.davis_path / "Annotations/480p" / seq_path.name

        # Load frames
        frame_files = sorted(seq_path.glob("*.jpg"))[:self.num_frames]
        mask_files = sorted(mask_path.glob("*.png"))[:self.num_frames] if mask_path.exists() else []

        # Load first frame
        first_frame = Image.open(frame_files[0])
        first_frame = first_frame.resize((512, 288))

        if mask_files:
            first_mask = Image.open(mask_files[0]).convert('L')
            first_mask = first_mask.resize((512, 288))
            mask_array = np.array(first_mask) > 0
        else:
            mask_array = np.zeros((288, 512), dtype=bool)

        # Create edited frame by inpainting
        frame_array = np.array(first_frame)
        if mask_array.any():
            edited_array = cv2.inpaint(
                frame_array,
                mask_array.astype(np.uint8) * 255,
                3,
                cv2.INPAINT_TELEA
            )
            edited_frame = Image.fromarray(edited_array)
        else:
            edited_frame = first_frame

        # Convert to tensors (this fixes the error)
        return {
            'edited_frame': self.transform(edited_frame),      # Now a tensor
            'original_frame': self.transform(first_frame),     # Now a tensor
            'edited_frame_pil': edited_frame,                  # Keep PIL for pipeline
            'original_frame_pil': first_frame,                 # Keep PIL for pipeline
            'mask': torch.tensor(mask_array),                  # Tensor
            'sequence_name': seq_path.name                     # String is fine
        }

# Create dataset with fixed version
dataset = DAVISDataset(
    davis_path="DAVIS",
    sequences=['blackswan', 'car-shadow', 'dog'],
    num_frames=8
)

# Custom collate function to handle mixed types
def custom_collate(batch):
    """Custom collate to handle PIL images"""
    return {
        'edited_frame': torch.stack([b['edited_frame'] for b in batch]),
        'original_frame': torch.stack([b['original_frame'] for b in batch]),
        'edited_frame_pil': [b['edited_frame_pil'] for b in batch],  # List of PIL
        'original_frame_pil': [b['original_frame_pil'] for b in batch],  # List of PIL
        'mask': torch.stack([b['mask'] for b in batch]),
        'sequence_name': [b['sequence_name'] for b in batch]
    }

# Create dataloader with custom collate
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=custom_collate)

print(f"✓ Fixed dataset ready: {len(dataset)} sequences")

Found 3 sequences
✓ Fixed dataset ready: 3 sequences


In [None]:
# Cell 2: Fixed training loop
def train_on_davis(pipe, dataloader, num_epochs=2):
    """Fixed training loop that handles PIL images correctly"""

    trainable_params = [p for p in pipe.unet.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(trainable_params, lr=5e-6)
    mse_loss = nn.MSELoss()

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        total_loss = 0

        for batch_idx, batch in enumerate(tqdm(dataloader)):
            # Get PIL image for the pipeline (it expects PIL)
            edited_frame_pil = batch['edited_frame_pil'][0]
            sequence_name = batch['sequence_name'][0]

            try:
                # Generate video (pipeline expects PIL image)
                frames = pipe(
                    edited_frame_pil,  # Use PIL version
                    num_frames=4,
                    num_inference_steps=10,
                    height=288,
                    width=512,
                    generator=torch.manual_seed(42)
                ).frames[0]

                # Compute loss
                loss = torch.tensor(0.0, requires_grad=True)

                # Temporal consistency
                for i in range(1, len(frames)):
                    curr = torch.tensor(np.array(frames[i])).float() / 255.0
                    prev = torch.tensor(np.array(frames[i-1])).float() / 255.0
                    loss = loss + mse_loss(curr, prev) * 0.1

                # Optimize
                if loss.requires_grad:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()

                print(f"  Batch {batch_idx}: {sequence_name}, Loss: {loss.item():.4f}")

                # Save sample
                if batch_idx == 0:
                    frames[0].save(f"epoch_{epoch}_sample.png")

            except Exception as e:
                print(f"  Error: {e}")
                continue

            torch.cuda.empty_cache()
            gc.collect()

        print(f"  Average loss: {total_loss/len(dataloader):.4f}")

    return pipe

# Train with fixed dataset
pipe = train_on_davis(pipe, dataloader, num_epochs=2)


Epoch 1/2


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

 33%|███▎      | 1/3 [00:02<00:05,  2.71s/it]

  Error: Sizes of tensors must match except in dimension 1. Expected size 10 but got size 9 for tensor number 1 in the list.


  0%|          | 0/10 [00:00<?, ?it/s]

 67%|██████▋   | 2/3 [00:05<00:02,  2.71s/it]

  Error: Sizes of tensors must match except in dimension 1. Expected size 10 but got size 9 for tensor number 1 in the list.


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 3/3 [00:08<00:00,  2.71s/it]


  Error: Sizes of tensors must match except in dimension 1. Expected size 10 but got size 9 for tensor number 1 in the list.
  Average loss: 0.0000

Epoch 2/2


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

 33%|███▎      | 1/3 [00:02<00:05,  2.94s/it]

  Error: Sizes of tensors must match except in dimension 1. Expected size 10 but got size 9 for tensor number 1 in the list.


  0%|          | 0/10 [00:00<?, ?it/s]

 67%|██████▋   | 2/3 [00:05<00:02,  2.81s/it]

  Error: Sizes of tensors must match except in dimension 1. Expected size 10 but got size 9 for tensor number 1 in the list.


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 3/3 [00:08<00:00,  2.79s/it]

  Error: Sizes of tensors must match except in dimension 1. Expected size 10 but got size 9 for tensor number 1 in the list.
  Average loss: 0.0000





In [None]:
# Cell 5: Test the trained model
def test_on_davis_sequence(pipe, dataset, seq_idx=0):
    """Test on a specific DAVIS sequence"""

    sample = dataset[seq_idx]
    edited_frame = sample['edited_frame']
    original_frame = sample['original_frame']
    seq_name = sample['sequence_name']

    print(f"Testing on sequence: {seq_name}")

    # Generate video
    frames = pipe(
        edited_frame,
        num_frames=8,
        num_inference_steps=15,
        height=288,
        width=512,
        generator=torch.manual_seed(42)
    ).frames[0]

    # Save as GIF
    frames[0].save(
        f"test_{seq_name}.gif",
        save_all=True,
        append_images=frames[1:],
        duration=100,
        loop=0
    )

    # Also save comparison
    fig, axes = plt.subplots(2, 4, figsize=(15, 8))

    # Top row: Generated frames
    for i in range(4):
        axes[0, i].imshow(frames[i])
        axes[0, i].set_title(f"Generated Frame {i}")
        axes[0, i].axis('off')

    # Bottom row: Original frames from dataset
    all_frames = sample['all_frames']
    for i in range(min(4, len(all_frames))):
        axes[1, i].imshow(all_frames[i])
        axes[1, i].set_title(f"Original Frame {i}")
        axes[1, i].axis('off')

    plt.suptitle(f"Results for {seq_name}")
    plt.tight_layout()
    plt.savefig(f"comparison_{seq_name}.png")
    plt.show()

    print(f"✓ Saved test_{seq_name}.gif")

    torch.cuda.empty_cache()

# Test on each sequence
for i in range(min(3, len(dataset))):
    test_on_davis_sequence(pipe, dataset, i)

Testing on sequence: blackswan


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 3 but got size 1280 for tensor number 1 in the list.

In [None]:
!pip install opencv-python-headless --quiet

import os
import math
import random
from glob import glob

import cv2
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [None]:
!wget -q https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip
!unzip -q DAVIS-2017-trainval-480p.zip
!ls
!ls DAVIS
!ls DAVIS/JPEGImages
!ls DAVIS/Annotations


DAVIS  DAVIS-2017-trainval-480p.zip  sample_data
Annotations  ImageSets	JPEGImages  README.md  SOURCES.md
480p
480p


In [None]:
DAVIS_ROOT = "DAVIS"
IMG_ROOT = os.path.join(DAVIS_ROOT, "JPEGImages", "480p")
MASK_ROOT = os.path.join(DAVIS_ROOT, "Annotations", "480p")

print("Example sequences:", os.listdir(IMG_ROOT)[:5])

def read_image_rgb(path):
    img = cv2.imread(path)  # BGR
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def read_mask_binary(path):
    m = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    if m is None:
        raise FileNotFoundError(path)
    # multi-object → treat any non-zero as object
    m = (m > 0).astype(np.uint8)
    return m

def resize_pair(img, mask, short_side=256, max_side=384):
    h, w = img.shape[:2]
    scale = short_side / min(h, w)
    if max(h, w) * scale > max_side:
        scale = max_side / max(h, w)
    new_h, new_w = int(h * scale), int(w * scale)
    img_r = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
    mask_r = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
    mask_r = (mask_r > 0).astype(np.uint8)
    return img_r, mask_r

def inpaint_frame(img_rgb, mask_binary, radius=3):
    bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
    inpainted = cv2.inpaint(bgr, mask_binary * 255, radius, cv2.INPAINT_TELEA)
    return cv2.cvtColor(inpainted, cv2.COLOR_BGR2RGB)

def video_to_tensor(frames):
    """
    frames: list of HxWx3 uint8
    returns: [3, T, H, W] float32 in [0,1]
    """
    arr = np.stack(frames, axis=0).astype(np.float32) / 255.0  # [T,H,W,3]
    arr = arr.transpose(3, 0, 1, 2)  # [3,T,H,W]
    return torch.from_numpy(arr)

def masks_to_tensor(masks):
    """
    masks: list of HxW uint8 (0/1)
    returns: [1, T, H, W] float32 in {0,1}
    """
    arr = np.stack(masks, axis=0).astype(np.float32)  # [T,H,W]
    arr = arr[None, ...]  # [1,T,H,W]
    return torch.from_numpy(arr)


class DAVISInpaintDataset(Dataset):
    """
    Returns:
      orig_video:  [3,T,H,W]  original frames
      edit_video:  [3,T,H,W]  inpainted (object-removed) frames
      masks:       [1,T,H,W]  object region (1=object)
    """
    def __init__(self, img_root, mask_root, num_frames=8, short_side=256, max_side=384):
        self.img_root = img_root
        self.mask_root = mask_root
        self.seq_names = sorted(os.listdir(img_root))
        self.num_frames = num_frames
        self.short_side = short_side
        self.max_side = max_side

    def __len__(self):
        return len(self.seq_names)

    def _load_sequence_paths(self, seq):
        img_dir = os.path.join(self.img_root, seq)
        mask_dir = os.path.join(self.mask_root, seq)
        img_paths = sorted(glob(os.path.join(img_dir, "*.jpg")))
        mask_paths = sorted(glob(os.path.join(mask_dir, "*.png")))
        T = min(len(img_paths), len(mask_paths))
        return img_paths[:T], mask_paths[:T]

    def __getitem__(self, idx):
        seq = self.seq_names[idx]
        img_paths, mask_paths = self._load_sequence_paths(seq)

        # If sequence is shorter than num_frames, loop frames
        if len(img_paths) >= self.num_frames:
            start = random.randint(0, len(img_paths) - self.num_frames)
            img_paths = img_paths[start:start + self.num_frames]
            mask_paths = mask_paths[start:start + self.num_frames]
        else:
            # loop to get num_frames
            rep = (self.num_frames + len(img_paths) - 1) // len(img_paths)
            img_paths = (img_paths * rep)[:self.num_frames]
            mask_paths = (mask_paths * rep)[:self.num_frames]

        frames_orig = []
        frames_edit = []
        masks = []

        for ip, mp in zip(img_paths, mask_paths):
            img = read_image_rgb(ip)
            m = read_mask_binary(mp)
            img, m = resize_pair(img, m, self.short_side, self.max_side)
            frames_orig.append(img)
            masks.append(m)
            frames_edit.append(inpaint_frame(img, m))

        orig_video = video_to_tensor(frames_orig)  # [3,T,H,W]
        edit_video = video_to_tensor(frames_edit)  # [3,T,H,W]
        mask_seq = masks_to_tensor(masks)          # [1,T,H,W]

        return {
            "seq_name": seq,
            "orig_video": orig_video,
            "edit_video": edit_video,
            "mask_seq": mask_seq,
        }

# quick sanity check
dataset = DAVISInpaintDataset(IMG_ROOT, MASK_ROOT, num_frames=8)
sample = dataset[0]
print("orig_video:", sample["orig_video"].shape)
print("edit_video:", sample["edit_video"].shape)
print("mask_seq:", sample["mask_seq"].shape)


Example sequences: ['dance-twirl', 'train', 'scooter-gray', 'dancing', 'drift-straight']
orig_video: torch.Size([3, 8, 215, 384])
edit_video: torch.Size([3, 8, 215, 384])
mask_seq: torch.Size([1, 8, 215, 384])


In [None]:
# Simple cosine or linear schedule – we use linear for simplicity
T_DIFF = 1000
betas = torch.linspace(1e-4, 0.02, T_DIFF, device=device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

def q_sample(x0, t, noise):
    """
    x0: [B,3,T,H,W]
    t:  [B] int64
    noise: same shape as x0
    """
    B = x0.shape[0]
    sqrt_ac = sqrt_alphas_cumprod[t].view(B, 1, 1, 1, 1)
    sqrt_om = sqrt_one_minus_alphas_cumprod[t].view(B, 1, 1, 1, 1)
    return sqrt_ac * x0 + sqrt_om * noise

def timestep_embedding(timesteps, dim):
    """
    Sinusoidal time embedding (like in many diffusion models).
    timesteps: [B]
    returns: [B, dim]
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(10000) * torch.arange(0, half, device=timesteps.device) / float(half)
    )
    args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    if dim % 2 == 1:
        emb = F.pad(emb, (0, 1))
    return emb


In [None]:
class ConvBlock3D(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1)
        self.norm = nn.GroupNorm(8, out_ch)
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))

class DownBlock3D(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = ConvBlock3D(in_ch, out_ch)
        self.conv2 = ConvBlock3D(out_ch, out_ch)
        self.down = nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=(1,2,2), padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        skip = x
        x = self.down(x)
        return x, skip

class UpBlock3D(nn.Module):
    def __init__(self, in_ch_from_down, out_ch_this_block, skip_ch_from_down):
        super().__init__()
        self.up = nn.ConvTranspose3d(
            in_ch_from_down, out_ch_this_block, kernel_size=3, stride=(1,2,2),
            padding=1, output_padding=(0,1,1)
        )
        self.conv1 = ConvBlock3D(out_ch_this_block + skip_ch_from_down, out_ch_this_block) # Corrected input channels
        self.conv2 = ConvBlock3D(out_ch_this_block, out_ch_this_block)

    def forward(self, x, skip):
        x = self.up(x)
        # crop/pad to match skip
        _, _, T, H, W = skip.shape
        _, _, T2, H2, W2 = x.shape
        if T2 != T:
            x = F.interpolate(x, size=(T, H2, W2), mode="nearest")
        if H2 != H or W2 != W:
            x = F.interpolate(x, size=(T, H, W), mode="trilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

In [None]:
class SCEEncoder(nn.Module):
    """
    Selective Content Encoder (SCE).
    Takes conditioning video [B,3,T,H,W] (here: edited-first tiled across time).
    """
    def __init__(self, in_ch=3, base_ch=64):
        super().__init__()
        self.down1 = DownBlock3D(in_ch, base_ch)
        self.down2 = DownBlock3D(base_ch, base_ch * 2)
        self.down3 = DownBlock3D(base_ch * 2, base_ch * 4)

    def forward(self, x):
        feats = {}
        x1, s1 = self.down1(x)
        x2, s2 = self.down2(x1)
        x3, s3 = self.down3(x2)
        feats["x3"] = x3
        feats["s1"] = s1
        feats["s2"] = s2
        feats["s3"] = s3
        return feats

class MaskPredictionDecoder(nn.Module):
    """
    MPD: predicts edited region mask from bottleneck features.
    Input: [B,C,T,H,W] → Output: [B,1,T,H,W] logits
    """
    def __init__(self, in_ch, mid_ch=64):
        super().__init__()
        self.conv1 = ConvBlock3D(in_ch, mid_ch)
        self.conv2 = ConvBlock3D(mid_ch, mid_ch)
        self.conv3 = nn.Conv3d(mid_ch, 1, kernel_size=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return self.conv3(x)

class GenPropUNet(nn.Module):
    """
    Simplified GenProp-like video diffusion model.

    Input:
      noisy_video: [B,3,T,H,W]  (noised target video)
      cond_video:  [B,3,T,H,W]  (edited first frame tiled across time)
      t:           [B]          (diffusion step)
    Output:
      noise_pred:  [B,3,T,H,W]
      mask_logits: [B,1,T,H',W']  (mask at bottleneck resolution)
    """
    def __init__(self, in_ch=3, base_ch=64, time_dim=256):
        super().__init__()
        self.in_ch = in_ch

        # Time embedding
        self.time_dim = time_dim
        self.time_mlp = nn.Sequential(
            nn.Linear(time_dim, time_dim),
            nn.SiLU(),
            nn.Linear(time_dim, base_ch * 4),
        )

        # Main UNet encoder
        self.down1 = DownBlock3D(in_ch, base_ch)
        self.down2 = DownBlock3D(base_ch, base_ch * 2)
        self.down3 = DownBlock3D(base_ch * 2, base_ch * 4)

        self.mid = ConvBlock3D(base_ch * 4, base_ch * 4)

        # Decoder
        # Changed UpBlock3D initialization to pass skip_channels
        self.up3 = UpBlock3D(base_ch * 4, base_ch * 2, base_ch * 4) # in_ch_from_down, out_ch_this_block, skip_ch_from_down
        self.up2 = UpBlock3D(base_ch * 2, base_ch, base_ch * 2)
        self.out_conv = nn.Conv3d(base_ch, in_ch, kernel_size=3, padding=1)

        # SCE encoder
        self.sce = SCEEncoder(in_ch=in_ch, base_ch=base_ch)
        self.sce_scale = nn.Parameter(torch.zeros(1))  # learnable scaling

        # MPD
        self.mpd = MaskPredictionDecoder(in_ch=base_ch * 4, mid_ch=base_ch)

    def forward(self, noisy_video, cond_video, t):
        """
        noisy_video, cond_video: [B,3,T,H,W], in [-1,1]
        t: [B]
        """
        B, C, T, H, W = noisy_video.shape

        # Time embedding
        t_emb = timestep_embedding(t, self.time_dim)  # [B,time_dim]
        t_emb = self.time_mlp(t_emb)                  # [B,base_ch*4]
        t_emb = t_emb.view(B, -1, 1, 1, 1)            # [B,base_ch*4,1,1,1]

        # SCE features from cond_video
        sce_feats = self.sce(cond_video)              # feats["x3"] is bottleneck-like

        # Main encoder
        h = noisy_video
        h1, skip1 = self.down1(h)                     # [B,base_ch,T,H/2,W/2]
        h2, skip2 = self.down2(h1)                    # [B,2base_ch,T,H/4,W/4]
        h3, skip3 = self.down3(h2)                    # [B,4base_ch,T,H/8,W/8]

        # Inject SCE at bottleneck + time
        # Align sce_feats["x3"] if needed
        content = sce_feats["x3"]
        if content.shape[2:] != h3.shape[2:]:
            content = F.interpolate(
                content, size=h3.shape[2:], mode="trilinear", align_corners=False
            )
        h3 = h3 + self.sce_scale * content + t_emb

        mid = self.mid(h3)                            # [B,4base_ch,T,H/8,W/8]

        # MPD on mid
        mask_logits = self.mpd(mid)                   # [B,1,T,H/8,W/8]

        # Decoder
        u3 = self.up3(mid, skip3)                     # [B,2base_ch,T,H/4,W/4]
        u2 = self.up2(u3, skip2)                      # [B,base_ch,T,H/2,W/2]
        out = self.out_conv(u2)                       # [B,3,T,H/2,W/2]

        # Optional: upsample back to original resolution
        out = F.interpolate(
            out, size=(T, H, W), mode="trilinear", align_corners=False
        )

        return out, mask_logits

In [None]:
def region_aware_mse(noise_pred, noise_true, mask_seq, lambda_bg=2.0):
    """
    noise_pred, noise_true: [B,3,T,H,W]
    mask_seq: [B,1,T,H,W] (1 inside edited region, 0 background)
    """
    mse = (noise_pred - noise_true) ** 2
    w = mask_seq * 1.0 + (1.0 - mask_seq) * lambda_bg
    w = w.expand_as(mse)
    return (mse * w).mean()

def mask_bce_loss(mask_logits, mask_seq):
    """
    mask_logits: [B,1,T,H',W']
    mask_seq:    [B,1,T,H,W]  original mask
    We downsample mask_seq to H',W' using nearest.
    """
    B, _, T, Hh, Wh = mask_logits.shape
    mask_small = F.interpolate(
        mask_seq, size=(T, Hh, Wh), mode="nearest"
    )
    return F.binary_cross_entropy_with_logits(mask_logits, mask_small)


In [None]:
def train_genprop(
    num_epochs=1,
    batch_size=1,
    num_frames=8,
    lambda_bg=2.0,
    beta_mpd=0.1,
    lr=1e-4,
):
    dataset = DAVISInpaintDataset(IMG_ROOT, MASK_ROOT, num_frames=num_frames)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    model = GenPropUNet(in_ch=3, base_ch=64, time_dim=256).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        model.train()
        pbar = tqdm(loader, desc=f"GenProp Epoch {epoch}")
        for batch in pbar:
            edit_video = batch["edit_video"].to(device)   # [B,3,T,H,W] in [0,1]
            mask_seq = batch["mask_seq"].to(device)       # [B,1,T,H,W]

            B, C, T, H, W = edit_video.shape

            # Normalize to [-1,1]
            x0 = edit_video * 2.0 - 1.0

            # Conditioning: use edited first frame repeated across time
            first_frame = x0[:, :, 0]                     # [B,3,H,W]
            cond_video = first_frame.unsqueeze(2).repeat(1, 1, T, 1, 1)  # [B,3,T,H,W]

            # Sample noise and t
            noise = torch.randn_like(x0)
            t = torch.randint(0, T_DIFF, (B,), device=device, dtype=torch.long)
            x_t = q_sample(x0, t, noise)

            # Forward
            noise_pred, mask_logits = model(x_t, cond_video, t)

            # Losses
            loss_diff = region_aware_mse(noise_pred, noise, mask_seq, lambda_bg=lambda_bg)
            loss_mask = mask_bce_loss(mask_logits, mask_seq)
            loss = loss_diff + beta_mpd * loss_mask

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            pbar.set_postfix(loss=loss.item(), diff=loss_diff.item(), mpd=loss_mask.item())

        # Save checkpoint each epoch
        torch.save(model.state_dict(), f"genprop_epoch{epoch}.pth")

    return model

# Train a tiny prototype run (you can increase epochs later)
genprop_model = train_genprop(
    num_epochs=1,
    batch_size=1,
    num_frames=8,
    lambda_bg=2.0,
    beta_mpd=0.1,
    lr=1e-4,
)


GenProp Epoch 0:   0%|          | 0/90 [00:00<?, ?it/s]

In [None]:
@torch.no_grad()
def genprop_single_step_demo(model, dataset, index=0, t_val=500):
    model.eval()
    batch = dataset[index]
    edit_video = batch["edit_video"].unsqueeze(0).to(device)  # [1,3,T,H,W]
    mask_seq = batch["mask_seq"].unsqueeze(0).to(device)      # [1,1,T,H,W]

    x0 = edit_video * 2.0 - 1.0
    B, C, T, H, W = x0.shape

    first_frame = x0[:, :, 0]
    cond_video = first_frame.unsqueeze(2).repeat(1, 1, T, 1, 1)

    t = torch.tensor([t_val], device=device, dtype=torch.long)
    noise = torch.randn_like(x0)
    x_t = q_sample(x0, t, noise)

    noise_pred, mask_logits = model(x_t, cond_video, t)
    # estimate x0_hat ~ (x_t - sqrt(1 - alpha_bar_t) * eps_theta) / sqrt(alpha_bar_t)
    sqrt_ac = sqrt_alphas_cumprod[t_val]
    sqrt_om = sqrt_one_minus_alphas_cumprod[t_val]
    x0_hat = (x_t - sqrt_om * noise_pred) / sqrt_ac
    x0_hat = torch.clamp((x0_hat + 1.0) / 2.0, 0.0, 1.0)

    return edit_video.cpu(), x0_hat.cpu(), mask_seq.cpu()

# quick demo
demo_dataset = DAVISInpaintDataset(IMG_ROOT, MASK_ROOT, num_frames=8)
orig_edit, recon_edit, demo_mask = genprop_single_step_demo(genprop_model, demo_dataset, index=0, t_val=500)
print("orig_edit:", orig_edit.shape, "recon_edit:", recon_edit.shape)


orig_edit: torch.Size([1, 3, 8, 215, 384]) recon_edit: torch.Size([1, 3, 8, 215, 384])


In [None]:
class OneStepVideoGenerator(nn.Module):
    """
    One-step generator:
      Input: edited first frame [B,3,H,W] in [-1,1]
      Output: full video [B,3,T,H,W] in [-1,1]
    """
    def __init__(self, num_frames=8, base_ch=64, z_dim=128):
        super().__init__()
        self.num_frames = num_frames
        self.z_dim = z_dim

        # 2D encoder
        self.enc = nn.Sequential(
            nn.Conv2d(3, base_ch, 3, padding=1),
            nn.GroupNorm(8, base_ch),
            nn.SiLU(),
            nn.Conv2d(base_ch, base_ch * 2, 4, stride=2, padding=1),
            nn.GroupNorm(8, base_ch * 2),
            nn.SiLU(),
            nn.Conv2d(base_ch * 2, base_ch * 4, 4, stride=2, padding=1),
            nn.GroupNorm(8, base_ch * 4),
            nn.SiLU(),
        )

        # noise + global fusion
        self.fc_z = nn.Linear(z_dim, base_ch * 4)
        self.fc_time = nn.Linear(base_ch * 4, base_ch * 4 * num_frames)

        # 3D decoder
        self.dec = nn.Sequential(
            nn.ConvTranspose3d(
                base_ch * 4, base_ch * 2,
                kernel_size=(3,4,4),
                stride=(1,2,2),
                padding=(1,1,1),
                output_padding=(0,1,1),
            ),
            nn.GroupNorm(8, base_ch * 2),
            nn.SiLU(),
            nn.ConvTranspose3d(
                base_ch * 2, base_ch,
                kernel_size=(3,4,4),
                stride=(1,2,2),
                padding=(1,1,1),
                output_padding=(0,1,1),
            ),
            nn.GroupNorm(8, base_ch),
            nn.SiLU(),
            nn.Conv3d(base_ch, 3, kernel_size=3, padding=1),
            nn.Tanh(),
        )

    def forward(self, first_frame, noise=None):
        B, C, target_H, target_W = first_frame.shape # Capture target dimensions
        x = self.enc(first_frame)                     # [B,4C,H',W']
        _, C_enc, H_enc, W_enc = x.shape

        if noise is None:
            noise = torch.randn(B, self.z_dim, device=first_frame.device)
        z = self.fc_z(noise)                          # [B,4C]
        x_pool = F.adaptive_avg_pool2d(x, 1).view(B, -1)
        fused = x_pool + z                            # [B,4C]

        t_embed = self.fc_time(fused)                 # [B,4C*T]
        t_embed = t_embed.view(B, C_enc, self.num_frames, 1, 1)

        x_spatial = x.unsqueeze(2).expand(-1, -1, self.num_frames, -1, -1)
        feats = x_spatial + t_embed                   # [B,4C,T,H',W']

        out = self.dec(feats)                         # [B,3,T,H_gen,W_gen] in [-1,1]

        # Interpolate output to match target H, W
        if out.shape[3] != target_H or out.shape[4] != target_W:
            out = F.interpolate(
                out.view(B*self.num_frames, C, out.shape[3], out.shape[4]),
                size=(target_H, target_W),
                mode='bilinear',
                align_corners=False
            ).view(B, C, self.num_frames, target_H, target_W)

        return out

class VideoDiscriminator(nn.Module):
    """
    3D conv video discriminator.
    Input: [B,3,T,H,W] in [-1,1]
    """
    def __init__(self, base_ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(3, base_ch, 3, stride=(1,2,2), padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv3d(base_ch, base_ch * 2, 3, stride=(1,2,2), padding=1),
            nn.BatchNorm3d(base_ch * 2),
            nn.LeakyReLU(0.2),
            nn.Conv3d(base_ch * 2, base_ch * 4, 3, stride=(2,2,2), padding=1),
            nn.BatchNorm3d(base_ch * 4),
            nn.LeakyReLU(0.2),
            nn.Conv3d(base_ch * 4, base_ch * 8, 3, stride=(2,2,2), padding=1),
            nn.BatchNorm3d(base_ch * 8),
            nn.LeakyReLU(0.2),
        )
        self.fc = nn.Linear(base_ch * 8, 1)

    def forward(self, video):
        x = self.net(video)          # [B,C',T',H',W']
        x = x.mean(dim=[2,3,4])      # global pooling
        return self.fc(x)            # [B,1]

In [None]:
def train_onestep(
    num_epochs=1,
    batch_size=1,
    num_frames=8,
    lr_g=1e-4,
    lr_d=1e-4,
    lambda_l1=50.0,
    lambda_bg=2.0,
):
    dataset = DAVISInpaintDataset(IMG_ROOT, MASK_ROOT, num_frames=num_frames)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    G = OneStepVideoGenerator(num_frames=num_frames).to(device)
    D = VideoDiscriminator().to(device)

    opt_G = torch.optim.Adam(G.parameters(), lr=lr_g, betas=(0.5, 0.999))
    opt_D = torch.optim.Adam(D.parameters(), lr=lr_d, betas=(0.5, 0.999))

    for epoch in range(num_epochs):
        pbar = tqdm(loader, desc=f"FastEdit Epoch {epoch}")
        for batch in pbar:
            edit_video = batch["edit_video"].to(device)  # [B,3,T,H,W] in [0,1]
            mask_seq = batch["mask_seq"].to(device)      # [B,1,T,H,W]

            B, C, T, H, W = edit_video.shape
            # Normalize to [-1,1]
            real_video = edit_video * 2.0 - 1.0          # [B,3,T,H,W]

            # Generator input: first frame
            first_frame = real_video[:, :, 0]            # [B,3,H,W]
            fake_video = G(first_frame)                  # [B,3,T,H,W]

            # ---- Train Discriminator (hinge loss) ----
            D_real = D(real_video)
            D_fake = D(fake_video.detach())
            loss_D = torch.relu(1.0 - D_real).mean() + torch.relu(1.0 + D_fake).mean()

            opt_D.zero_grad()
            loss_D.backward()
            opt_D.step()

            # ---- Train Generator ----
            D_fake_for_G = D(fake_video)
            loss_G_adv = -D_fake_for_G.mean()

            # Region-aware L1 in pixel space
            diff = fake_video - real_video
            w = mask_seq * 1.0 + (1.0 - mask_seq) * lambda_bg  # [B,1,T,H,W]
            w = w.expand_as(diff)
            loss_L1 = (w * diff.abs()).mean()

            loss_G = loss_G_adv + lambda_l1 * loss_L1

            opt_G.zero_grad()
            loss_G.backward()
            opt_G.step()

            pbar.set_postfix(D=loss_D.item(), G_adv=loss_G_adv.item(), L1=loss_L1.item())

        torch.save(G.state_dict(), f"fastedit_G_epoch{epoch}.pth")
        torch.save(D.state_dict(), f"fastedit_D_epoch{epoch}.pth")

    return G, D

# Tiny sanity run
G_fastedit, D_fastedit = train_onestep(
    num_epochs=1,
    batch_size=1,
    num_frames=8,
    lr_g=1e-4,
    lr_d=1e-4,
    lambda_l1=50.0,
    lambda_bg=2.0,
)


FastEdit Epoch 0:   0%|          | 0/90 [00:00<?, ?it/s]

In [None]:
!pip install -q opencv-python-headless imageio

import os
import math
import random
from glob import glob
from pathlib import Path

import cv2
import imageio
import numpy as np
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [None]:
!wget https://graphics.ethz.ch/Downloads/Data/Davis/DAVIS-data.zip -O DAVIS.zip

--2025-11-17 15:15:19--  https://graphics.ethz.ch/Downloads/Data/Davis/DAVIS-data.zip
Resolving graphics.ethz.ch (graphics.ethz.ch)... 129.132.145.103
Connecting to graphics.ethz.ch (graphics.ethz.ch)|129.132.145.103|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://cgl.ethz.ch/Downloads/Data/Davis/DAVIS-data.zip [following]
--2025-11-17 15:15:19--  https://cgl.ethz.ch/Downloads/Data/Davis/DAVIS-data.zip
Resolving cgl.ethz.ch (cgl.ethz.ch)... 129.132.145.103
Connecting to cgl.ethz.ch (cgl.ethz.ch)|129.132.145.103|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1958016855 (1.8G) [application/zip]
Saving to: ‘DAVIS.zip’


2025-11-17 15:15:49 (62.4 MB/s) - ‘DAVIS.zip’ saved [1958016855/1958016855]



In [None]:
!unzip -q DAVIS.zip

In [None]:
from pathlib import Path
import cv2
import numpy as np
import torch

# Adjusted for Colab
DATA_ROOT = Path("/content/DAVIS")
IMG_ROOT = DATA_ROOT / "JPEGImages" / "480p"
MASK_ROOT = DATA_ROOT / "Annotations" / "480p"

print("Has images:", IMG_ROOT.exists(), "Has masks:", MASK_ROOT.exists())


def read_rgb(path):
    img = cv2.imread(str(path))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def read_mask(path):
    m = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
    m = (m > 0).astype(np.uint8)  # 1 = object, 0 = background
    return m

def resize_pair(img, mask, short_side=256, max_side=384):
    h, w = img.shape[:2]
    scale = short_side / min(h, w)
    if max(h, w) * scale > max_side:
        scale = max_side / max(h, w)
    new_h, new_w = int(h * scale), int(w * scale)
    img_r = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
    mask_r = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
    mask_r = (mask_r > 0).astype(np.uint8)
    return img_r, mask_r

def inpaint_frame(img_rgb, mask_binary, radius=3):
    bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
    inpainted = cv2.inpaint(bgr, mask_binary * 255, radius, cv2.INPAINT_TELEA)
    return cv2.cvtColor(inpainted, cv2.COLOR_BGR2RGB)

def video_to_tensor(frames):
    # frames: list of HxWx3 uint8
    arr = np.stack(frames, axis=0).astype(np.float32) / 255.0  # [T,H,W,3]
    arr = arr.transpose(3, 0, 1, 2)  # [3,T,H,W]
    return torch.from_numpy(arr)

def masks_to_tensor(masks):
    arr = np.stack(masks, axis=0).astype(np.float32)  # [T,H,W]
    arr = arr[None, ...]  # [1,T,H,W]
    return torch.from_numpy(arr)

Has images: True Has masks: True


In [None]:
class DAVISObjectRemovalDataset(Dataset):
    """
    Returns, per sequence:
      orig_video:   [3,T,H,W]  original frames
      edit_video:   [3,T,H,W]  inpainted (object removed) frames
      mask_seq:     [1,T,H,W]  object masks
      orig_first:   [3,H,W]
      edit_first:   [3,H,W]
      first_mask:   [1,H,W]
    """
    def __init__(self, img_root, mask_root, num_frames=8,
                 short_side=256, max_side=384):
        self.img_root = img_root
        self.mask_root = mask_root
        self.seq_names = sorted([p.name for p in img_root.iterdir() if p.is_dir()])
        self.num_frames = num_frames
        self.short_side = short_side
        self.max_side = max_side

    def __len__(self):
        return len(self.seq_names)

    def _paths_for_seq(self, seq):
        img_dir = self.img_root / seq
        mask_dir = self.mask_root / seq
        img_paths = sorted(img_dir.glob("*.jpg"))
        mask_paths = sorted(mask_dir.glob("*.png"))
        T = min(len(img_paths), len(mask_paths))
        return img_paths[:T], mask_paths[:T]

    def __getitem__(self, idx):
        seq = self.seq_names[idx]
        img_paths, mask_paths = self._paths_for_seq(seq)

        if len(img_paths) >= self.num_frames:
            start = random.randint(0, len(img_paths) - self.num_frames)
            img_paths = img_paths[start:start + self.num_frames]
            mask_paths = mask_paths[start:start + self.num_frames]
        else:
            rep = (self.num_frames + len(img_paths) - 1) // len(img_paths)
            img_paths = (img_paths * rep)[:self.num_frames]
            mask_paths = (mask_paths * rep)[:self.num_frames]

        frames_orig, frames_edit, masks = [], [], []
        for ip, mp in zip(img_paths, mask_paths):
            img = read_rgb(ip)
            m = read_mask(mp)
            img, m = resize_pair(img, m, self.short_side, self.max_side)
            frames_orig.append(img)
            masks.append(m)
            frames_edit.append(inpaint_frame(img, m))

        orig_video = video_to_tensor(frames_orig)   # [3,T,H,W]
        edit_video = video_to_tensor(frames_edit)   # [3,T,H,W]
        mask_seq   = masks_to_tensor(masks)         # [1,T,H,W]

        orig_first = orig_video[:, 0]               # [3,H,W]
        edit_first = edit_video[:, 0]               # [3,H,W]
        first_mask = mask_seq[:, 0]                 # [1,H,W]

        return {
            "seq": seq,
            "orig_video": orig_video,
            "edit_video": edit_video,
            "mask_seq": mask_seq,
            "orig_first": orig_first,
            "edit_first": edit_first,
            "first_mask": first_mask,
        }

# quick sanity check
dataset = DAVISObjectRemovalDataset(IMG_ROOT, MASK_ROOT, num_frames=8)
sample = dataset[0]
for k, v in sample.items():
    print(k, v.shape if torch.is_tensor(v) else v)


seq bear
orig_video torch.Size([3, 8, 215, 384])
edit_video torch.Size([3, 8, 215, 384])
mask_seq torch.Size([1, 8, 215, 384])
orig_first torch.Size([3, 215, 384])
edit_first torch.Size([3, 215, 384])
first_mask torch.Size([1, 215, 384])


In [None]:
class SelectiveContentEncoder(nn.Module):
    """
    Takes:
      orig_first: [B,3,H,W]
      edit_first: [B,3,H,W]
      first_mask: [B,1,H,W]
    Produces:
      cond_feat: [B,F,H/4,W/4]  (used for injection)
    """
    def __init__(self, in_ch=3+3+1, feat_ch=32):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),   # H/2 (Changed kernel_size from 4 to 3)
            nn.ReLU(inplace=True),
            nn.Conv2d(64, feat_ch, 3, stride=2, padding=1), # H/4 (Changed kernel_size from 4 to 3)
            nn.ReLU(inplace=True),
        )

    def forward(self, orig_first, edit_first, first_mask):
        x = torch.cat([orig_first, edit_first, first_mask], dim=1)
        feat = self.encoder(x)  # [B,feat_ch,H/4,W/4]
        return feat

In [None]:
class VideoEncoder3D(nn.Module):
    """
    Simple 3D encoder:
      Input : [B,3,T,H,W]
      Output: [B,C,T/2,H/4,W/4]
    """
    def __init__(self, in_ch=3, base_ch=32, out_ch=64):
        super().__init__()
        self.conv1 = nn.Conv3d(in_ch, base_ch, 3, padding=1)
        self.conv2 = nn.Conv3d(base_ch, base_ch, 3, stride=(1,2,2), padding=1)  # H/2,W/2
        self.conv3 = nn.Conv3d(base_ch, base_ch*2, 3, stride=(2,2,2), padding=1) # T/2,H/4,W/4
        self.conv4 = nn.Conv3d(base_ch*2, out_ch, 3, padding=1)

    def forward(self, x):
        # x: [B,3,T,H,W]
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        return x  # [B,out_ch,T/2,H/4,W/4]


In [None]:
class VideoGenerator3D(nn.Module):
    """
    Latent → edited latent → decoded to video.
    Input:
      z: [B,Cz,Tz,Hf,Wf]  from VideoEncoder
      cond_feat: [B,F,Hf,Wf] from SCE
    Output:
      edited_video: [B,3,T,H,W]
    """
    def __init__(self, latent_ch=64, cond_ch=32, base_ch=64, out_frames=8):
        super().__init__()
        self.out_frames = out_frames

        # Project cond_feat into same spatial size and concat with z
        self.cond_proj = nn.Conv2d(cond_ch, latent_ch, 1)

        # Simple 3D UNet-ish decoder in latent space
        self.dec1 = nn.Conv3d(latent_ch*2, base_ch*2, 3, padding=1)
        self.dec2 = nn.ConvTranspose3d(
            base_ch*2, base_ch, kernel_size=4, stride=(2,2,2), padding=1
        )  # upsample T,H,W by 2
        self.dec3 = nn.ConvTranspose3d(
            base_ch, 32, kernel_size=(1,4,4), stride=(1,2,2),
            padding=(0,1,1)
        )

        self.to_rgb = nn.Conv3d(32, 3, 3, padding=1)

    def forward(self, z, cond_feat, target_H, target_W): # Added target_H, target_W
        """
        z: [B,Cz,Tz,Hf,Wf]
        cond_feat: [B,F,Hf,Wf]
        target_H, target_W: target spatial dimensions for the output video
        """
        B, Cz, Tz, Hf, Wf = z.shape
        cond = self.cond_proj(cond_feat)            # [B,Cz,Hf,Wf]
        cond = cond.unsqueeze(2).expand(-1, -1, Tz, -1, -1)  # [B,Cz,Tz,Hf,Wf]

        x = torch.cat([z, cond], dim=1)             # [B,2Cz,Tz,Hf,Wf]
        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))                    # [B,base_ch,T,H',W']
        x = F.relu(self.dec3(x))                    # [B,32,T,H_gen,W_gen]

        rgb = torch.tanh(self.to_rgb(x))            # [-1,1], [B,3,T,H_gen,W_gen]

        # Interpolate output to match target H, W
        if rgb.shape[3] != target_H or rgb.shape[4] != target_W:
            rgb = F.interpolate(
                rgb.view(B * self.out_frames, rgb.shape[1], rgb.shape[3], rgb.shape[4]),
                size=(target_H, target_W),
                mode='bilinear',
                align_corners=False
            ).view(B, rgb.shape[1], self.out_frames, target_H, target_W)

        return rgb


In [None]:
class VideoDiscriminator3D(nn.Module):
    """
    3D GAN discriminator:
      Input: [B,3,T,H,W] in [-1,1]
      Output: [B,1]
    """
    def __init__(self, base_ch=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(3, base_ch, 3, stride=(1,2,2), padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(base_ch, base_ch*2, 3, stride=(2,2,2), padding=1),
            nn.BatchNorm3d(base_ch*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(base_ch*2, base_ch*4, 3, stride=(2,2,2), padding=1),
            nn.BatchNorm3d(base_ch*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(base_ch*4, base_ch*8, 3, stride=(2,2,2), padding=1),
            nn.BatchNorm3d(base_ch*8),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.fc = nn.Linear(base_ch*8, 1)

    def forward(self, v):
        x = self.net(v)            # [B,C,T',H',W']
        x = x.mean(dim=[2,3,4])    # global pooling
        return self.fc(x)          # [B,1]


In [None]:
def region_aware_l1(pred, target, mask_seq, lambda_bg=2.0):
    """
    pred, target: [B,3,T,H,W] in [-1,1]
    mask_seq:     [B,1,T,H,W] 1 = object region
    """
    diff = (pred - target).abs()
    w = mask_seq * 1.0 + (1.0 - mask_seq) * lambda_bg
    w = w.expand_as(diff)
    return (w * diff).mean()

def hinge_d_loss(D_real, D_fake):
    return torch.relu(1.0 - D_real).mean() + torch.relu(1.0 + D_fake).mean()

def hinge_g_loss(D_fake):
    return -D_fake.mean()


In [None]:
class FullGeneratorPipeline(nn.Module):
    """
    Complete generator:
      - VideoEncoder3D (frozen)
      - SelectiveContentEncoder
      - VideoGenerator3D
    """
    def __init__(self, num_frames=8):
        super().__init__()
        self.encoder = VideoEncoder3D(in_ch=3, base_ch=32, out_ch=64)
        self.sce     = SelectiveContentEncoder(in_ch=3+3+1, feat_ch=32)
        self.gen     = VideoGenerator3D(latent_ch=64, cond_ch=32, base_ch=64, out_frames=num_frames)

        # Freeze encoder (like freezing most layers of SVD)
        for p in self.encoder.parameters():
            p.requires_grad = False

    def forward(self, orig_video, orig_first, edit_first, first_mask):
        """
        orig_video : [B,3,T,H,W]  (original)
        orig_first : [B,3,H,W]
        edit_first : [B,3,H,W]
        first_mask : [B,1,H,W]
        Returns:
          edited_video: [B,3,T,H,W]
        """
        B, C, T_video, H_video, W_video = orig_video.shape # Get target H, W from orig_video

        with torch.no_grad():
            z = self.encoder(orig_video)  # latent; encoder frozen

        cond_feat = self.sce(orig_first, edit_first, first_mask)  # [B,32,Hf,Wf]
        out = self.gen(z, cond_feat, H_video, W_video)             # Pass target H, W to generator
        return out

In [None]:
def train_pipeline(
    num_epochs=3,
    batch_size=1,
    num_frames=8,
    lr_g=2e-4,
    lr_d=2e-4,
    lambda_l1=50.0,
    lambda_bg=2.0,
):
    dataset = DAVISObjectRemovalDataset(
        IMG_ROOT, MASK_ROOT, num_frames=num_frames
    )
    loader = DataLoader(dataset, batch_size=batch_size,
                        shuffle=True, num_workers=2)

    G = FullGeneratorPipeline(num_frames=num_frames).to(device)
    D = VideoDiscriminator3D(base_ch=32).to(device)

    # Only parameters with requires_grad=True will be updated
    opt_G = torch.optim.Adam(
        [p for p in G.parameters() if p.requires_grad],
        lr=lr_g, betas=(0.5, 0.999)
    )
    opt_D = torch.optim.Adam(D.parameters(), lr=lr_d, betas=(0.5, 0.999))

    for epoch in range(num_epochs):
        pbar = tqdm(loader, desc=f"Epoch {epoch}")
        for batch in pbar:
            # Move to device
            orig_video = batch["orig_video"].to(device)   # [B,3,T,H,W] in [0,1]
            edit_video = batch["edit_video"].to(device)   # [B,3,T,H,W] in [0,1]
            mask_seq   = batch["mask_seq"].to(device)     # [B,1,T,H,W]
            orig_first = batch["orig_first"].to(device)   # [B,3,H,W]
            edit_first = batch["edit_first"].to(device)   # [B,3,H,W]
            first_mask = batch["first_mask"].to(device)   # [B,1,H,W]

            B, C, T, H, W = orig_video.shape

            # Normalize to [-1,1]
            real_video = edit_video * 2.0 - 1.0

            # ---- Generator forward ----
            fake_video = G(
                orig_video * 2.0 - 1.0,
                orig_first * 2.0 - 1.0,
                edit_first * 2.0 - 1.0,
                first_mask,
            )  # [-1,1]

            # ---- Train Discriminator ----
            D_real = D(real_video)
            D_fake = D(fake_video.detach())
            loss_D = hinge_d_loss(D_real, D_fake)

            opt_D.zero_grad()
            loss_D.backward()
            opt_D.step()

            # ---- Train Generator ----
            D_fake_for_G = D(fake_video)
            loss_G_adv   = hinge_g_loss(D_fake_for_G)
            loss_G_l1    = region_aware_l1(fake_video, real_video, mask_seq,
                                           lambda_bg=lambda_bg)
            loss_G       = loss_G_adv + lambda_l1 * loss_G_l1

            opt_G.zero_grad()
            loss_G.backward()
            opt_G.step()

            pbar.set_postfix(
                D=loss_D.item(),
                G_adv=loss_G_adv.item(),
                L1=loss_G_l1.item()
            )

        # save checkpoints
        torch.save(G.state_dict(), f"video_gen_epoch{epoch}.pth")
        torch.save(D.state_dict(), f"video_disc_epoch{epoch}.pth")

    return G, D

# Run a short training for debugging – increase num_epochs later
G_trained, D_trained = train_pipeline(
    num_epochs=5,
    batch_size=1,
    num_frames=8,
    lr_g=2e-4,
    lr_d=2e-4,
    lambda_l1=50.0,
    lambda_bg=2.0,
)

Epoch 0:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/50 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process


Epoch 2:   0%|          | 0/50 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 3:   0%|          | 0/50 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 4:   0%|          | 0/50 [00:00<?, ?it/s]

Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()    
self._shutdown_workers()  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
Exception ignored in:         if w.is_alive():
if w.is_alive():Exception ignored in:  
<function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>  
<function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>  
Traceback (most recent call last)

In [None]:
@torch.no_grad()
def run_inference_on_sequence(G, dataset, index=0, save_path="gen_result.mp4"):
    G.eval()
    sample = dataset[index]
    orig_video = sample["orig_video"].unsqueeze(0).to(device)  # [1,3,T,H,W]
    edit_video = sample["edit_video"].unsqueeze(0).to(device)
    mask_seq   = sample["mask_seq"].unsqueeze(0).to(device)
    orig_first = sample["orig_first"].unsqueeze(0).to(device)
    edit_first = sample["edit_first"].unsqueeze(0).to(device)
    first_mask = sample["first_mask"].unsqueeze(0).to(device)

    real_video = edit_video * 2.0 - 1.0
    fake_video = G(
        orig_video * 2.0 - 1.0,
        orig_first * 2.0 - 1.0,
        edit_first * 2.0 - 1.0,
        first_mask,
    )  # [-1,1]

    real_np = ((real_video[0].cpu().numpy() + 1.0) / 2.0).clip(0,1)  # [3,T,H,W]
    fake_np = ((fake_video[0].cpu().numpy() + 1.0) / 2.0).clip(0,1)

    real_np = (real_np.transpose(1,2,3,0) * 255).astype(np.uint8)  # [T,H,W,3]
    fake_np = (fake_np.transpose(1,2,3,0) * 255).astype(np.uint8)

    # Side-by-side video (left = GT, right = generated)
    T, H, W, _ = real_np.shape
    frames = []
    for t in range(T):
        concat = np.concatenate([real_np[t], fake_np[t]], axis=1)
        frames.append(concat)

    imageio.mimsave(save_path, frames, fps=5)
    print("Saved", save_path)

# Usage
test_dataset = DAVISObjectRemovalDataset(IMG_ROOT, MASK_ROOT, num_frames=8)
run_inference_on_sequence(G_trained, test_dataset, index=0, save_path="comparison_gt_vs_gen.gif")


Saved comparison_gt_vs_gen.gif


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock3D(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv3d(in_ch, out_ch, 3, padding=1)
        self.norm = nn.GroupNorm(8, out_ch)
        self.act  = nn.SiLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))

class DownBlock3D(nn.Module):
    """Strides only in space, not time."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = ConvBlock3D(in_ch, out_ch)
        self.conv2 = ConvBlock3D(out_ch, out_ch)
        self.down  = nn.Conv3d(out_ch, out_ch, 3, stride=(1,2,2), padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        skip = x
        x = self.down(x)
        return x, skip

class UpBlock3D(nn.Module):
    def __init__(self, in_ch_from_down, out_ch_this_block, skip_ch_from_down):
        super().__init__()
        self.up   = nn.ConvTranspose3d(
            in_ch_from_down, out_ch_this_block,
            kernel_size=3, stride=(1,2,2),
            padding=1, output_padding=(0,1,1)
        )

        # Adapter for skip connection to match out_ch_this_block before concatenation
        if skip_ch_from_down != out_ch_this_block:
            self.skip_adapter = ConvBlock3D(skip_ch_from_down, out_ch_this_block)
            print(f"UpBlock3D init: skip_adapter created for {skip_ch_from_down}->{out_ch_this_block}")
        else:
            self.skip_adapter = nn.Identity()
            print(f"UpBlock3D init: skip_adapter is Identity for {skip_ch_from_down}->{out_ch_this_block}")

        self.conv1 = ConvBlock3D(out_ch_this_block * 2, out_ch_this_block)
        self.conv2 = ConvBlock3D(out_ch_this_block, out_ch_this_block)
        print(f"UpBlock3D init: conv1 expects {out_ch_this_block * 2} channels.")

    def forward(self, x, skip):
        print(f"UpBlock3D forward: Input x shape (from upsampled): {x.shape}, skip shape (original): {skip.shape}")
        x = self.up(x)
        print(f"UpBlock3D forward: After self.up, x shape: {x.shape}")

        skip_adapted = self.skip_adapter(skip)
        print(f"UpBlock3D forward: After skip_adapter, skip_adapted shape: {skip_adapted.shape}")

        # match spatial size
        _, _, T, Hs, Ws = skip_adapted.shape
        x = F.interpolate(x, size=(T, Hs, Ws), mode="trilinear", align_corners=False)
        print(f"UpBlock3D forward: After interpolate, x shape: {x.shape}")

        x = torch.cat([x, skip_adapted], dim=1)
        print(f"UpBlock3D forward: After concatenate, x shape: {x.shape}")
        x = self.conv1(x)
        x = self.conv2(x)
        return x

In [None]:
class SelectiveContentEncoder(nn.Module):
    """
    orig_first: [B,3,H,W]
    edit_first: [B,3,H,W]
    first_mask: [B,1,H,W]
    → cond_feat: [B,cond_ch,H/4,W/4]
    """
    def __init__(self, in_ch=3+3+1, cond_ch=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # H/2
            nn.ReLU(inplace=True),
            nn.Conv2d(64, cond_ch, 4, stride=2, padding=1),  # H/4
            nn.ReLU(inplace=True),
        )

    def forward(self, orig_first, edit_first, first_mask):
        x = torch.cat([orig_first, edit_first, first_mask], dim=1)
        feat = self.encoder(x)
        return feat  # [B,cond_ch,H/4,W/4]


In [None]:
class VideoUNetWithSCE(nn.Module):
    def __init__(self, in_ch=3, base_ch=64, cond_ch=64):
        super().__init__()
        self.down1 = DownBlock3D(in_ch, base_ch)        # s1: base_ch
        self.down2 = DownBlock3D(base_ch, base_ch*2)    # s2: base_ch*2
        self.down3 = DownBlock3D(base_ch*2, base_ch*4)  # s3: base_ch*4

        self.mid   = ConvBlock3D(base_ch*4, base_ch*4)

        # project SCE feature into bottleneck channels
        self.cond_proj = nn.Conv2d(cond_ch, base_ch*4, kernel_size=1)

        # UpBlock3D(in_ch_from_down, out_ch_this_block, skip_ch_from_down)
        self.up3  = UpBlock3D(base_ch*4, base_ch*2, base_ch*4) # mid has base_ch*4, target out is base_ch*2, skip is s3 (base_ch*4)
        self.up2  = UpBlock3D(base_ch*2, base_ch, base_ch*2)   # up3 out is base_ch*2, target out is base_ch, skip is s2 (base_ch*2)
        self.out  = nn.Conv3d(base_ch, in_ch, 3, padding=1)

    def forward(self, orig_video, cond_feat):
        """
        orig_video: [B,3,T,H,W]   in [-1,1]
        cond_feat : [B,cond_ch,H/4,W/4] from SCE
        """
        B, C, T, H, W = orig_video.shape

        # Encoder
        h1, s1 = self.down1(orig_video)      # [B,64,T,H/2,W/2]
        h2, s2 = self.down2(h1)             # [B,128,T,H/4,W/4]
        h3, s3 = self.down3(h2)             # [B,256,T,H/8,W/8]

        # SCE injection at bottleneck.
        # First, downsample cond_feat to match h3 spatial size.
        Hb, Wb = h3.shape[-2:] # Corrected unpacking
        cond_proj = self.cond_proj(cond_feat)           # [B,256,H/4,W/4]
        cond_proj = F.interpolate(cond_proj, size=(Hb, Wb),
                                  mode="bilinear", align_corners=False)
        cond_proj = cond_proj.unsqueeze(2).expand(-1,-1,T,-1,-1)  # [B,256,T,Hb,Wb]

        mid = self.mid(h3 + cond_proj)

        # Decoder
        u3 = self.up3(mid, s3)              # [B,128,T,H/4,W/4]
        u2 = self.up2(u3, s2)               # [B,64,T,H/2,W/2]
        out = self.out(u2)                  # [B,3,T,H/2,W/2]

        # upsample back to original spatial size
        out = F.interpolate(out, size=(T, H, W),
                            mode="trilinear", align_corners=False)
        out = torch.tanh(out)               # [-1,1]
        return out

In [None]:
from torch.utils.data import DataLoader

def train_unet_sce(
    num_epochs=10,
    batch_size=2,
    num_frames=8,
    lr=1e-4,
    lambda_bg=2.0,
):
    dataset = DAVISObjectRemovalDataset(
        IMG_ROOT, MASK_ROOT, num_frames=num_frames,
        short_side=192, max_side=224  # smaller → easier to train
    )
    loader = DataLoader(dataset, batch_size=batch_size,
                        shuffle=True, num_workers=2, pin_memory=True)

    sce = SelectiveContentEncoder(in_ch=3+3+1, cond_ch=64).to(device)
    G   = VideoUNetWithSCE(in_ch=3, base_ch=64, cond_ch=64).to(device)

    params = list(sce.parameters()) + list(G.parameters())
    optimizer = torch.optim.Adam(params, lr=lr)

    for epoch in range(num_epochs):
        G.train(); sce.train()
        pbar = tqdm(loader, desc=f"Epoch {epoch}")
        for batch in pbar:
            orig_video = batch["orig_video"].to(device)   # [B,3,T,H,W] in [0,1]
            edit_video = batch["edit_video"].to(device)   # [B,3,T,H,W]
            mask_seq   = batch["mask_seq"].to(device)     # [B,1,T,H,W]
            orig_first = batch["orig_first"].to(device)   # [B,3,H,W]
            edit_first = batch["edit_first"].to(device)   # [B,3,H,W]
            first_mask = batch["first_mask"].to(device)   # [B,1,H,W]

            # Normalise to [-1,1]
            orig_v = orig_video * 2.0 - 1.0
            target = edit_video * 2.0 - 1.0

            cond_feat = sce(orig_first, edit_first, first_mask)

            pred = G(orig_v, cond_feat)   # [-1,1], [B,3,T,H,W]

            loss = object_aware_l1(pred, target, mask_seq,
                       lambda_obj=10.0, lambda_bg=1.0)


            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(params, 1.0)
            optimizer.step()

            pbar.set_postfix(loss=loss.item())

        torch.save({
            "sce": sce.state_dict(),
            "G":   G.state_dict(),
        }, f"unet_sce_epoch{epoch}.pth")

    return sce, G

sce_trained, G_trained = train_unet_sce(
    num_epochs=10,   # *really* suggest ≥10 here
    batch_size=2,
    num_frames=8,
    lr=1e-4,
    lambda_bg=2.0,
)

Epoch 0:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 5:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 6:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 7:   0%|          | 0/25 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 8:   0%|          | 0/25 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^assert self._parent_pid == os.getpid(), 'can only test a child process'      File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive


           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 9:   0%|          | 0/25 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c1e18ccd260>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

In [None]:
import imageio

@torch.no_grad()
def test_on_sequence(sce, G, dataset, index=0, save_path="comparison_unet.gif"):
    sce.eval(); G.eval()
    sample = dataset[index]

    orig_video = sample["orig_video"].unsqueeze(0).to(device)
    edit_video = sample["edit_video"].unsqueeze(0).to(device)
    mask_seq   = sample["mask_seq"].unsqueeze(0).to(device)
    orig_first = sample["orig_first"].unsqueeze(0).to(device)
    edit_first = sample["edit_first"].unsqueeze(0).to(device)
    first_mask = sample["first_mask"].unsqueeze(0).to(device)

    orig_v = orig_video * 2.0 - 1.0
    target = edit_video * 2.0 - 1.0

    cond_feat = sce(orig_first, edit_first, first_mask)
    pred = G(orig_v, cond_feat)     # [-1,1]

    target_np = ((target[0].cpu().numpy() + 1.0) / 2.0).clip(0,1)
    pred_np   = ((pred[0].cpu().numpy() + 1.0) / 2.0).clip(0,1)

    target_np = (target_np.transpose(1,2,3,0) * 255).astype(np.uint8)  # [T,H,W,3]
    pred_np   = (pred_np.transpose(1,2,3,0) * 255).astype(np.uint8)

    frames = []
    for t in range(target_np.shape[0]):
        concat = np.concatenate([target_np[t], pred_np[t]], axis=1)
        frames.append(concat)

    imageio.mimsave(save_path, frames, fps=5)
    print("Saved", save_path)

# Use the same dataset config as training
test_dataset = DAVISObjectRemovalDataset(
    IMG_ROOT, MASK_ROOT, num_frames=8,
    short_side=192, max_side=224
)
test_on_sequence(sce_trained, G_trained, test_dataset, index=0,
                 save_path="comparison_unet.gif")


Saved comparison_unet.gif


In [None]:
def object_aware_l1(pred, target, mask_seq,
                    lambda_obj=10.0, lambda_bg=1.0):
    """
    pred, target: [B,3,T,H,W] in [-1,1]
    mask_seq    : [B,1,T,H,W]  (1 = object region)
    """
    eps = 1e-8
    mask = mask_seq
    inv_mask = 1.0 - mask

    diff = (pred - target).abs()          # [B,3,T,H,W]

    # object region
    obj_diff = diff * mask
    obj_norm = mask.sum() * diff.shape[1] + eps   # channels × #object pixels
    loss_obj = obj_diff.sum() / obj_norm

    # background
    bg_diff = diff * inv_mask
    bg_norm = inv_mask.sum() * diff.shape[1] + eps
    loss_bg = bg_diff.sum() / bg_norm

    return lambda_obj * loss_obj + lambda_bg * loss_bg


In [None]:
def read_rgb(path):
    img = cv2.imread(str(path))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def read_mask(path):
    m = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
    return (m > 0).astype(np.uint8)  # 0/1

def resize_pair(img, mask, short_side=256, max_side=320):
    h, w = img.shape[:2]
    scale = short_side / min(h, w)
    if max(h, w) * scale > max_side:
        scale = max_side / max(h, w)
    nh, nw = int(h * scale), int(w * scale)
    img_r = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA)
    mask_r = cv2.resize(mask, (nw, nh), interpolation=cv2.INTER_NEAREST)
    return img_r, (mask_r > 0).astype(np.uint8)

def dilate_mask(mask, k=5):
    """Make sure mask fully covers bear + a little context."""
    kernel = np.ones((k, k), np.uint8)
    return cv2.dilate(mask.astype(np.uint8), kernel, iterations=1)

def video_to_tensor(frames):
    arr = np.stack(frames, axis=0).astype(np.float32) / 255.0  # [T,H,W,3]
    arr = arr.transpose(3, 0, 1, 2)  # [3,T,H,W]
    return torch.from_numpy(arr)

def masks_to_tensor(masks):
    arr = np.stack(masks, axis=0).astype(np.float32)  # [T,H,W]
    arr = arr[None, ...]  # [1,T,H,W]
    return torch.from_numpy(arr)


In [None]:
def inpaint_cv2(img_rgb, mask_binary, radius=3):
    bgr  = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
    mask = (mask_binary > 0).astype(np.uint8) * 255
    out  = cv2.inpaint(bgr, mask, radius, cv2.INPAINT_TELEA)
    return cv2.cvtColor(out, cv2.COLOR_BGR2RGB)


In [None]:
USE_SD_INPAINT = False  # set True if you want diffusion inpainting

sd_pipe = None
if USE_SD_INPAINT:
    sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
        "runwayml/stable-diffusion-inpainting",
        torch_dtype=torch.float16,
        safety_checker=None,
    ).to(device)
    sd_pipe.enable_xformers_memory_efficient_attention()

def inpaint_sd(img_rgb, mask_binary, prompt="remove the animal, keep background natural"):
    """
    img_rgb: HxWx3 uint8
    mask_binary: HxW 0/1, 1 = region to remove
    """
    image = Image.fromarray(img_rgb)
    mask  = Image.fromarray((mask_binary * 255).astype(np.uint8))

    with torch.autocast(device.type):
        result = sd_pipe(
            prompt=prompt,
            image=image,
            mask_image=mask,
            guidance_scale=7.5,
            num_inference_steps=30,
        ).images[0]
    return np.array(result)


In [None]:
def inpaint_frame(img_rgb, mask_binary):
    if USE_SD_INPAINT and sd_pipe is not None:
        return inpaint_sd(img_rgb, mask_binary)
    else:
        return inpaint_cv2(img_rgb, mask_binary)


In [None]:
class DAVISInpaintDataset(Dataset):
    """
    For each sequence:
      orig_video: [3,T,H,W]   original
      edit_video: [3,T,H,W]   inpainted bear-removed
      mask_seq  : [1,T,H,W]   dilated masks
      orig_first, edit_first, first_mask (for SCE)
    """
    def __init__(self, img_root, mask_root,
                 num_frames=8, short_side=256, max_side=320):
        self.img_root = img_root
        self.mask_root = mask_root
        self.seq_names = sorted([p.name for p in img_root.iterdir() if p.is_dir()])
        self.num_frames = num_frames
        self.short_side = short_side
        self.max_side = max_side

    def __len__(self):
        return len(self.seq_names)

    def _paths_for_seq(self, seq):
        img_dir  = self.img_root / seq
        mask_dir = self.mask_root / seq
        imgs  = sorted(img_dir.glob("*.jpg"))
        masks = sorted(mask_dir.glob("*.png"))
        T = min(len(imgs), len(masks))
        return imgs[:T], masks[:T]

    def __getitem__(self, idx):
        seq = self.seq_names[idx]
        img_paths, mask_paths = self._paths_for_seq(seq)

        # sample contiguous window or loop
        if len(img_paths) >= self.num_frames:
            start = random.randint(0, len(img_paths) - self.num_frames)
            img_paths = img_paths[start:start + self.num_frames]
            mask_paths = mask_paths[start:start + self.num_frames]
        else:
            rep = (self.num_frames + len(img_paths) - 1) // len(img_paths)
            img_paths  = (img_paths  * rep)[:self.num_frames]
            mask_paths = (mask_paths * rep)[:self.num_frames]

        frames_orig, frames_edit, masks = [], [], []
        for ip, mp in zip(img_paths, mask_paths):
            img = read_rgb(ip)
            m   = read_mask(mp)
            img, m = resize_pair(img, m, self.short_side, self.max_side)
            m   = dilate_mask(m, k=5)          # IMPORTANT: enlarge region
            frames_orig.append(img)
            masks.append(m)
            frames_edit.append(inpaint_frame(img, m))  # GT bear-removed frame

        orig_video = video_to_tensor(frames_orig)  # [3,T,H,W]
        edit_video = video_to_tensor(frames_edit)  # [3,T,H,W]
        mask_seq   = masks_to_tensor(masks)        # [1,T,H,W]

        orig_first = orig_video[:, 0]              # [3,H,W]
        edit_first = edit_video[:, 0]              # [3,H,W]
        first_mask = mask_seq[:, 0]                # [1,H,W]

        return {
            "seq": seq,
            "orig_video": orig_video,
            "edit_video": edit_video,
            "mask_seq": mask_seq,
            "orig_first": orig_first,
            "edit_first": edit_first,
            "first_mask": first_mask,
        }

dataset = DAVISInpaintDataset(IMG_ROOT, MASK_ROOT, num_frames=8)
sample = dataset[0]
for k, v in sample.items():
    print(k, v.shape if torch.is_tensor(v) else v)


seq bear
orig_video torch.Size([3, 8, 179, 320])
edit_video torch.Size([3, 8, 179, 320])
mask_seq torch.Size([1, 8, 179, 320])
orig_first torch.Size([3, 179, 320])
edit_first torch.Size([3, 179, 320])
first_mask torch.Size([1, 179, 320])


In [None]:
def visualize_masks(dataset, index=0, out_path="mask_overlay.gif"):
    s = dataset[index]
    orig_video = s["orig_video"].numpy()   # [3,T,H,W]
    mask_seq   = s["mask_seq"].numpy()     # [1,T,H,W]

    frames = []
    T = orig_video.shape[1]
    for t in range(T):
        frame = (orig_video[:, t].transpose(1,2,0) * 255).astype(np.uint8)
        mask  = mask_seq[0, t]

        # red overlay where mask=1
        overlay = frame.copy()
        overlay[mask > 0] = [255, 0, 0]
        vis = (0.6 * frame + 0.4 * overlay).astype(np.uint8)
        frames.append(vis)

    imageio.mimsave(out_path, frames, fps=4)
    print("Saved", out_path)

visualize_masks(dataset, index=0, out_path="mask_overlay.gif")


Saved mask_overlay.gif


In [None]:
class ConvBlock3D(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv3d(in_ch, out_ch, 3, padding=1)
        self.norm = nn.GroupNorm(8, out_ch)
        self.act  = nn.SiLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))

class DownBlock3D(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = ConvBlock3D(in_ch, out_ch)
        self.conv2 = ConvBlock3D(out_ch, out_ch)
        self.down  = nn.Conv3d(out_ch, out_ch, 3,
                               stride=(1,2,2), padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        skip = x
        x = self.down(x)
        return x, skip

class UpBlock3D(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose3d(
            in_ch, out_ch,
            kernel_size=3, stride=(1,2,2),
            padding=1, output_padding=(0,1,1)
        )
        self.conv1 = ConvBlock3D(out_ch + skip_ch, out_ch)
        self.conv2 = ConvBlock3D(out_ch, out_ch)

    def forward(self, x, skip):
        x = self.up(x)
        _, _, T, Hs, Ws = skip.shape

        x = F.interpolate(x, size=(T, Hs, Ws),
                          mode="trilinear", align_corners=False)

        x = torch.cat([x, skip], dim=1)  # concat along channels
        x = self.conv1(x)
        x = self.conv2(x)
        return x



In [None]:
class SelectiveContentEncoder(nn.Module):
    """
    orig_first: [B,3,H,W]
    edit_first: [B,3,H,W]
    first_mask: [B,1,H,W]
    → cond_feat: [B,cond_ch,H/4,W/4]
    """
    def __init__(self, in_ch=3+3+1, cond_ch=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, cond_ch, 4, stride=2, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, orig_first, edit_first, first_mask):
        x = torch.cat([orig_first, edit_first, first_mask], dim=1)
        return self.encoder(x)  # [B,cond_ch,H/4,W/4]


In [None]:
class VideoUNetWithSCE(nn.Module):
    def __init__(self, in_ch=3, base_ch=64, cond_ch=64):
        super().__init__()
        self.down1 = DownBlock3D(in_ch, base_ch)
        self.down2 = DownBlock3D(base_ch, base_ch*2)
        self.down3 = DownBlock3D(base_ch*2, base_ch*4)

        self.mid = ConvBlock3D(base_ch*4, base_ch*4)

        self.cond_proj = nn.Conv2d(cond_ch, base_ch*4, 1)

        self.up3 = UpBlock3D(
        in_ch=base_ch*4,    # mid bottleneck
        skip_ch=base_ch*4,  # skip from down3
        out_ch=base_ch*2,
)

        self.up2 = UpBlock3D(
        in_ch=base_ch*2,     # output of up3
        skip_ch=base_ch*2,   # skip from down2
        out_ch=base_ch,
)

        self.out = nn.Conv3d(base_ch, in_ch, 3, padding=1)

    def forward(self, orig_video, cond_feat):
        """
        orig_video: [B,3,T,H,W]  in [-1,1]
        cond_feat : [B,cond_ch,H/4,W/4]
        """
        B, C, T, H, W = orig_video.shape

        h1, s1 = self.down1(orig_video)      # [B,64,T,H/2,W/2]
        h2, s2 = self.down2(h1)             # [B,128,T,H/4,W/4]
        h3, s3 = self.down3(h2)             # [B,256,T,H/8,W/8]

        # SCE injection at bottleneck
        _, _, Tb, Hb, Wb = h3.shape
        cond = self.cond_proj(cond_feat)          # [B,256,H/4,W/4]
        cond = F.interpolate(cond, size=(Hb, Wb),
                              mode="bilinear", align_corners=False)
        cond = cond.unsqueeze(2).expand(-1, -1, Tb, -1, -1)  # [B,256,Tb,Hb,Wb]

        mid = self.mid(h3 + cond)

        u3 = self.up3(mid, s3)
        u2 = self.up2(u3, s2)
        out = self.out(u2)                       # [B,3,T,H/2,W/2]
        out = F.interpolate(out, size=(T, H, W),
                            mode="trilinear", align_corners=False)
        out = torch.tanh(out)                    # [-1,1]
        return out


In [None]:
def blend_with_mask(orig_v, pred_v, mask_seq):
    """
    orig_v, pred_v: [B,3,T,H,W] in [-1,1]
    mask_seq      : [B,1,T,H,W] 1 = object region
    """
    mask = mask_seq
    inv_mask = 1.0 - mask
    return orig_v * inv_mask + pred_v * mask


In [None]:
def object_aware_l1(pred, target, mask_seq,
                    lambda_obj=10.0, lambda_bg=1.0):
    """
    pred, target: [B,3,T,H,W] in [-1,1]
    mask_seq    : [B,1,T,H,W]
    """
    eps = 1e-8
    mask    = mask_seq
    invmask = 1.0 - mask

    diff = (pred - target).abs()

    # object region
    obj_diff = diff * mask
    obj_norm = mask.sum() * diff.shape[1] + eps
    loss_obj = obj_diff.sum() / obj_norm

    # background region
    bg_diff = diff * invmask
    bg_norm = invmask.sum() * diff.shape[1] + eps
    loss_bg = bg_diff.sum() / bg_norm

    return lambda_obj * loss_obj + lambda_bg * loss_bg


In [None]:
def train_unet_sce(
    num_epochs=10,
    batch_size=1,
    num_frames=8,
    lr=1e-4,
    lambda_obj=10.0,
    lambda_bg=1.0,
):
    dataset = DAVISInpaintDataset(
        IMG_ROOT, MASK_ROOT,
        num_frames=num_frames,
        short_side=256, max_side=320,
    )
    loader = DataLoader(dataset, batch_size=batch_size,
                        shuffle=True, num_workers=2, pin_memory=True)

    sce = SelectiveContentEncoder(in_ch=3+3+1, cond_ch=64).to(device)
    G   = VideoUNetWithSCE(in_ch=3, base_ch=64, cond_ch=64).to(device)

    params = list(sce.parameters()) + list(G.parameters())
    optimizer = torch.optim.Adam(params, lr=lr)

    for epoch in range(num_epochs):
        G.train(); sce.train()
        pbar = tqdm(loader, desc=f"Epoch {epoch}")
        for batch in pbar:
            orig_video = batch["orig_video"].to(device)   # [B,3,T,H,W] in [0,1]
            edit_video = batch["edit_video"].to(device)   # GT inpaint
            mask_seq   = batch["mask_seq"].to(device)     # [B,1,T,H,W]
            orig_first = batch["orig_first"].to(device)
            edit_first = batch["edit_first"].to(device)
            first_mask = batch["first_mask"].to(device)

            orig_v = orig_video * 2.0 - 1.0
            target = edit_video * 2.0 - 1.0

            cond_feat = sce(orig_first, edit_first, first_mask)
            pred_raw  = G(orig_v, cond_feat)            # [-1,1] full-frame
            pred      = blend_with_mask(orig_v, pred_raw, mask_seq)

            loss = object_aware_l1(
                pred, target, mask_seq,
                lambda_obj=lambda_obj, lambda_bg=lambda_bg
            )

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(params, 1.0)
            optimizer.step()

            pbar.set_postfix(loss=loss.item())

        torch.save(
            {"sce": sce.state_dict(), "G": G.state_dict()},
            f"unet_sce_epoch{epoch}.pth"
        )

    return sce, G

sce_trained, G_trained = train_unet_sce(
    num_epochs=5,   # bump to 10+ when things look good
    batch_size=1,
    num_frames=8,
    lr=1e-4,
    lambda_obj=10.0,
    lambda_bg=1.0,
)


Epoch 0:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 2:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 3:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 4:   0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
@torch.no_grad()
def test_on_sequence(sce, G, dataset, index=0, save_path="gt_vs_gen.gif"):
    sce.eval(); G.eval()
    s = dataset[index]

    orig_video = s["orig_video"].unsqueeze(0).to(device)
    edit_video = s["edit_video"].unsqueeze(0).to(device)
    mask_seq   = s["mask_seq"].unsqueeze(0).to(device)
    orig_first = s["orig_first"].unsqueeze(0).to(device)
    edit_first = s["edit_first"].unsqueeze(0).to(device)
    first_mask = s["first_mask"].unsqueeze(0).to(device)

    orig_v = orig_video * 2.0 - 1.0
    target = edit_video * 2.0 - 1.0

    cond_feat = sce(orig_first, edit_first, first_mask)
    pred_raw  = G(orig_v, cond_feat)
    pred      = blend_with_mask(orig_v, pred_raw, mask_seq)

    target_np = ((target[0].cpu().numpy() + 1.0) / 2.0).clip(0,1)
    pred_np   = ((pred[0].cpu().numpy() + 1.0) / 2.0).clip(0,1)

    target_np = (target_np.transpose(1,2,3,0) * 255).astype(np.uint8)
    pred_np   = (pred_np.transpose(1,2,3,0) * 255).astype(np.uint8)

    frames = []
    for t in range(target_np.shape[0]):
        concat = np.concatenate([target_np[t], pred_np[t]], axis=1)
        frames.append(concat)

    imageio.mimsave(save_path, frames, fps=5)
    print("Saved", save_path)

test_dataset = DAVISInpaintDataset(IMG_ROOT, MASK_ROOT, num_frames=8)
test_on_sequence(sce_trained, G_trained, test_dataset, index=0,
                 save_path="gt_vs_gen.gif")


Saved gt_vs_gen.gif


In [None]:
@torch.no_grad()
def compare_full_sequence(sce, G, dataset, index=0,
                          save_path="gt_vs_gen_full.gif"):
    sce.eval(); G.eval()
    s = dataset[index]

    orig_video = s["orig_video"].unsqueeze(0).to(device)
    edit_video = s["edit_video"].unsqueeze(0).to(device)
    mask_seq   = s["mask_seq"].unsqueeze(0).to(device)
    orig_first = s["orig_first"].unsqueeze(0).to(device)
    edit_first = s["edit_first"].unsqueeze(0).to(device)
    first_mask = s["first_mask"].unsqueeze(0).to(device)

    orig_v = orig_video * 2.0 - 1.0
    target = edit_video * 2.0 - 1.0

    cond_feat = sce(orig_first, edit_first, first_mask)
    pred_raw  = G(orig_v, cond_feat)
    pred      = blend_with_mask(orig_v, pred_raw, mask_seq)

    target_np = ((target[0].cpu().numpy() + 1.0) / 2.0).clip(0,1)
    pred_np   = ((pred[0].cpu().numpy() + 1.0) / 2.0).clip(0,1)

    target_np = (target_np.transpose(1,2,3,0) * 255).astype(np.uint8)  # [T,H,W,3]
    pred_np   = (pred_np.transpose(1,2,3,0) * 255).astype(np.uint8)

    frames = []
    for t in range(target_np.shape[0]):
        concat = np.concatenate([target_np[t], pred_np[t]], axis=1)
        frames.append(concat)

    imageio.mimsave(save_path, frames, fps=5)
    print("Saved", save_path)

compare_full_sequence(sce_trained, G_trained, dataset, index=0,
                      save_path="gt_vs_gen_full.gif")


Saved gt_vs_gen_full.gif


## New Section

In [None]:
# Uninstall any existing xformers installation before attempting a new one
!pip uninstall -y xformers

# Ensure base packages are correctly installed, focusing on a consistent CUDA version
# We'll use cu121 as it's common for xformers builds with PyTorch 2.x
!pip uninstall -y torch torchvision torchaudio # Uninstall existing PyTorch components
!pip install -q torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121

# Now install other necessary packages, ensuring compatibility
!pip uninstall -y diffusers transformers accelerate safetensors
!pip install -q diffusers==0.24.0 transformers accelerate safetensors

# Install xformers for the chosen CUDA version (cu121)
!pip install -q xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121

import os, random
from pathlib import Path
from glob import glob

import cv2
import imageio
import numpy as np
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from diffusers import (
    StableDiffusionControlNetInpaintPipeline,
    ControlNetModel,
    StableVideoDiffusionPipeline,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ GPU available: {torch.cuda.is_available()}")

[0mFound existing installation: torch 2.8.0+cu126
Uninstalling torch-2.8.0+cu126:
  Successfully uninstalled torch-2.8.0+cu126
Found existing installation: torchvision 0.23.0+cu126
Uninstalling torchvision-0.23.0+cu126:
  Successfully uninstalled torchvision-0.23.0+cu126
Found existing installation: torchaudio 2.8.0+cu126
Uninstalling torchaudio-2.8.0+cu126:
  Successfully uninstalled torchaudio-2.8.0+cu126
[31mERROR: Could not find a version that satisfies the requirement torch==2.1.0 (from versions: 2.2.0+cu121, 2.2.1+cu121, 2.2.2+cu121, 2.3.0+cu121, 2.3.1+cu121, 2.4.0+cu121, 2.4.1+cu121, 2.5.0+cu121, 2.5.1+cu121)[0m[31m
[0m[31mERROR: No matching distribution found for torch==2.1.0[0m[31m
[0mFound existing installation: diffusers 0.35.2
Uninstalling diffusers-0.35.2:
  Successfully uninstalled diffusers-0.35.2
Found existing installation: transformers 4.57.1
Uninstalling transformers-4.57.1:
  Successfully uninstalled transformers-4.57.1
Found existing installation: accelera

ImportError: cannot import name 'cached_download' from 'huggingface_hub' (/usr/local/lib/python3.12/dist-packages/huggingface_hub/__init__.py)

In [None]:
# Install xformers for PyTorch 2.x and CUDA 12.1 (compatible with CUDA 12.6)
!pip install -q xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121

print("✓ xformers installed for CUDA 12.1")

[31mERROR: Could not find a version that satisfies the requirement xformers==0.0.23.post1 (from versions: 0.0.27, 0.0.27.post1, 0.0.27.post2, 0.0.28, 0.0.28.post1, 0.0.28.post2, 0.0.28.post3, 0.0.29, 0.0.29.post1)[0m[31m
[0m[31mERROR: No matching distribution found for xformers==0.0.23.post1[0m[31m
[0m✓ xformers installed for CUDA 12.1


In [None]:
# Uninstall any existing xformers installation
!pip uninstall -y xformers

# Install a more recent xformers version compatible with PyTorch 2.x and CUDA 12.x
# Using 0.0.29.post1 as it was listed as available by pip and is for cu121
!pip install -q xformers==0.0.29.post1 --index-url https://download.pytorch.org/whl/cu121

print("✓ Attempted installation of xformers==0.0.29.post1 for CUDA 12.1")
print("Please re-run the import cell (9i36KlLvvd-H) to confirm the fix.")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m780.4/780.4 MB[0m [31m?[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m87.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m41.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m87.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m?[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m410.6/410.6 MB[0m [31m?[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.6/121.6 MB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.5/56.5 MB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# ---------- ControlNet Inpainting Teacher ----------

USE_SD_CONTROLNET = True  # set False if you just want SVD later

controlnet = None
sd_cn_pipe = None

if USE_SD_CONTROLNET:
    controlnet = ControlNetModel.from_pretrained(
        "lllyasviel/control_v11p_sd15_canny",
        torch_dtype=torch.float16,
    )

    sd_cn_pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
        "runwayml/stable-diffusion-inpainting",
        controlnet=controlnet,
        torch_dtype=torch.float16,
        safety_checker=None,
    ).to(device)

    sd_cn_pipe.enable_xformers_memory_efficient_attention()
    sd_cn_pipe.enable_model_cpu_offload()

    print("Loaded SD ControlNet Inpaint pipeline")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/996 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/1.45G [00:00<?, ?B/s]

model_index.json:   0%|          | 0.00/548 [00:00<?, ?B/s]

Fetching 14 files:   0%|          | 0/14 [00:00<?, ?it/s]

merges.txt: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/748 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

text_encoder/pytorch_model.bin:   0%|          | 0.00/492M [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

unet/diffusion_pytorch_model.bin:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

vae/diffusion_pytorch_model.bin:   0%|          | 0.00/335M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

An error occurred while trying to fetch /root/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /root/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
`torch_dtype` is deprecated! Use `dtype` instead!


ValueError: Due to a serious vulnerability issue in `torch.load`, even with `weights_only=True`, we now require users to upgrade torch to at least v2.6 in order to use the function. This version restriction does not apply when loading files with safetensors.
See the vulnerability report here https://nvd.nist.gov/vuln/detail/CVE-2025-32434