## 1. Setup and Imports

In [1]:
import os
import sys
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Transformers and Diffusers
from transformers import CLIPTokenizer, CLIPTextModel
from diffusers import (
    StableDiffusionInstructPix2PixPipeline,
    EulerAncestralDiscreteScheduler,
    DDPMScheduler,
    AutoencoderKL,
    UNet2DConditionModel
)
from diffusers.optimization import get_cosine_schedule_with_warmup
from torchvision import transforms

# Check GPU
USE_GPU = torch.cuda.is_available()
DEVICE = torch.device('cuda:0' if USE_GPU else 'cpu')

print(f"PyTorch Version: {torch.__version__}")
print(f"Device: {DEVICE}")
print(f"GPU Available: {USE_GPU}")
if USE_GPU:
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

PyTorch Version: 2.9.0+cu128
Device: cuda:0
GPU Available: True
GPU Memory: 191.51 GB


## 2. Configuration

In [2]:
from dataclasses import dataclass

@dataclass
class Config:
    # Model selection
    model_id: str = "timbrooks/instruct-pix2pix"  # Pretrained InstructPix2Pix
    
    # Image configuration
    image_size: int = 512
    
    # Training configuration
    batch_size: int = 2  # Small batch for 8GB+ GPU
    gradient_accumulation_steps: int = 2  # Effective batch = 4
    num_epochs: int = 20
    learning_rate: float = 5e-5
    weight_decay: float = 1e-2
    warmup_steps: int = 500
    max_grad_norm: float = 1.0
    use_mixed_precision: bool = True
    
    # Fine-tuning strategy
    fine_tune_unet: bool = True  # Fine-tune UNet
    fine_tune_text_encoder: bool = False  # Keep text encoder frozen
    fine_tune_vae: bool = False  # Keep VAE frozen
    
    # Data
    num_workers: int = 4
    
    # Paths
    checkpoints_dir: str = "./models/instructpix2pix_checkpoints"
    results_dir: str = "./results/instructpix2pix"
    log_dir: str = "./logs/instructpix2pix"

config = Config()

# Create directories
os.makedirs(config.checkpoints_dir, exist_ok=True)
os.makedirs(config.results_dir, exist_ok=True)
os.makedirs(config.log_dir, exist_ok=True)

print("Configuration:")
print(f"  Model: {config.model_id}")
print(f"  Image size: {config.image_size}x{config.image_size}")
print(f"  Batch size: {config.batch_size} (effective: {config.batch_size * config.gradient_accumulation_steps})")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Epochs: {config.num_epochs}")
print(f"  Fine-tune UNet: {config.fine_tune_unet}")
print(f"  Fine-tune Text Encoder: {config.fine_tune_text_encoder}")

Configuration:
  Model: timbrooks/instruct-pix2pix
  Image size: 512x512
  Batch size: 2 (effective: 4)
  Learning rate: 5e-05
  Epochs: 20
  Fine-tune UNet: True
  Fine-tune Text Encoder: False


## 3. Load SSID Dataset

In [3]:
# Load SSID annotations
def load_annotations(json_path, split_name):
    with open(json_path, 'r') as f:
        data = json.load(f)
    flat_data = [storylet for story in data['annotations'] for storylet in story]
    df = pd.DataFrame(flat_data)
    df['split'] = split_name
    return df

# Paths
annotations_dir = '../data/SSID_Annotations/SSID_Annotations'
images_dir = '../data/SSID_Images/SSID_Images'

train_json = os.path.join(annotations_dir, "SSID_Train.json")
val_json = os.path.join(annotations_dir, "SSID_Validation.json")

# Load splits
df_train = load_annotations(train_json, 'train')
df_val = load_annotations(val_json, 'val')

print(f"Train storylets: {len(df_train)}")
print(f"Validation storylets: {len(df_val)}")
print(f"Unique stories (train): {df_train['story_id'].nunique()}")
print(f"Unique stories (val): {df_val['story_id'].nunique()}")

Train storylets: 62500
Validation storylets: 3480
Unique stories (train): 12500
Unique stories (val): 696


## 4. Create Training Pairs (Previous Image + Text → Next Image)

