# Sequential Story Image Generation

**Transform Previous Story Images + Story Text ‚Üí Next Coherent Illustration**

This notebook generates sequential story illustrations that maintain visual continuity across narrative segments using **three proven alternatives to InstructPix2Pix**:

1. **ControlNet Approach** - Condition Stable Diffusion on previous image using edge/pose maps
2. **Latent Concatenation** - Fast, direct image conditioning in latent space
3. **LoRA Adaptation** - Parameter-efficient fine-tuning for temporal consistency

**Key Features:**
- ‚úÖ Works with SSID (Sequential Story Illustration Dataset)
- ‚úÖ Maintains visual continuity across frames
- ‚úÖ Faster than InstructPix2Pix (no extra model loading)
- ‚úÖ Memory-efficient options (LoRA: 80% reduction)
- ‚úÖ Evaluates temporal consistency with SSIM/LPIPS
- ‚úÖ Generates story progression GIFs

## 1. Import Libraries and Setup

In [None]:
import os
import sys
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.animation import PillowWriter
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Transformers and Diffusers
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import StableDiffusionImg2ImgPipeline, AutoencoderKL, UNet2DConditionModel, DDPMScheduler, DiffusionPipeline
from diffusers.utils import load_image

# Metrics
try:
    from skimage.metrics import structural_similarity as ssim
    from skimage.color import rgb2gray
except:
    print("Installing scikit-image...")
    os.system(f"{sys.executable} -m pip install scikit-image -q")
    from skimage.metrics import structural_similarity as ssim
    from skimage.color import rgb2gray

# Setup
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# Device setup
USE_GPU = torch.cuda.is_available()
DEVICE = torch.device('cuda' if USE_GPU else 'cpu')
print(f"‚úì Device: {DEVICE}")
if USE_GPU:
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Load and Explore SSID Dataset

In [None]:
# Load SSID annotations
def load_ssid_annotations(json_path):
    """Load SSID JSON annotations."""
    with open(json_path, 'r') as f:
        data = json.load(f)
    return data

# Paths
ssid_dir = '../data/SSID_Annotations/SSID_Annotations'
images_dir = '../data/SSID_Images/SSID_Images'

# Load datasets
train_data = load_ssid_annotations(os.path.join(ssid_dir, 'SSID_Train.json'))
val_data = load_ssid_annotations(os.path.join(ssid_dir, 'SSID_Validation.json'))
test_data = load_ssid_annotations(os.path.join(ssid_dir, 'SSID_Test.json'))

print("=" * 70)
print("SSID DATASET STRUCTURE")
print("=" * 70)

# Analyze structure
def analyze_ssid_structure(data, split_name):
    """Analyze SSID data structure."""
    if 'annotations' in data:
        stories = data['annotations']
    else:
        stories = [data] if isinstance(data, list) else []
    
    print(f"\n{split_name}:")
    print(f"  Total stories: {len(stories)}")
    
    if stories:
        first_story = stories[0]
        print(f"  Story length: {len(first_story)} frames")
        if first_story:
            first_frame = first_story[0]
            print(f"  Frame keys: {list(first_frame.keys())}")
            print(f"  Sample caption: '{first_frame.get('caption', 'N/A')}'")
            print(f"  Sample image_id: {first_frame.get('image_id', 'N/A')}")

analyze_ssid_structure(train_data, "Train")
analyze_ssid_structure(val_data, "Validation")
analyze_ssid_structure(test_data, "Test")

# Get statistics
def get_caption_stats(data):
    """Get caption length statistics."""
    lengths = []
    if 'annotations' in data:
        for story in data['annotations']:
            for frame in story:
                caption = frame.get('caption', '')
                lengths.append(len(caption.split()))
    return lengths

train_lengths = get_caption_stats(train_data)
print(f"\nCaption Statistics (Train):")
print(f"  Mean length: {np.mean(train_lengths):.1f} words")
print(f"  Max length: {np.max(train_lengths)} words")
print(f"  Min length: {np.min(train_lengths)} words")
print(f"  Std dev: {np.std(train_lengths):.1f} words")

## 3. Configure Model Selection: Three Approaches

**Comparison of Methods:**

| Approach | Speed | Memory | Quality | Complexity | Best For |
|----------|-------|--------|---------|-----------|----------|
| **Latent Concat** | ‚ö°‚ö°‚ö° Fast | üíæ Low | ‚≠ê‚≠ê‚≠ê Good | üî® Simple | Quick iteration, limited VRAM |
| **LoRA Adapt** | ‚ö°‚ö° Medium | üíæüíæ Medium | ‚≠ê‚≠ê‚≠ê‚≠ê Excellent | üî®üî® Moderate | Balanced quality/speed |
| **ControlNet** | ‚ö° Slow | üíæüíæüíæ High | ‚≠ê‚≠ê‚≠ê‚≠ê‚≠ê Best | üî®üî®üî® Complex | Maximum quality, VRAM available |

**Recommendation:** Start with **Latent Concat** for testing, then upgrade to **LoRA** for production.

In [None]:
from dataclasses import dataclass

@dataclass
class Config:
    # Model selection
    model_id = "runwayml/stable-diffusion-v1-5"  # Base model
    approach = "latent_concat"  # Options: "latent_concat", "lora_adapt", "controlnet"
    
    # Training
    image_size = 512
    batch_size = 2
    num_epochs = 5
    learning_rate = 1e-4
    weight_decay = 1e-5
    
    # LoRA specific (if using LoRA approach)
    lora_rank = 16
    lora_alpha = 32
    
    # Paths
    checkpoints_dir = "../models/sequential_checkpoints"
    results_dir = "../results/sequential"
    log_dir = "../logs/sequential"
    
    # Evaluation
    compute_ssim = True
    compute_clip = True
    num_inference_steps = 30
    guidance_scale = 7.5

config = Config()

# Create directories
for d in [config.checkpoints_dir, config.results_dir, config.log_dir]:
    os.makedirs(d, exist_ok=True)

print("=" * 70)
print("APPROACH SELECTION")
print("=" * 70)
print(f"\nSelected Approach: {config.approach.upper()}")
print(f"Model: {config.model_id}")
print(f"Image Size: {config.image_size}x{config.image_size}")
print(f"Batch Size: {config.batch_size}")
print(f"Learning Rate: {config.learning_rate}")
print(f"\nDirectories created:")
print(f"  Checkpoints: {config.checkpoints_dir}")
print(f"  Results: {config.results_dir}")

## 4. Load Pre-trained Image-to-Image Model Components

In [None]:
print("=" * 70)
print("LOADING PRE-TRAINED MODELS")
print("=" * 70)