In [4]:
def create_training_pairs(df, images_dir, split='train'):
    """
    Create (input_image, edit_prompt, output_image) triplets.
    input_image: previous frame in story
    edit_prompt: text description of what changes to next frame
    output_image: actual next frame
    """
    pairs = []
    missing_count = 0
    
    for story_id in df['story_id'].unique():
        story_data = df[df['story_id'] == story_id].sort_values('image_order').reset_index(drop=True)
        
        # Need at least 2 images per story
        if len(story_data) < 2:
            continue
        
        # Create pairs: (image_t, text_t+1, image_t+1)
        for i in range(len(story_data) - 1):
            prev_row = story_data.iloc[i]
            next_row = story_data.iloc[i + 1]
            
            # Try multiple image path patterns
            image_id_prev = str(prev_row['youtube_image_id']).strip()
            image_id_next = str(next_row['youtube_image_id']).strip()
            
            # Try different path formats
            possible_prev_paths = [
                os.path.join(images_dir, f"{image_id_prev}.jpg"),
                os.path.join(images_dir, image_id_prev),
                os.path.join(images_dir, f"{image_id_prev}.png"),
            ]
            possible_next_paths = [
                os.path.join(images_dir, f"{image_id_next}.jpg"),
                os.path.join(images_dir, image_id_next),
                os.path.join(images_dir, f"{image_id_next}.png"),
            ]
            
            prev_img_path = None
            next_img_path = None
            
            for path in possible_prev_paths:
                if os.path.exists(path):
                    prev_img_path = path
                    break
            
            for path in possible_next_paths:
                if os.path.exists(path):
                    next_img_path = path
                    break
            
            if prev_img_path and next_img_path:
                pairs.append({
                    'input_image': prev_img_path,  # Previous image
                    'edit_prompt': next_row['storytext'],  # Text describing next scene
                    'output_image': next_img_path,  # Target next image
                    'story_id': story_id,
                    'split': split
                })
            else:
                missing_count += 1
    
    return pairs, missing_count

# Create pairs
train_pairs, train_missing = create_training_pairs(df_train, images_dir, 'train')
val_pairs, val_missing = create_training_pairs(df_val, images_dir, 'val')

print(f"Training pairs: {len(train_pairs)} (missing: {train_missing})")
print(f"Validation pairs: {len(val_pairs)} (missing: {val_missing})")

# Debug: List what's in the images directory
print(f"\nImages directory: {images_dir}")
print(f"Images directory exists: {os.path.exists(images_dir)}")
if os.path.exists(images_dir):
    img_files = os.listdir(images_dir)
    print(f"Number of files: {len(img_files)}")
    if img_files:
        print(f"Sample files: {img_files[:5]}")

# Debug: Show sample storylet data
if len(df_train) > 0:
    print(f"\nSample storylet data:")
    print(df_train[['youtube_image_id', 'story_id', 'image_order']].head())

if train_pairs:
    print(f"\nExample pair:")
    pair = train_pairs[0]
    print(f"  Input image: {os.path.basename(pair['input_image'])}")
    print(f"  Edit prompt: {pair['edit_prompt'][:60]}...")
    print(f"  Output image: {os.path.basename(pair['output_image'])}")

Training pairs: 50000 (missing: 0)
Validation pairs: 2784 (missing: 0)

Images directory: ../data/SSID_Images/SSID_Images
Images directory exists: True
Number of files: 17367
Sample files: ['265.jpg', '10883.jpg', '10195.jpg', '3475.jpg', '25.jpg']

Sample storylet data:
  youtube_image_id  story_id  image_order
0             2001      5887            1
1             2002      5887            2
2             2003      5887            3
3             2004      5887            4
4             2005      5887            5

Example pair:
  Input image: 2001.jpg
  Edit prompt: He is telling me about his car....
  Output image: 2002.jpg


## 5. Custom Dataset Class

In [5]:
class InstructPix2PixDataset(Dataset):
    """
    Dataset for InstructPix2Pix fine-tuning.
    Input: (input_image, edit_prompt)
    Target: output_image
    """
    
    def __init__(self, pairs, tokenizer, image_size=512):
        self.pairs = pairs
        self.tokenizer = tokenizer
        self.image_size = image_size
        
        # Image transformations
        self.image_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.CenterCrop((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5]
            )
        ])
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        
        # Load and transform images
        try:
            input_img = Image.open(pair['input_image']).convert('RGB')
            output_img = Image.open(pair['output_image']).convert('RGB')
        except Exception as e:
            print(f"Error loading images: {e}")
            # Return blank images as fallback
            input_img = Image.new('RGB', (self.image_size, self.image_size), color='gray')
            output_img = Image.new('RGB', (self.image_size, self.image_size), color='gray')
        
        input_tensor = self.image_transform(input_img)
        output_tensor = self.image_transform(output_img)
        
        # Tokenize edit prompt
        prompt = pair['edit_prompt']
        tokens = self.tokenizer(
            prompt,
            padding='max_length',
            max_length=77,
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_image': input_tensor,
            'prompt_input_ids': tokens['input_ids'].squeeze(),
            'prompt_attention_mask': tokens['attention_mask'].squeeze(),
            'output_image': output_tensor,
            'prompt': prompt
        }

# Create tokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

# Create datasets
train_dataset = InstructPix2PixDataset(train_pairs, tokenizer, config.image_size)
val_dataset = InstructPix2PixDataset(val_pairs, tokenizer, config.image_size)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=USE_GPU
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=USE_GPU
)

print(f"Train loader batches: {len(train_loader)}")
print(f"Validation loader batches: {len(val_loader)}")