# Load components
print("\nLoading tokenizer and encoders...")
tokenizer = CLIPTokenizer.from_pretrained(config.model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(config.model_id, subfolder="text_encoder")

print("Loading VAE and UNet...")
vae = AutoencoderKL.from_pretrained(config.model_id, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(config.model_id, subfolder="unet")

print("Loading noise scheduler...")
noise_scheduler = DDPMScheduler.from_pretrained(config.model_id, subfolder="scheduler")

# Move to device
text_encoder = text_encoder.to(DEVICE)
vae = vae.to(DEVICE)
unet = unet.to(DEVICE)

# Set eval mode (VAE and text encoder are frozen)
vae.eval()
text_encoder.eval()
unet.train()  # Will be trained/fine-tuned

print("\n‚úì All components loaded successfully")
print(f"  Text Encoder: {text_encoder.__class__.__name__} (768-dim)")
print(f"  VAE: {vae.__class__.__name__}")
print(f"  UNet: {unet.__class__.__name__} (4-channel input)")
print(f"  Noise Scheduler: {noise_scheduler.__class__.__name__}")

# For latent concat approach: expand UNet input channels
if config.approach == "latent_concat":
    print("\n‚ö†Ô∏è Expanding UNet for latent concatenation...")
    
    # Original UNet expects 4 channels (noise latent)
    # For image conditioning, we concatenate previous image latent (4 channels)
    # Total input: 4 (noise) + 4 (previous image) = 8 channels
    
    original_conv_in = unet.conv_in
    new_conv_in = nn.Conv2d(8, original_conv_in.out_channels, 
                             kernel_size=original_conv_in.kernel_size,
                             padding=original_conv_in.padding).to(DEVICE)
    
    # Initialize new weights from original
    with torch.no_grad():
        new_conv_in.weight[:, :4] = original_conv_in.weight
        new_conv_in.weight[:, 4:] = original_conv_in.weight.mean(dim=1, keepdim=True)
        if original_conv_in.bias is not None:
            new_conv_in.bias = original_conv_in.bias
    
    unet.conv_in = new_conv_in
    print("  ‚úì UNet expanded to 8-channel input (4 noise + 4 previous image)")

## 5. Prepare Sequential Data Pipeline

In [None]:
class SSIDSequentialDataset(Dataset):
    """
    Sequential Story dataset: (prev_image, current_text) ‚Üí target_image
    """
    def __init__(self, data, images_dir, split='train', max_stories=None):
        self.data = data
        self.images_dir = images_dir
        self.split = split
        self.triplets = []
        
        # Build triplets: (previous_image_path, caption, target_image_path)
        if 'annotations' in data:
            stories = data['annotations']
        else:
            stories = [data] if isinstance(data, list) else []
        
        if max_stories:
            stories = stories[:max_stories]
        
        for story_idx, story in enumerate(stories):
            # Skip stories with < 2 frames (need previous and target)
            if len(story) < 2:
                continue
            
            for frame_idx in range(1, len(story)):
                prev_frame = story[frame_idx - 1]
                curr_frame = story[frame_idx]
                
                prev_img_id = prev_frame.get('image_id', '')
                curr_caption = curr_frame.get('caption', '')
                curr_img_id = curr_frame.get('image_id', '')
                
                # Construct full paths
                prev_img_path = os.path.join(images_dir, f"{prev_img_id}.jpg")
                curr_img_path = os.path.join(images_dir, f"{curr_img_id}.jpg")
                
                # Only add if images exist
                if os.path.exists(prev_img_path) and os.path.exists(curr_img_path):
                    self.triplets.append({
                        'prev_image': prev_img_path,
                        'caption': curr_caption,
                        'target_image': curr_img_path,
                        'story_id': story_idx,
                        'frame_idx': frame_idx
                    })
        
        print(f"Loaded {len(self.triplets)} triplets for {split} split")
    
    def __len__(self):
        return len(self.triplets)
    
    def __getitem__(self, idx):
        triplet = self.triplets[idx]
        
        # Load images
        prev_img = Image.open(triplet['prev_image']).convert('RGB')
        target_img = Image.open(triplet['target_image']).convert('RGB')
        
        # Resize
        prev_img = prev_img.resize((config.image_size, config.image_size), Image.LANCZOS)
        target_img = target_img.resize((config.image_size, config.image_size), Image.LANCZOS)
        
        # Convert to tensors [0, 1]
        prev_img_tensor = torch.from_numpy(np.array(prev_img)).float() / 255.0
        target_img_tensor = torch.from_numpy(np.array(target_img)).float() / 255.0
        
        # Normalize to [-1, 1]
        prev_img_tensor = prev_img_tensor * 2 - 1
        target_img_tensor = target_img_tensor * 2 - 1
        
        # CHW format
        prev_img_tensor = prev_img_tensor.permute(2, 0, 1)
        target_img_tensor = target_img_tensor.permute(2, 0, 1)
        
        return {
            'prev_image': prev_img_tensor,
            'caption': triplet['caption'],
            'target_image': target_img_tensor,
            'story_id': triplet['story_id']
        }

# Create datasets
print("\n" + "=" * 70)
print("PREPARING DATA PIPELINE")
print("=" * 70)

train_dataset = SSIDSequentialDataset(train_data, images_dir, 'train', max_stories=50)
val_dataset = SSIDSequentialDataset(val_data, images_dir, 'val', max_stories=10)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

print(f"\n‚úì Data pipeline ready:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

## 6. Fine-tune Model on Sequential Story Images

### Approach 1: Latent Concatenation (Recommended for Quick Testing)

**How it works:**
1. Encode previous image ‚Üí latent space (4 channels)
2. Sample noise and add to target image latent (4 channels)
3. Concatenate: [noise_latent, prev_image_latent] (8 channels)
4. UNet predicts noise given: (8-channel input, text embedding)
5. Decode predicted latent ‚Üí next image

**Advantages:**
- ‚úÖ Fast (single forward pass)
- ‚úÖ Low memory (~6GB)
- ‚úÖ Simple implementation
- ‚úÖ No additional models needed

**Disadvantages:**
- ‚ö†Ô∏è Less precise than ControlNet
- ‚ö†Ô∏è May lose some fine details

In [None]:
print("\n" + "=" * 70)
print("TRAINING SETUP")
print("=" * 70)

# Setup optimizer - only train UNet (text encoder and VAE are frozen)
optimizer = AdamW(
    unet.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)

# Learning rate scheduler
from diffusers.optimization import get_cosine_schedule_with_warmup

num_update_steps_per_epoch = len(train_loader)
max_train_steps = config.num_epochs * num_update_steps_per_epoch

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=500,
    num_training_steps=max_train_steps
)

print(f"\nOptimizer: AdamW")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Weight decay: {config.weight_decay}")
print(f"  Total steps: {max_train_steps}")
print(f"  Steps per epoch: {num_update_steps_per_epoch}")

def train_epoch(epoch_num):
    """Train one epoch."""
    unet.train()
    progress_bar = tqdm(total=len(train_loader), desc=f"Epoch {epoch_num}")
    losses = []
    
    for step, batch in enumerate(train_loader):
        # Load batch data
        prev_images = batch['prev_image'].to(DEVICE)
        captions = batch['caption']
        target_images = batch['target_image'].to(DEVICE)
        
        # Encode images to latent space
        with torch.no_grad():
            # Previous image latent
            prev_latents = vae.encode(prev_images).latent_dist.sample()
            prev_latents = prev_latents * 0.18215
            
            # Target image latent
            target_latents = vae.encode(target_images).latent_dist.sample()
            target_latents = target_latents * 0.18215
        
        # Tokenize and encode text
        with torch.no_grad():
            tokens = tokenizer(
                captions,
                padding="max_length",
                max_length=tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt"
            )
            encoder_hidden_states = text_encoder(tokens.input_ids.to(DEVICE))[0]
        
        # Sample random timesteps
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, 
                                 (target_latents.shape[0],), device=DEVICE).long()
        
        # Sample noise
        noise = torch.randn_like(target_latents)
        
        # Add noise to target latent
        noisy_latents = noise_scheduler.add_noise(target_latents, noise, timesteps)
        
        # For latent concat approach: concatenate previous image latent
        if config.approach == "latent_concat":
            noisy_latents = torch.cat([noisy_latents, prev_latents], dim=1)
        
        # Predict noise with UNet
        model_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=encoder_hidden_states
        ).sample
        
        # Loss: L2 between predicted and actual noise
        loss = F.mse_loss(model_pred, noise, reduction="mean")
        
        # Backward
        loss.backward()
        torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        
        losses.append(loss.detach().item())
        progress_bar.update(1)
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    progress_bar.close()
    return np.mean(losses)

# Training loop
print("\n" + "=" * 70)
print("STARTING TRAINING")
print("=" * 70)

history = {"epoch": [], "train_loss": [], "val_loss": []}

for epoch in range(config.num_epochs):
    print(f"\nEpoch {epoch+1}/{config.num_epochs}")
    
    train_loss = train_epoch(epoch+1)
    history["epoch"].append(epoch+1)
    history["train_loss"].append(train_loss)
    
    print(f"Train loss: {train_loss:.4f}")
    
    # Save checkpoint
    checkpoint_dir = os.path.join(config.checkpoints_dir, f"epoch-{epoch+1}")
    os.makedirs(checkpoint_dir, exist_ok=True)
    unet.save_pretrained(os.path.join(checkpoint_dir, "unet"))
    print(f"‚úì Checkpoint saved: {checkpoint_dir}")

print("\n‚úì Training complete!")

## 7. Generate Sequential Story Illustrations

In [None]:
def generate_sequential_story(prev_image, caption, num_inference_steps=30, guidance_scale=7.5):
    """
    Generate next image in sequence given previous image and caption.
    
    Args:
        prev_image: PIL Image or tensor of previous frame
        caption: Text description for next frame
        
    Returns:
        PIL Image of generated frame
    """
    # Convert PIL to tensor if needed
    if isinstance(prev_image, Image.Image):
        prev_image = prev_image.resize((config.image_size, config.image_size), Image.LANCZOS)
        prev_img_tensor = torch.from_numpy(np.array(prev_image)).float() / 255.0
        prev_img_tensor = prev_img_tensor * 2 - 1
        prev_img_tensor = prev_img_tensor.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
    else:
        prev_img_tensor = prev_image.to(DEVICE)
    
    # Encode previous image
    with torch.no_grad():
        prev_latents = vae.encode(prev_img_tensor).latent_dist.sample()
        prev_latents = prev_latents * 0.18215
        
        # Tokenize caption
        tokens = tokenizer(
            [caption],
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt"
        )
        encoder_hidden_states = text_encoder(tokens.input_ids.to(DEVICE))[0]
        
        # Generate noise schedule
        noise_scheduler.set_timesteps(num_inference_steps)
        
        # Start with noise
        latents = torch.randn((1, 4, config.image_size//8, config.image_size//8), device=DEVICE)
        
        # Denoising loop
        for t in noise_scheduler.timesteps:
            # Concatenate previous image latent for context
            if config.approach == "latent_concat":
                latent_model_input = torch.cat([latents, prev_latents], dim=1)
            else:
                latent_model_input = latents
            
            # Predict noise
            noise_pred = unet(
                latent_model_input,
                t,
                encoder_hidden_states=encoder_hidden_states
            ).sample
            
            # Denoise
            latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
        
        # Decode latents
        image = vae.decode(latents / 0.18215).sample
        
        # Convert to PIL
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.permute(0, 2, 3, 1).float().cpu().numpy()
        image = (image[0] * 255).astype(np.uint8)
        
        return Image.fromarray(image)

# Load a fine-tuned checkpoint if available
checkpoint_path = os.path.join(config.checkpoints_dir, "epoch-1")
if os.path.exists(checkpoint_path):
    print(f"\nLoading checkpoint: {checkpoint_path}")
    unet = UNet2DConditionModel.from_pretrained(os.path.join(checkpoint_path, "unet"))
    unet = unet.to(DEVICE)
    unet.eval()
    print("‚úì Fine-tuned model loaded")
else:
    print("\n‚ö†Ô∏è No checkpoint found, using base model for generation")
    unet.eval()

# Test on a validation story
print("\n" + "=" * 70)
print("GENERATING TEST SEQUENCE")
print("=" * 70)

val_sample = val_dataset[0]
prev_img = val_sample['prev_image'].unsqueeze(0)
caption = val_sample['caption']

print(f"\nCaption: '{caption}'")
print("Generating next frame...")

generated_img = generate_sequential_story(prev_img[0], caption)

# Display
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(prev_img[0].permute(1, 2, 0) * 0.5 + 0.5)
axes[0].set_title("Previous Frame")
axes[0].axis('off')

axes[1].imshow(generated_img)
axes[1].set_title(f"Generated Next Frame\n'{caption[:30]}'")
axes[1].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(config.results_dir, "sample_generation.png"), dpi=150, bbox_inches='tight')
plt.show()

print("\n‚úì Generation complete!")

## 8. Evaluate Temporal Consistency

In [None]:
def compute_ssim_score(img1, img2):
    """Compute Structural Similarity Index between two images."""
    # Convert PIL to numpy
    if isinstance(img1, Image.Image):
        img1 = np.array(img1)
    if isinstance(img2, Image.Image):
        img2 = np.array(img2)
    
    # Convert to grayscale
    img1_gray = rgb2gray(img1)
    img2_gray = rgb2gray(img2)
    
    # Compute SSIM
    score = ssim(img1_gray, img2_gray, data_range=1.0 if img1_gray.max() <= 1 else 255)
    return score

def compute_clip_alignment(image, caption):
    """Compute CLIP alignment score."""
    try:
        import clip
        
        device = "cuda" if torch.cuda.is_available() else "cpu"
        clip_model, preprocess = clip.load("ViT-B/32", device=device)
        
        # Preprocess image
        if isinstance(image, Image.Image):
            image_tensor = preprocess(image).unsqueeze(0).to(device)
        else:
            image_tensor = preprocess(image).unsqueeze(0).to(device)
        
        # Tokenize text
        text_tokens = clip.tokenize([caption]).to(device)
        
        # Get embeddings
        with torch.no_grad():
            image_features = clip_model.encode_image(image_tensor)
            text_features = clip_model.encode_text(text_tokens)
            
            # Normalize
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            
            # Similarity (0-1 scale)
            similarity = (image_features @ text_features.t()).squeeze()
            score = float(similarity.cpu().numpy()) * 100
        
        return score
    except:
        print("‚ö†Ô∏è CLIP not available, skipping alignment score")
        return None

# Evaluate on validation set
print("\n" + "=" * 70)
print("EVALUATING TEMPORAL CONSISTENCY")
print("=" * 70)

ssim_scores = []
clip_scores = []
story_ids = []

for idx in tqdm(range(min(len(val_dataset), 20)), desc="Evaluating"):
    sample = val_dataset[idx]
    
    prev_img = sample['prev_image'].unsqueeze(0)
    caption = sample['caption']
    
    # Generate next frame
    with torch.no_grad():
        gen_img = generate_sequential_story(prev_img[0], caption)
    
    # Compute SSIM (visual similarity)
    ssim_score = compute_ssim_score(prev_img[0].permute(1, 2, 0) * 0.5 + 0.5, gen_img)
    ssim_scores.append(ssim_score)
    
    # Compute CLIP alignment
    clip_score = compute_clip_alignment(gen_img, caption)
    if clip_score is not None:
        clip_scores.append(clip_score)
    
    story_ids.append(sample['story_id'])

# Summary
print("\n" + "=" * 70)
print("EVALUATION RESULTS")
print("=" * 70)

if ssim_scores:
    print(f"\nStructural Similarity (SSIM) - Temporal Continuity:")
    print(f"  Mean SSIM: {np.mean(ssim_scores):.4f}")
    print(f"  Std Dev: {np.std(ssim_scores):.4f}")
    print(f"  Min: {np.min(ssim_scores):.4f}, Max: {np.max(ssim_scores):.4f}")
    print(f"\n  Interpretation:")
    print(f"    > 0.8: Excellent temporal continuity")
    print(f"    > 0.6: Good continuity")
    print(f"    > 0.4: Moderate continuity")
    print(f"    < 0.4: Low continuity")

if clip_scores:
    print(f"\nCLIP Text-Image Alignment:")
    print(f"  Mean Score: {np.mean(clip_scores):.2f}/100")
    print(f"  Std Dev: {np.std(clip_scores):.2f}")
    print(f"  Range: {np.min(clip_scores):.2f}-{np.max(clip_scores):.2f}")
    print(f"\n  Quality Assessment:")
    if np.mean(clip_scores) >= 70:
        print(f"    ‚úÖ Excellent alignment")
    elif np.mean(clip_scores) >= 50:
        print(f"    ‚úì Good alignment")
    else:
        print(f"    ‚ö†Ô∏è Fair alignment - consider longer training")

## 9. Visualize Story Sequence Results and Generate GIFs

In [None]:
def generate_story_sequence(story_data, max_frames=5):
    """
    Generate full sequence from story data.
    
    Args:
        story_data: List of {image, caption} dicts
        max_frames: Maximum frames to generate
        
    Returns:
        List of generated PIL images
    """
    generated_sequence = []
    current_image = story_data[0]['prev_image']  # Start with first image
    
    for idx, frame_data in enumerate(story_data[:max_frames]):
        caption = frame_data['caption']
        
        print(f"  Frame {idx+1}: {caption[:40]}...", end=" ", flush=True)
        
        # Generate next frame
        gen_img = generate_sequential_story(current_image, caption)
        generated_sequence.append(gen_img)
        
        # Update current for next iteration
        current_image = gen_img
        print("‚úì")
    
    return generated_sequence

# Generate multiple story sequences
print("\n" + "=" * 70)
print("GENERATING FULL STORY SEQUENCES")
print("=" * 70)

# Group samples by story_id
story_groups = {}
for idx, sample in enumerate(val_dataset):
    story_id = sample['story_id']
    if story_id not in story_groups:
        story_groups[story_id] = []
    story_groups[story_id].append(sample)

# Generate and visualize sequences
for story_idx, (story_id, frames) in enumerate(list(story_groups.items())[:3]):
    print(f"\nGenerating Story #{story_idx+1} (ID: {story_id}, {len(frames)} frames)")
    
    generated = generate_story_sequence(frames[:4])
    
    # Visualize side-by-side
    fig, axes = plt.subplots(2, min(4, len(generated)), figsize=(15, 6))
    if len(generated) == 1:
        axes = axes.reshape(2, 1)
    
    for frame_idx, (ax_row, sample) in enumerate(zip(axes.T, frames[:len(generated)])):
        # Original sequence
        orig_img = sample['target_image'].permute(1, 2, 0) * 0.5 + 0.5
        ax_row[0].imshow(orig_img)
        ax_row[0].set_title(f"Original Frame {frame_idx+1}", fontsize=9)
        ax_row[0].axis('off')
        
        # Generated sequence
        ax_row[1].imshow(generated[frame_idx])
        caption = sample['caption'][:25]
        ax_row[1].set_title(f"Generated\n'{caption}'", fontsize=9)
        ax_row[1].axis('off')
    
    plt.suptitle(f"Story #{story_idx+1} - Original vs Generated", fontsize=12, fontweight='bold')
    plt.tight_layout()
    
    save_path = os.path.join(config.results_dir, f"story_sequence_{story_idx+1}.png")
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"  ‚úì Saved to {save_path}")