Train dataset size: 50000
Validation dataset size: 2784
Train loader batches: 25000
Validation loader batches: 1392


## 6. Load Pretrained InstructPix2Pix Model

In [None]:
print(f"Loading pretrained InstructPix2Pix model...")
print(f"Model: {config.model_id}\n")

try:
    # Load the full pipeline in float32 for training
    pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
        config.model_id,
        torch_dtype=torch.float32,
        low_cpu_mem_usage=False
    )
    
    # Extract individual components
    tokenizer = pipe.tokenizer
    text_encoder = pipe.text_encoder
    vae = pipe.vae
    unet = pipe.unet
    scheduler = pipe.scheduler
    
    # Move to device FIRST
    text_encoder = text_encoder.to(DEVICE)
    vae = vae.to(DEVICE)
    unet = unet.to(DEVICE)
    
    # THEN convert to float32 explicitly (handles both params and buffers)
    text_encoder = text_encoder.float()
    vae = vae.float()
    unet = unet.float()
    
    print("✓ Model components loaded successfully!")
    print(f"\nModel Architecture:")
    print(f"  Text Encoder: {sum(p.numel() for p in text_encoder.parameters()) / 1e6:.1f}M parameters (dtype: {next(text_encoder.parameters()).dtype})")
    print(f"  VAE: {sum(p.numel() for p in vae.parameters()) / 1e6:.1f}M parameters (dtype: {next(vae.parameters()).dtype})")
    print(f"  UNet: {sum(p.numel() for p in unet.parameters()) / 1e6:.1f}M parameters (dtype: {next(unet.parameters()).dtype})")
    print(f"  Total: {(sum(p.numel() for p in text_encoder.parameters()) + sum(p.numel() for p in vae.parameters()) + sum(p.numel() for p in unet.parameters())) / 1e6:.1f}M parameters")
    
except Exception as e:
    print(f"✗ Error loading model: {e}")
    print(f"Make sure you have internet connection to download the model")

Loading pretrained InstructPix2Pix model...
Model: timbrooks/instruct-pix2pix



Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


✓ Model components loaded successfully!

Model Architecture:
  Text Encoder: 123.1M parameters
  VAE: 83.7M parameters
  UNet: 859.5M parameters
  Total: 1066.2M parameters


## 7. Setup for Fine-tuning

In [7]:
# Freeze components we're not fine-tuning
if not config.fine_tune_text_encoder:
    text_encoder.requires_grad_(False)
    print("Text encoder frozen")

if not config.fine_tune_vae:
    vae.requires_grad_(False)
    print("VAE frozen")

# Only fine-tune UNet
if config.fine_tune_unet:
    unet.requires_grad_(True)
    print("UNet unfrozen for fine-tuning")
else:
    unet.requires_grad_(False)

# Count trainable parameters
trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad)
print(f"\nTrainable parameters: {trainable_params / 1e6:.1f}M")

# Create optimizer
optimizer = AdamW(
    unet.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay,
    betas=(0.9, 0.999)
)

# Learning rate scheduler
num_update_steps_per_epoch = len(train_loader) // config.gradient_accumulation_steps
max_train_steps = config.num_epochs * num_update_steps_per_epoch

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config.warmup_steps,
    num_training_steps=max_train_steps
)

# Gradient scaler for mixed precision
scaler = GradScaler() if config.use_mixed_precision and USE_GPU else None

print(f"\nOptimizer: AdamW")
print(f"Learning rate: {config.learning_rate}")
print(f"Total training steps: {max_train_steps}")
print(f"Mixed precision: {config.use_mixed_precision}")

Text encoder frozen
VAE frozen
UNet unfrozen for fine-tuning

Trainable parameters: 859.5M

Optimizer: AdamW
Learning rate: 5e-05
Total training steps: 250000
Mixed precision: True


## 8. Training Function

In [11]:
def encode_text_embedding(tokenizer, text_encoder, prompt, device, dtype):
    """Encode text prompt to embedding"""
    tokens = tokenizer(
        prompt,
        padding='max_length',
        max_length=77,
        truncation=True,
        return_tensors='pt'
    )
    
    with torch.no_grad():
        text_embedding = text_encoder(
            input_ids=tokens['input_ids'].to(device),
            attention_mask=tokens['attention_mask'].to(device)
        )[0]
    
    return text_embedding.to(dtype)