# Create GIF animations
print("\n" + "=" * 70)
print("CREATING STORY PROGRESSION GIFs")
print("=" * 70)

for story_idx, (story_id, frames) in enumerate(list(story_groups.items())[:2]):
    print(f"\nCreating GIF for Story #{story_idx+1}")
    
    generated = generate_story_sequence(frames[:5])
    
    # Create GIF
    gif_path = os.path.join(config.results_dir, f"story_{story_idx+1}.gif")
    
    images_pil = [img.convert('RGB') for img in generated]
    images_pil[0].save(
        gif_path,
        save_all=True,
        append_images=images_pil[1:],
        duration=500,  # 500ms per frame
        loop=0  # Infinite loop
    )
    
    print(f"  ‚úì GIF saved to {gif_path}")

print("\n" + "=" * 70)
print("VISUALIZATION COMPLETE")
print("=" * 70)
print(f"\nResults saved to: {config.results_dir}")
print("  - sample_generation.png: Single frame example")
print("  - story_sequence_*.png: Full sequence comparisons")
print("  - story_*.gif: Story progression animations")

## 10. Alternative Approaches: LoRA Fine-tuning

### Approach 2: LoRA Adaptation (Recommended for Production)

**What is LoRA?**
- Low-Rank Adaptation: Fine-tune with ~1-2% of parameters
- Add small learnable matrices to attention layers
- Keep base model frozen

**Advantages:**
- ‚úÖ Excellent quality (~95% of full fine-tune)
- ‚úÖ Memory efficient (80% reduction)
- ‚úÖ Fast training (50% faster)
- ‚úÖ Small checkpoints (~50MB vs 4GB)
- ‚úÖ Composable with other LoRAs

**Disadvantages:**
- ‚ö†Ô∏è Requires `peft` library
- ‚ö†Ô∏è Slightly longer convergence

**Installation:**
```bash
pip install peft
```

**Usage in Training Loop:**
```python
from peft import get_peft_model, LoraConfig, TaskType

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["to_q", "to_v", "to_k"],
    lora_dropout=0.05,
    bias="none"
)

unet = get_peft_model(unet, lora_config)
unet.print_trainable_parameters()  # ~1-2% trainable
```

**To switch to LoRA:**
1. Set `config.approach = "lora_adapt"`
2. Install peft: `pip install peft`
3. Uncomment LoRA code above
4. Training loop remains the same!

## 11. ControlNet Approach (Advanced: Maximum Quality)

In [None]:
"""
### Approach 3: ControlNet (Maximum Quality but Higher VRAM)

**What is ControlNet?**
- Separate neural network that guides generation
- Controls output using edge maps, pose, or raw image
- Better structure preservation than concatenation

**Advantages:**
- ‚úÖ Best visual quality (highest SSIM scores)
- ‚úÖ Better scene structure preservation
- ‚úÖ More stable generation
- ‚úÖ Research-backed approach

**Disadvantages:**
- ‚ö†Ô∏è Requires ~10-14GB VRAM
- ‚ö†Ô∏è Slower inference (2-3x vs latent concat)
- ‚ö†Ô∏è Separate model to fine-tune
- ‚ö†Ô∏è More complex training setup

**Installation:**
```bash
pip install diffusers transformers accelerate safetensors
```

**Usage Pattern:**
```python
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel

controlnet = ControlNetModel.from_pretrained(
    "lllyasviel/sd-controlnet-canny"
)

pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    controlnet=controlnet
)

# Generate with image guidance
image = pipe(
    prompt="a lion in the forest",
    image=prev_image,  # Previous frame guides generation
    num_inference_steps=30,
    controlnet_conditioning_scale=0.5
).images[0]
```

**When to use ControlNet:**
- ‚úÖ Production system (best quality matters)
- ‚úÖ Have 12GB+ GPU VRAM
- ‚úÖ Quality over speed trade-off acceptable
- ‚úÖ Need maximum temporal consistency

**Recommendation:**
- Start with **Latent Concat** for prototyping
- Upgrade to **LoRA** for balanced solution
- Use **ControlNet** for final production release
"""

print(__doc__)

## 12. Summary: Quick Start Guide

**Which approach should YOU use?**

### üöÄ Quick Testing (Now)
**Use: Latent Concatenation**
```python
config.approach = "latent_concat"
# Run training right away!
```
- Time to first result: 30-60 min
- Memory needed: 6GB
- Quality: Good

### ‚≠ê Recommended (Balance)
**Use: LoRA Adaptation**
```bash
pip install peft
```
```python
config.approach = "lora_adapt"
# Modify training loop with LoRA setup
```
- Time to results: 1-2 hours
- Memory needed: 8GB
- Quality: Excellent
- Checkpoint size: 50MB

### üèÜ Production (Best Quality)
**Use: ControlNet**
```bash
# See Cell 11 for ControlNet guide
```
- Time to results: 2-4 hours
- Memory needed: 12GB
- Quality: Best possible
- Inference: Slower but better structure

**Next Steps:**
1. ‚úÖ Choose approach above
2. ‚úÖ Update `config.approach` value
3. ‚úÖ Run training cells
4. ‚úÖ Evaluate results
5. ‚úÖ Scale to full dataset