def train_epoch(epoch, unet, vae, text_encoder, tokenizer, train_loader, optimizer, lr_scheduler, criterion, device, scaler, dtype):
    """Train for one epoch"""
    unet.train()
    total_loss = 0.0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", ncols=80)
    
    for batch_idx, batch in enumerate(progress_bar):
        try:
            # Load batch to device with correct dtype
            input_images = batch['input_image'].to(device, dtype=dtype)
            output_images = batch['output_image'].to(device, dtype=dtype)
            prompts = batch['prompt']
            
            # Encode images to latent space
            with torch.no_grad():
                input_latents = vae.encode(input_images).latent_dist.sample() * 0.18215
                output_latents = vae.encode(output_images).latent_dist.sample() * 0.18215
                input_latents = input_latents.to(dtype)
                output_latents = output_latents.to(dtype)
            
            # Encode prompts
            prompt_embeds_list = []
            for prompt in prompts:
                prompt_embed = encode_text_embedding(tokenizer, text_encoder, prompt, device, dtype)
                prompt_embeds_list.append(prompt_embed)
            prompt_embeds = torch.cat(prompt_embeds_list, dim=0)
            
            # Sample random timesteps
            timesteps = torch.randint(
                0,
                scheduler.config.num_train_timesteps,
                (input_latents.shape[0],),
                device=device
            ).long()
            
            # Sample noise
            noise = torch.randn_like(output_latents)
            
            # Add noise to output latents (forward process)
            noisy_latents = scheduler.add_noise(output_latents, noise, timesteps)
            
            # Concatenate input and noisy latents
            latent_model_input = torch.cat([input_latents, noisy_latents], dim=1)
            
            with autocast(enabled=config.use_mixed_precision and USE_GPU):
                # Predict noise
                noise_pred = unet(
                    latent_model_input,
                    timesteps,
                    encoder_hidden_states=prompt_embeds
                ).sample
                
                # MSE loss
                loss = F.mse_loss(noise_pred, noise, reduction='mean')
            
            # Backward pass with gradient accumulation
            if config.use_mixed_precision and USE_GPU:
                scaler.scale(loss / config.gradient_accumulation_steps).backward()
            else:
                (loss / config.gradient_accumulation_steps).backward()
            
            # Gradient accumulation step
            if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
                if config.use_mixed_precision and USE_GPU:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(unet.parameters(), config.max_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(unet.parameters(), config.max_grad_norm)
                    optimizer.step()
                
                optimizer.zero_grad()
                lr_scheduler.step()
            
            total_loss += loss.item()
            progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
            
        except Exception as e:
            print(f"\nError in batch {batch_idx}: {e}")
            continue
    
    avg_loss = total_loss / len(train_loader)
    return avg_loss


def validate(unet, vae, text_encoder, tokenizer, val_loader, device, dtype):
    """Validate model"""
    unet.eval()
    total_loss = 0.0
    
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc="Validating", ncols=80)
        
        for batch in progress_bar:
            try:
                input_images = batch['input_image'].to(device, dtype=dtype)
                output_images = batch['output_image'].to(device, dtype=dtype)
                prompts = batch['prompt']
                
                # Encode images
                input_latents = vae.encode(input_images).latent_dist.sample() * 0.18215
                output_latents = vae.encode(output_images).latent_dist.sample() * 0.18215
                input_latents = input_latents.to(dtype)
                output_latents = output_latents.to(dtype)
                
                # Encode prompts
                prompt_embeds_list = []
                for prompt in prompts:
                    prompt_embed = encode_text_embedding(tokenizer, text_encoder, prompt, device, dtype)
                    prompt_embeds_list.append(prompt_embed)
                prompt_embeds = torch.cat(prompt_embeds_list, dim=0)
                
                # Random timesteps
                timesteps = torch.randint(
                    0,
                    scheduler.config.num_train_timesteps,
                    (input_latents.shape[0],),
                    device=device
                ).long()
                
                # Noise
                noise = torch.randn_like(output_latents)
                
                noisy_latents = scheduler.add_noise(output_latents, noise, timesteps)
                latent_model_input = torch.cat([input_latents, noisy_latents], dim=1)
                
                # Forward
                noise_pred = unet(
                    latent_model_input,
                    timesteps,
                    encoder_hidden_states=prompt_embeds
                ).sample
                
                loss = F.mse_loss(noise_pred, noise, reduction='mean')
                total_loss += loss.item()
                progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
                
            except Exception as e:
                continue
    
    avg_loss = total_loss / len(val_loader)
    return avg_loss

print("Training functions defined")

Training functions defined


## 9. Training Loop

In [16]:
train_losses = []
val_losses = []
best_val_loss = float('inf')
patience = 5
patience_counter = 0

# Use float32 for stable training (regardless of GPU)
dtype = torch.float32

print("Starting training...\n")
print(f"Model: {config.model_id}")
print(f"Epochs: {config.num_epochs}")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Device: {DEVICE}\n")
print(f"Dtype: {dtype}\n")
print("="*60)

for epoch in range(config.num_epochs):
    print(f"\nEpoch {epoch + 1}/{config.num_epochs}")
    print("-" * 60)
    
    # Train
    train_loss = train_epoch(
        epoch,
        unet,
        vae,
        text_encoder,
        tokenizer,
        train_loader,
        optimizer,
        lr_scheduler,
        F.mse_loss,
        DEVICE,
        scaler,
        dtype
    )
    train_losses.append(train_loss)
    
    # Validate
    val_loss = validate(unet, vae, text_encoder, tokenizer, val_loader, DEVICE, dtype)
    val_losses.append(val_loss)
    
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        
        # Save best model
        checkpoint_path = os.path.join(config.checkpoints_dir, "best_unet")
        unet.save_pretrained(checkpoint_path)
        print(f"✓ Best model saved (val_loss: {val_loss:.4f})")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break
    
    # Save checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        checkpoint_path = os.path.join(config.checkpoints_dir, f"unet_epoch_{epoch + 1}")
        unet.save_pretrained(checkpoint_path)
        print(f"✓ Checkpoint saved")

print("\n" + "="*60)
print("Training completed!")
print(f"Final train loss: {train_losses[-1]:.4f}")
print(f"Final val loss: {val_losses[-1]:.4f}")

Starting training...

Model: timbrooks/instruct-pix2pix
Epochs: 20
Training samples: 50000
Validation samples: 2784
Device: cuda:0

Dtype: torch.float32


Epoch 1/20
------------------------------------------------------------


Epoch 1:   0%|                               | 18/25000 [00:01<31:07, 13.38it/s]


Error in batch 0: Input type (float) and bias type (c10::Half) should be the same

Error in batch 1: Input type (float) and bias type (c10::Half) should be the same

Error in batch 2: Input type (float) and bias type (c10::Half) should be the same

Error in batch 3: Input type (float) and bias type (c10::Half) should be the same

Error in batch 4: Input type (float) and bias type (c10::Half) should be the same

Error in batch 5: Input type (float) and bias type (c10::Half) should be the same

Error in batch 6: Input type (float) and bias type (c10::Half) should be the same

Error in batch 7: Input type (float) and bias type (c10::Half) should be the same

Error in batch 8: Input type (float) and bias type (c10::Half) should be the same

Error in batch 9: Input type (float) and bias type (c10::Half) should be the same

Error in batch 10: Input type (float) and bias type (c10::Half) should be the same

Error in batch 11: Input type (float) and bias type (c10::Half) should be the same

E

Epoch 1:   0%|                               | 44/25000 [00:02<11:29, 36.18it/s]


Error in batch 27: Input type (float) and bias type (c10::Half) should be the same

Error in batch 28: Input type (float) and bias type (c10::Half) should be the same

Error in batch 29: Input type (float) and bias type (c10::Half) should be the same

Error in batch 30: Input type (float) and bias type (c10::Half) should be the same

Error in batch 31: Input type (float) and bias type (c10::Half) should be the same

Error in batch 32: Input type (float) and bias type (c10::Half) should be the same

Error in batch 33: Input type (float) and bias type (c10::Half) should be the same

Error in batch 34: Input type (float) and bias type (c10::Half) should be the same

Error in batch 35: Input type (float) and bias type (c10::Half) should be the same

Error in batch 36: Input type (float) and bias type (c10::Half) should be the same

Error in batch 37: Input type (float) and bias type (c10::Half) should be the same

Error in batch 38: Input type (float) and bias type (c10::Half) should be t

Epoch 1:   0%|                               | 76/25000 [00:02<06:03, 68.50it/s]


Error in batch 55: Input type (float) and bias type (c10::Half) should be the same

Error in batch 56: Input type (float) and bias type (c10::Half) should be the same

Error in batch 57: Input type (float) and bias type (c10::Half) should be the same

Error in batch 58: Input type (float) and bias type (c10::Half) should be the same

Error in batch 59: Input type (float) and bias type (c10::Half) should be the same

Error in batch 60: Input type (float) and bias type (c10::Half) should be the same

Error in batch 61: Input type (float) and bias type (c10::Half) should be the same

Error in batch 62: Input type (float) and bias type (c10::Half) should be the same

Error in batch 63: Input type (float) and bias type (c10::Half) should be the same

Error in batch 64: Input type (float) and bias type (c10::Half) should be the same

Error in batch 65: Input type (float) and bias type (c10::Half) should be the same

Error in batch 66: Input type (float) and bias type (c10::Half) should be t

Epoch 1:   0%|▏                             | 108/25000 [00:02<04:16, 97.22it/s]


Error in batch 85: Input type (float) and bias type (c10::Half) should be the same

Error in batch 86: Input type (float) and bias type (c10::Half) should be the same

Error in batch 87: Input type (float) and bias type (c10::Half) should be the same

Error in batch 88: Input type (float) and bias type (c10::Half) should be the same

Error in batch 89: Input type (float) and bias type (c10::Half) should be the same

Error in batch 90: Input type (float) and bias type (c10::Half) should be the same

Error in batch 91: Input type (float) and bias type (c10::Half) should be the same

Error in batch 92: Input type (float) and bias type (c10::Half) should be the same

Error in batch 93: Input type (float) and bias type (c10::Half) should be the same

Error in batch 94: Input type (float) and bias type (c10::Half) should be the same

Error in batch 95: Input type (float) and bias type (c10::Half) should be the same

Error in batch 96: Input type (float) and bias type (c10::Half) should be t

Epoch 1:   1%|▏                            | 140/25000 [00:02<03:31, 117.82it/s]


Error in batch 115: Input type (float) and bias type (c10::Half) should be the same

Error in batch 116: Input type (float) and bias type (c10::Half) should be the same

Error in batch 117: Input type (float) and bias type (c10::Half) should be the same

Error in batch 118: Input type (float) and bias type (c10::Half) should be the same

Error in batch 119: Input type (float) and bias type (c10::Half) should be the same

Error in batch 120: Input type (float) and bias type (c10::Half) should be the same

Error in batch 121: Input type (float) and bias type (c10::Half) should be the same

Error in batch 122: Input type (float) and bias type (c10::Half) should be the same

Error in batch 123: Input type (float) and bias type (c10::Half) should be the same

Error in batch 124: Input type (float) and bias type (c10::Half) should be the same

Error in batch 125: Input type (float) and bias type (c10::Half) should be the same

Error in batch 126: Input type (float) and bias type (c10::Half)

Epoch 1:   1%|▏                            | 172/25000 [00:02<03:11, 129.45it/s]


Error in batch 147: Input type (float) and bias type (c10::Half) should be the same

Error in batch 148: Input type (float) and bias type (c10::Half) should be the same

Error in batch 149: Input type (float) and bias type (c10::Half) should be the same

Error in batch 150: Input type (float) and bias type (c10::Half) should be the same

Error in batch 151: Input type (float) and bias type (c10::Half) should be the same

Error in batch 152: Input type (float) and bias type (c10::Half) should be the same

Error in batch 153: Input type (float) and bias type (c10::Half) should be the same

Error in batch 154: Input type (float) and bias type (c10::Half) should be the same

Error in batch 155: Input type (float) and bias type (c10::Half) should be the same

Error in batch 156: Input type (float) and bias type (c10::Half) should be the same

Error in batch 157: Input type (float) and bias type (c10::Half) should be the same

Error in batch 158: Input type (float) and bias type (c10::Half)

Epoch 1:   1%|▏                            | 202/25000 [00:03<03:09, 130.98it/s]


Error in batch 175: Input type (float) and bias type (c10::Half) should be the same

Error in batch 176: Input type (float) and bias type (c10::Half) should be the same

Error in batch 177: Input type (float) and bias type (c10::Half) should be the same

Error in batch 178: Input type (float) and bias type (c10::Half) should be the same

Error in batch 179: Input type (float) and bias type (c10::Half) should be the same

Error in batch 180: Input type (float) and bias type (c10::Half) should be the same

Error in batch 181: Input type (float) and bias type (c10::Half) should be the same

Error in batch 182: Input type (float) and bias type (c10::Half) should be the same

Error in batch 183: Input type (float) and bias type (c10::Half) should be the same

Error in batch 184: Input type (float) and bias type (c10::Half) should be the same

Error in batch 185: Input type (float) and bias type (c10::Half) should be the same

Error in batch 186: Input type (float) and bias type (c10::Half)

Epoch 1:   1%|▎                            | 216/25000 [00:03<03:11, 129.65it/s]


Error in batch 203: Input type (float) and bias type (c10::Half) should be the same

Error in batch 204: Input type (float) and bias type (c10::Half) should be the same

Error in batch 205: Input type (float) and bias type (c10::Half) should be the same

Error in batch 206: Input type (float) and bias type (c10::Half) should be the same

Error in batch 207: Input type (float) and bias type (c10::Half) should be the same

Error in batch 208: Input type (float) and bias type (c10::Half) should be the same

Error in batch 209: Input type (float) and bias type (c10::Half) should be the same

Error in batch 210: Input type (float) and bias type (c10::Half) should be the same

Error in batch 211: Input type (float) and bias type (c10::Half) should be the same

Error in batch 212: Input type (float) and bias type (c10::Half) should be the same

Error in batch 213: Input type (float) and bias type (c10::Half) should be the same

Error in batch 214: Input type (float) and bias type (c10::Half)

Epoch 1:   1%|▎                            | 248/25000 [00:03<03:02, 135.47it/s]


Error in batch 231: Input type (float) and bias type (c10::Half) should be the same

Error in batch 232: Input type (float) and bias type (c10::Half) should be the same

Error in batch 233: Input type (float) and bias type (c10::Half) should be the same

Error in batch 234: Input type (float) and bias type (c10::Half) should be the same

Error in batch 235: Input type (float) and bias type (c10::Half) should be the same

Error in batch 236: Input type (float) and bias type (c10::Half) should be the same

Error in batch 237: Input type (float) and bias type (c10::Half) should be the same

Error in batch 238: Input type (float) and bias type (c10::Half) should be the same

Error in batch 239: Input type (float) and bias type (c10::Half) should be the same

Error in batch 240: Input type (float) and bias type (c10::Half) should be the same

Error in batch 241: Input type (float) and bias type (c10::Half) should be the same

Error in batch 242: Input type (float) and bias type (c10::Half)

Epoch 1:   1%|▎                            | 280/25000 [00:03<02:58, 138.46it/s]


Error in batch 263: Input type (float) and bias type (c10::Half) should be the same

Error in batch 264: Input type (float) and bias type (c10::Half) should be the same

Error in batch 265: Input type (float) and bias type (c10::Half) should be the same

Error in batch 266: Input type (float) and bias type (c10::Half) should be the same

Error in batch 267: Input type (float) and bias type (c10::Half) should be the same

Error in batch 268: Input type (float) and bias type (c10::Half) should be the same

Error in batch 269: Input type (float) and bias type (c10::Half) should be the same

Error in batch 270: Input type (float) and bias type (c10::Half) should be the same

Error in batch 271: Input type (float) and bias type (c10::Half) should be the same

Error in batch 272: Input type (float) and bias type (c10::Half) should be the same

Error in batch 273: Input type (float) and bias type (c10::Half) should be the same

Error in batch 274: Input type (float) and bias type (c10::Half)

Epoch 1:   1%|▎                             | 307/25000 [00:03<05:18, 77.61it/s]



Error in batch 291: Input type (float) and bias type (c10::Half) should be the same

Error in batch 292: Input type (float) and bias type (c10::Half) should be the same

Error in batch 293: Input type (float) and bias type (c10::Half) should be the same

Error in batch 294: Input type (float) and bias type (c10::Half) should be the same

Error in batch 295: Input type (float) and bias type (c10::Half) should be the same

Error in batch 296: Input type (float) and bias type (c10::Half) should be the same

Error in batch 297: Input type (float) and bias type (c10::Half) should be the same

Error in batch 298: Input type (float) and bias type (c10::Half) should be the same

Error in batch 299: Input type (float) and bias type (c10::Half) should be the same

Error in batch 300: Input type (float) and bias type (c10::Half) should be the same

Error in batch 301: Input type (float) and bias type (c10::Half) should be the same

Error in batch 302: Input type (float) and bias type (c10::Half)

KeyboardInterrupt: 

## 10. Plot Training History

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(train_losses, label='Train Loss', marker='o', linewidth=2, markersize=6)
ax.plot(val_losses, label='Val Loss', marker='s', linewidth=2, markersize=6)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('InstructPix2Pix Fine-tuning - Training History', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(config.results_dir, 'training_loss.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"Loss plot saved to: {os.path.join(config.results_dir, 'training_loss.png')}")

## 11. Load Fine-tuned Model and Create Pipeline

In [None]:
# Load fine-tuned UNet
best_unet_path = os.path.join(config.checkpoints_dir, "best_unet")

if os.path.exists(best_unet_path):
    print(f"Loading fine-tuned UNet from: {best_unet_path}")
    
    # Create new pipeline with fine-tuned UNet
    pipe_finetuned = StableDiffusionInstructPix2PixPipeline.from_pretrained(
        config.model_id,
        unet=UNet2DConditionModel.from_pretrained(best_unet_path),
        torch_dtype=torch.float16 if USE_GPU else torch.float32
    ).to(DEVICE)
    
    print("✓ Fine-tuned pipeline created successfully")
else:
    print(f"Best model not found at {best_unet_path}")
    print("Using original pretrained model")
    pipe_finetuned = StableDiffusionInstructPix2PixPipeline.from_pretrained(
        config.model_id,
        torch_dtype=torch.float16 if USE_GPU else torch.float32
    ).to(DEVICE)

## 12. Test Inference

In [None]:
print(f"Testing inference on validation samples...\n")

for idx in range(min(3, len(val_pairs))):
    pair = val_pairs[idx]
    
    print(f"\nSample {idx + 1}:")
    print(f"  Edit prompt: {pair['edit_prompt'][:60]}...")
    
    try:
        # Load input image
        input_img = Image.open(pair['input_image']).convert('RGB')
        
        # Generate next image
        with torch.no_grad():
            generated = pipe_finetuned(
                prompt=pair['edit_prompt'],
                image=input_img,
                guidance_scale=7.5,
                num_inference_steps=30
            ).images[0]
        
        # Load ground truth
        ground_truth = Image.open(pair['output_image']).convert('RGB')
        
        # Display
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        axes[0].imshow(input_img)
        axes[0].set_title('Input Image (Previous Frame)', fontsize=10)
        axes[0].axis('off')
        
        axes[1].imshow(generated)
        axes[1].set_title('Generated (Fine-tuned)', fontsize=10, color='green')
        axes[1].axis('off')
        
        axes[2].imshow(ground_truth)
        axes[2].set_title('Ground Truth (Next Frame)', fontsize=10, color='blue')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.savefig(
            os.path.join(config.results_dir, f'inference_sample_{idx + 1}.png'),
            dpi=100,
            bbox_inches='tight'
        )
        plt.show()
        print(f"  ✓ Sample saved")
        
    except Exception as e:
        print(f"  ✗ Error: {e}")

print(f"\nInference samples saved to: {config.results_dir}")

## 13. Generate Full Story Sequences

In [None]:
def generate_story_with_instructpix2pix(story_id, df, pipe, images_dir):
    """
    Generate a full story sequence using InstructPix2Pix.
    """
    story_data = df[df['story_id'] == story_id].sort_values('image_order')
    generated_imgs = []
    
    if len(story_data) < 2:
        print(f"Story {story_id} has less than 2 images")
        return None
    
    # First image (use real)
    first_img_path = os.path.join(images_dir, f"{story_data.iloc[0]['youtube_image_id']}.jpg")
    first_img = Image.open(first_img_path).convert('RGB')
    generated_imgs.append(first_img)
    current_img = first_img
    
    # Generate remaining frames
    print(f"\nGenerating story sequence (story_id: {story_id})...")
    
    with torch.no_grad():
        for i in range(1, len(story_data)):
            prompt = story_data.iloc[i]['storytext']
            print(f"  Frame {i + 1}/{len(story_data)}: {prompt[:40]}...")
            
            try:
                generated = pipe(
                    prompt=prompt,
                    image=current_img,
                    guidance_scale=7.5,
                    num_inference_steps=30
                ).images[0]
                
                generated_imgs.append(generated)
                current_img = generated  # Use generated as input for next
                
            except Exception as e:
                print(f"    Error: {e}")
                break
    
    return generated_imgs, story_data


# Generate a story
if len(val_pairs) > 0:
    story_id = val_pairs[0]['story_id']
    
    result = generate_story_with_instructpix2pix(
        story_id,
        df_val,
        pipe_finetuned,
        images_dir
    )
    
    if result:
        generated_imgs, story_data = result
        
        # Display sequence
        num_imgs = len(generated_imgs)
        fig, axes = plt.subplots(2, (num_imgs + 1) // 2, figsize=(4*(num_imgs), 8))
        axes = axes.flatten() if num_imgs > 1 else [axes]
        
        for idx, img in enumerate(generated_imgs):
            axes[idx].imshow(img)
            axes[idx].set_title(f"Frame {idx + 1}", fontsize=10)
            axes[idx].axis('off')
        
        # Hide unused subplots
        for idx in range(num_imgs, len(axes)):
            axes[idx].axis('off')
        
        plt.tight_layout()
        plt.savefig(
            os.path.join(config.results_dir, 'full_story_sequence.png'),
            dpi=100,
            bbox_inches='tight'
        )
        plt.show()
        
        print(f"\nGenerated {num_imgs} frames for story {story_id}")
        print("\nStory narrative:")
        for idx, row in story_data.iterrows():
            print(f"{row['image_order']}. {row['storytext']}")

## 14. Summary

In [None]:
print("="*70)
print("INSTRUCTPIX2PIX FINE-TUNING - SUMMARY")
print("="*70)

print(f"\nModel: InstructPix2Pix (Pretrained)")
print(f"  Base model: {config.model_id}")
print(f"  Architecture: Stable Diffusion + instruction tuning")
print(f"  Input: (Previous image, Text instruction)")
print(f"  Output: Next image in sequence")

print(f"\nDataset:")
print(f"  Training pairs: {len(train_pairs)}")
print(f"  Validation pairs: {len(val_pairs)}")
print(f"  Stories (train): {df_train['story_id'].nunique()}")
print(f"  Stories (val): {df_val['story_id'].nunique()}")

print(f"\nTraining:")
print(f"  Batch size: {config.batch_size} (effective: {config.batch_size * config.gradient_accumulation_steps})")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Epochs: {config.num_epochs}")
print(f"  Mixed precision: {config.use_mixed_precision}")

print(f"\nResults:")
if train_losses:
    print(f"  Initial loss: {train_losses[0]:.4f}")
    print(f"  Final loss: {train_losses[-1]:.4f}")
    print(f"  Loss reduction: {(1 - train_losses[-1]/train_losses[0])*100:.1f}%")

print(f"\nCheckpoints: {config.checkpoints_dir}")
print(f"Results: {config.results_dir}")

print(f"\n" + "="*70)
print("KEY ADVANTAGES:")
print("✓ Pretrained on large-scale image-text datasets")
print("✓ Optimized for image editing with text instructions")
print("✓ Easy to fine-tune on custom dataset")
print("✓ High-quality image generation (512×512)")
print("✓ Contextually aware (uses previous image as reference)")
print("✓ Autoregressive inference for full story generation")
print("✓ Fast inference (~5-10 seconds per image)")
print("="*70)