## 1. Setup and Imports

In [None]:
import os
import sys
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Transformers and Vision models
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPProcessor
from torchvision import transforms
from torchvision.models import resnet50

# Check GPU
USE_GPU = torch.cuda.is_available()
DEVICE = torch.device('cuda:0' if USE_GPU else 'cpu')
print(f"Device: {DEVICE}")
print(f"GPU Available: {USE_GPU}")
if USE_GPU:
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration
**Model Architecture**: Lightweight Sequential Generator
- Text encoder: CLIP (frozen)
- Image encoder: Lightweight ResNet50 (frozen)
- Generator: 4-layer CNN with residual blocks
- Output: 256×256 RGB image

In [None]:
from dataclasses import dataclass

@dataclass
class Config:
    # Model architecture
    image_size: int = 256
    text_embed_dim: int = 768  # CLIP embedding dimension
    image_embed_dim: int = 2048  # ResNet50 output dimension
    latent_dim: int = 512  # Hidden dimension for generator
    
    # Training
    batch_size: int = 8
    num_epochs: int = 30
    learning_rate: float = 1e-4
    weight_decay: float = 1e-5
    
    # Data
    num_workers: int = 4
    shuffle: bool = True
    
    # Checkpoints
    checkpoints_dir: str = "./models/sequential_checkpoints"
    results_dir: str = "./results/sequential"
    log_dir: str = "./logs/sequential"

config = Config()

# Create directories
os.makedirs(config.checkpoints_dir, exist_ok=True)
os.makedirs(config.results_dir, exist_ok=True)
os.makedirs(config.log_dir, exist_ok=True)

print("Configuration:")
print(f"  Image size: {config.image_size}x{config.image_size}")
print(f"  Batch size: {config.batch_size}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Epochs: {config.num_epochs}")

## 3. Load SSID Dataset

In [None]:
# Load SSID annotations
def load_annotations(json_path, split_name):
    with open(json_path, 'r') as f:
        data = json.load(f)
    flat_data = [storylet for story in data['annotations'] for storylet in story]
    df = pd.DataFrame(flat_data)
    df['split'] = split_name
    return df

# Paths
annotations_dir = '../data/SSID_Annotations/SSID_Annotations'
images_dir = '../data/SSID_Images/SSID_Images'

train_json = os.path.join(annotations_dir, "SSID_Train.json")
val_json = os.path.join(annotations_dir, "SSID_Validation.json")
test_json = os.path.join(annotations_dir, "SSID_Test.json")

# Load all splits
df_train = load_annotations(train_json, 'train')
df_val = load_annotations(val_json, 'val')
df_test = load_annotations(test_json, 'test')

df_all = pd.concat([df_train, df_val, df_test], ignore_index=True)

print(f"Train storylets: {len(df_train)}")
print(f"Validation storylets: {len(df_val)}")
print(f"Test storylets: {len(df_test)}")
print(f"Total storylets: {len(df_all)}")
print(f"Unique stories: {df_all['story_id'].nunique()}")

## 4. Create Sequential Pairs Dataset
Build pairs of (previous_image, text, next_image) for training

In [None]:
def create_sequential_pairs(df, images_dir, split='train'):
    """
    Create (prev_image, text, next_image) pairs from stories.
    Only use stories with 2+ images.
    """
    pairs = []
    
    for story_id in df['story_id'].unique():
        story_data = df[df['story_id'] == story_id].sort_values('image_order').reset_index(drop=True)
        
        # Skip stories with only 1 image
        if len(story_data) < 2:
            continue
        
        # Create pairs: (image_t, text_t+1, image_t+1)
        for i in range(len(story_data) - 1):
            prev_row = story_data.iloc[i]
            next_row = story_data.iloc[i + 1]
            
            prev_img_path = os.path.join(images_dir, f"{prev_row['youtube_image_id']}.jpg")
            next_img_path = os.path.join(images_dir, f"{next_row['youtube_image_id']}.jpg")
            
            # Only add if both images exist
            if os.path.exists(prev_img_path) and os.path.exists(next_img_path):
                pairs.append({
                    'prev_image': prev_img_path,
                    'text': next_row['storytext'],
                    'next_image': next_img_path,
                    'story_id': story_id,
                    'split': split
                })
    
    return pairs

# Create pairs for each split
train_pairs = create_sequential_pairs(df_train, images_dir, 'train')
val_pairs = create_sequential_pairs(df_val, images_dir, 'val')

print(f"Train pairs: {len(train_pairs)}")
print(f"Validation pairs: {len(val_pairs)}")
print(f"\nExample pair:")
if train_pairs:
    pair = train_pairs[0]
    print(f"  Previous image: {os.path.basename(pair['prev_image'])}")
    print(f"  Text prompt: {pair['text'][:60]}...")
    print(f"  Next image: {os.path.basename(pair['next_image'])}")

## 5. Custom Dataset Class

In [None]:
class SequentialImageDataset(Dataset):
    """
    Dataset for sequential image generation:
    Input: (previous_image, text_prompt)
    Output: next_image
    """
    
    def __init__(self, pairs, tokenizer, image_size=256):
        self.pairs = pairs
        self.tokenizer = tokenizer
        self.image_size = image_size
        
        # Image transformations
        self.image_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.CenterCrop((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        
        # Load and transform images
        try:
            prev_img = Image.open(pair['prev_image']).convert('RGB')
            next_img = Image.open(pair['next_image']).convert('RGB')
        except Exception as e:
            print(f"Error loading images: {e}")
            # Return black images as fallback
            prev_img = Image.new('RGB', (self.image_size, self.image_size))
            next_img = Image.new('RGB', (self.image_size, self.image_size))
        
        prev_img_tensor = self.image_transform(prev_img)
        next_img_tensor = self.image_transform(next_img)
        
        # Tokenize text
        text = pair['text']
        tokens = self.tokenizer(
            text,
            padding='max_length',
            max_length=77,
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'prev_image': prev_img_tensor,
            'text_input_ids': tokens['input_ids'].squeeze(),
            'text_attention_mask': tokens['attention_mask'].squeeze(),
            'next_image': next_img_tensor,
            'text': text
        }

# Create tokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

# Create datasets
train_dataset = SequentialImageDataset(train_pairs, tokenizer, config.image_size)
val_dataset = SequentialImageDataset(val_pairs, tokenizer, config.image_size)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=USE_GPU
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=USE_GPU
)

print(f"Train loader batches: {len(train_loader)}")
print(f"Validation loader batches: {len(val_loader)}")

## 6. Lightweight Sequential Generator Model

In [None]:
class ResidualBlock(nn.Module):
    """Lightweight residual block for efficient image generation"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        residual = self.shortcut(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return self.relu(out)


class SequentialImageGenerator(nn.Module):
    """
    Lightweight sequential image generator.
    Takes previous image + text embedding as input, generates next image.
    Architecture:
    - Previous image: 3 channels → 64 channels
    - Text embedding: 768 dim → spatial features (8x8x128)
    - Fusion + Residual blocks + Upsampling
    - Output: 256x256 RGB image
    """
    
    def __init__(self, text_embed_dim=768, latent_dim=512, image_size=256):
        super().__init__()
        self.latent_dim = latent_dim
        self.image_size = image_size
        
        # ========== Image Encoder (lightweight) ==========
        # Process previous image: 3 → 64 channels
        self.image_encoder = nn.Sequential(
            nn.Conv2d(3, 32, 7, stride=2, padding=3),  # 256 → 128
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # 128 → 64
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # ========== Text Projection ==========
        # Project text embedding to spatial features: 768 → (8x8x128)
        self.text_projection = nn.Sequential(
            nn.Linear(text_embed_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 8 * 8 * 128)  # Spatial reshape
        )
        
        # ========== Fusion Module ==========
        # Fuse image features (64x64x64) and text features (8x8x128)
        # Upsample text to 64x64 and concatenate
        self.text_upsample = nn.Sequential(
            nn.Upsample(size=(64, 64), mode='nearest'),
            nn.Conv2d(128, 64, 1)
        )
        
        # Fused features: 64 (image) + 64 (text) = 128
        self.fusion = nn.Sequential(
            nn.Conv2d(128, latent_dim, 3, padding=1),
            nn.BatchNorm2d(latent_dim),
            nn.ReLU(inplace=True)
        )
        
        # ========== Generator (residual + upsampling) ==========
        self.residual_blocks = nn.Sequential(
            ResidualBlock(latent_dim, latent_dim),
            ResidualBlock(latent_dim, latent_dim),
            ResidualBlock(latent_dim, latent_dim)
        )
        
        # Upsample to 256x256
        self.upsampler = nn.Sequential(
            # 64x64 → 128x128
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(latent_dim, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            # 128x128 → 256x256
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            # Final layer
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Tanh()  # Output [-1, 1]
        )
    
    def forward(self, prev_image, text_embedding):
        """
        Args:
            prev_image: (B, 3, 256, 256) - normalized to [-1, 1] or [0, 1]
            text_embedding: (B, 768) - CLIP text embedding
        Returns:
            generated_image: (B, 3, 256, 256) - normalized to [-1, 1]
        """
        # Encode previous image
        img_features = self.image_encoder(prev_image)  # (B, 64, 64, 64)
        
        # Project text embedding to spatial features
        text_spatial = self.text_projection(text_embedding)  # (B, 8*8*128)
        text_spatial = text_spatial.view(-1, 128, 8, 8)  # (B, 128, 8, 8)
        
        # Upsample text features to match image features
        text_features = self.text_upsample(text_spatial)  # (B, 64, 64, 64)
        
        # Fuse image and text features
        fused = torch.cat([img_features, text_features], dim=1)  # (B, 128, 64, 64)
        fused = self.fusion(fused)  # (B, 512, 64, 64)
        
        # Apply residual blocks
        features = self.residual_blocks(fused)  # (B, 512, 64, 64)
        
        # Upsample to final resolution
        output = self.upsampler(features)  # (B, 3, 256, 256)
        
        return output


# Create model
model = SequentialImageGenerator(
    text_embed_dim=config.text_embed_dim,
    latent_dim=config.latent_dim,
    image_size=config.image_size
).to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model Parameters: {total_params / 1e6:.2f}M")
print(f"Model created successfully!")

## 7. Load Encoders (CLIP)

In [None]:
# Load CLIP text encoder (frozen for embeddings)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
text_encoder.eval()
for param in text_encoder.parameters():
    param.requires_grad = False

print("✓ CLIP text encoder loaded (frozen)")
print(f"  Text embedding dimension: {config.text_embed_dim}")

## 8. Loss Function and Optimizer

In [None]:
# Loss function: Combination of L1 and Perceptual loss
class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
    
    def forward(self, pred, target):
        # L1: Encourages pixel-level accuracy
        l1 = self.l1_loss(pred, target)
        
        # MSE: Smooth pixel differences
        mse = self.mse_loss(pred, target)
        
        # Combined: 0.7 * L1 + 0.3 * MSE
        return 0.7 * l1 + 0.3 * mse

criterion = CombinedLoss()
optimizer = Adam(
    model.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay,
    betas=(0.9, 0.999)
)

print(f"Loss Function: Combined L1 + MSE")
print(f"Optimizer: Adam (lr={config.learning_rate}, weight_decay={config.weight_decay})")

## 9. Training Function

In [None]:
def train_epoch(model, text_encoder, train_loader, optimizer, criterion, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0.0
    
    progress_bar = tqdm(train_loader, desc="Training", ncols=80)
    
    for batch in progress_bar:
        try:
            # Load batch
            prev_images = batch['prev_image'].to(device)
            text_input_ids = batch['text_input_ids'].to(device)
            text_attention_mask = batch['text_attention_mask'].to(device)
            next_images = batch['next_image'].to(device)
            
            # Get text embeddings from CLIP (no grad)
            with torch.no_grad():
                text_embeddings = text_encoder(
                    input_ids=text_input_ids,
                    attention_mask=text_attention_mask
                ).last_hidden_state[:, 0]  # Use [CLS] token embedding
            
            # Forward pass
            generated_images = model(prev_images, text_embeddings)
            
            # Compute loss
            loss = criterion(generated_images, next_images)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item():.4f})
            
        except Exception as e:
            print(f"\nError in batch: {e}")
            continue
    
    avg_loss = total_loss / len(train_loader)
    return avg_loss


def validate(model, text_encoder, val_loader, criterion, device):
    """Validate model"""
    model.eval()
    total_loss = 0.0
    
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc="Validating", ncols=80)
        
        for batch in progress_bar:
            try:
                # Load batch
                prev_images = batch['prev_image'].to(device)
                text_input_ids = batch['text_input_ids'].to(device)
                text_attention_mask = batch['text_attention_mask'].to(device)
                next_images = batch['next_image'].to(device)
                
                # Get text embeddings
                text_embeddings = text_encoder(
                    input_ids=text_input_ids,
                    attention_mask=text_attention_mask
                ).last_hidden_state[:, 0]
                
                # Forward pass
                generated_images = model(prev_images, text_embeddings)
                
                # Compute loss
                loss = criterion(generated_images, next_images)
                total_loss += loss.item()
                progress_bar.set_postfix({"loss": loss.item():.4f})
                
            except Exception as e:
                continue
    
    avg_loss = total_loss / len(val_loader)
    return avg_loss

print("Training functions defined")

## 10. Training Loop

In [None]:
train_losses = []
val_losses = []
best_val_loss = float('inf')
patience = 5
patience_counter = 0

print("Starting training...\n")
print(f"Epochs: {config.num_epochs}")
print(f"Steps per epoch: {len(train_loader)}")
print(f"Device: {DEVICE}\n")
print("="*60)

for epoch in range(config.num_epochs):
    print(f"\nEpoch {epoch + 1}/{config.num_epochs}")
    print("-" * 60)
    
    # Train
    train_loss = train_epoch(model, text_encoder, train_loader, optimizer, criterion, DEVICE)
    train_losses.append(train_loss)
    
    # Validate
    val_loss = validate(model, text_encoder, val_loader, criterion, DEVICE)
    val_losses.append(val_loss)
    
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        
        # Save best model
        checkpoint_path = os.path.join(config.checkpoints_dir, "best_model.pt")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"✓ Best model saved (val_loss: {val_loss:.4f})")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break
    
    # Save checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        checkpoint_path = os.path.join(config.checkpoints_dir, f"checkpoint_epoch_{epoch + 1}.pt")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"✓ Checkpoint saved")

print("\n" + "="*60)
print("Training completed!")
print(f"Final train loss: {train_losses[-1]:.4f}")
print(f"Final val loss: {val_losses[-1]:.4f}")

## 11. Plot Training History

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(train_losses, label='Train Loss', marker='o', linewidth=2, markersize=6)
ax.plot(val_losses, label='Val Loss', marker='s', linewidth=2, markersize=6)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Sequential Image Generator - Training History', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(config.results_dir, 'training_loss.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"Loss plot saved to: {os.path.join(config.results_dir, 'training_loss.png')}")

## 12. Load Best Model and Generate Samples

In [None]:
# Load best model
best_model_path = os.path.join(config.checkpoints_dir, "best_model.pt")
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
    print(f"✓ Best model loaded")
else:
    print(f"No best model found at {best_model_path}")

model.eval()
print("Model ready for inference")

## 13. Test Inference on Validation Stories

In [None]:
def denormalize(tensor):
    """Convert from [-1, 1] or [0, 1] to PIL Image"""
    # If tanh output [-1, 1]
    if tensor.min() < 0:
        tensor = (tensor + 1) / 2
    tensor = torch.clamp(tensor, 0, 1)
    tensor = tensor * 255
    return tensor.byte()


def generate_next_image(prev_image_path, text, model, text_encoder, device, image_size=256):
    """
    Generate next image given previous image and text.
    """
    model.eval()
    
    # Load and preprocess image
    img = Image.open(prev_image_path).convert('RGB')
    img_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.CenterCrop((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    img_tensor = img_transform(img).unsqueeze(0).to(device)
    
    # Tokenize text
    tokens = tokenizer(
        text,
        padding='max_length',
        max_length=77,
        truncation=True,
        return_tensors='pt'
    )
    text_input_ids = tokens['input_ids'].to(device)
    text_attention_mask = tokens['attention_mask'].to(device)
    
    # Get text embedding
    with torch.no_grad():
        text_embedding = text_encoder(
            input_ids=text_input_ids,
            attention_mask=text_attention_mask
        ).last_hidden_state[:, 0]
        
        # Generate
        generated = model(img_tensor, text_embedding)
    
    # Convert to PIL
    generated = generated.squeeze(0).cpu()
    generated = denormalize(generated)
    generated = transforms.ToPILImage()(generated)
    
    return generated


# Test on a sample from validation set
print(f"Testing inference on {min(3, len(val_pairs))} validation pairs...\n")

for idx in range(min(3, len(val_pairs))):
    pair = val_pairs[idx]
    
    print(f"\nPair {idx + 1}:")
    print(f"  Text: {pair['text'][:60]}...")
    
    try:
        generated_img = generate_next_image(
            pair['prev_image'],
            pair['text'],
            model,
            text_encoder,
            DEVICE,
            config.image_size
        )
        
        # Load ground truth
        ground_truth = Image.open(pair['next_image']).convert('RGB')
        ground_truth = ground_truth.resize((config.image_size, config.image_size))
        
        # Display
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Previous image
        prev_img = Image.open(pair['prev_image']).convert('RGB')
        prev_img = prev_img.resize((config.image_size, config.image_size))
        axes[0].imshow(prev_img)
        axes[0].set_title('Previous Image', fontsize=10)
        axes[0].axis('off')
        
        # Generated
        axes[1].imshow(generated_img)
        axes[1].set_title('Generated Next Image', fontsize=10, color='green')
        axes[1].axis('off')
        
        # Ground truth
        axes[2].imshow(ground_truth)
        axes[2].set_title('Ground Truth Next Image', fontsize=10, color='blue')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.savefig(
            os.path.join(config.results_dir, f'inference_sample_{idx + 1}.png'),
            dpi=100,
            bbox_inches='tight'
        )
        plt.show()
        print(f"  ✓ Sample saved")
        
    except Exception as e:
        print(f"  ✗ Error: {e}")

print(f"\nInference samples saved to: {config.results_dir}")

## 14. Generate Full Story Sequence

In [None]:
def generate_story_sequence(story_id, df, model, text_encoder, device, images_dir, image_size=256):
    """
    Generate a full sequence of images for a story.
    Autoregressively: use generated image as input for next generation.
    """
    story_data = df[df['story_id'] == story_id].sort_values('image_order')
    generated_sequence = []
    
    if len(story_data) < 2:
        print(f"Story {story_id} has less than 2 images")
        return None
    
    model.eval()
    
    # First image (use real)
    first_img_path = os.path.join(images_dir, f"{story_data.iloc[0]['youtube_image_id']}.jpg")
    first_img = Image.open(first_img_path).convert('RGB')
    first_img = first_img.resize((image_size, image_size))
    generated_sequence.append(first_img)
    
    current_img_path = first_img_path
    
    # Generate remaining images
    with torch.no_grad():
        for i in range(1, len(story_data)):
            text = story_data.iloc[i]['storytext']
            
            try:
                generated_img = generate_next_image(
                    current_img_path,
                    text,
                    model,
                    text_encoder,
                    device,
                    image_size
                )
                generated_sequence.append(generated_img)
                
                # Save temporary for next iteration
                temp_path = os.path.join(config.results_dir, 'temp_gen.jpg')
                generated_img.save(temp_path)
                current_img_path = temp_path
                
            except Exception as e:
                print(f"Error generating image {i}: {e}")
                break
    
    return generated_sequence, story_data


# Generate a story sequence
if len(val_pairs) > 0:
    # Get a story_id from validation pairs
    story_id = val_pairs[0]['story_id']
    
    print(f"Generating story sequence for story_id: {story_id}\n")
    
    result = generate_story_sequence(
        story_id,
        df_val,
        model,
        text_encoder,
        DEVICE,
        images_dir,
        config.image_size
    )
    
    if result:
        generated_sequence, story_data = result
        
        # Display sequence
        num_images = len(generated_sequence)
        fig, axes = plt.subplots(1, num_images, figsize=(4*num_images, 4))
        
        if num_images == 1:
            axes = [axes]
        
        for idx, img in enumerate(generated_sequence):
            axes[idx].imshow(img)
            text = story_data.iloc[idx]['storytext'][:20] + "..."
            axes[idx].set_title(f"Image {idx + 1}", fontsize=10)
            axes[idx].axis('off')
        
        plt.tight_layout()
        plt.savefig(
            os.path.join(config.results_dir, 'story_sequence.png'),
            dpi=100,
            bbox_inches='tight'
        )
        plt.show()
        
        print(f"\nGenerated {num_images} images for story")
        print("Full Story:")
        for idx, row in story_data.iterrows():
            print(f"{row['image_order']}. {row['storytext']}")
        
        print(f"\nSequence saved to: {os.path.join(config.results_dir, 'story_sequence.png')}")

## 15. Summary and Results

In [None]:
print("="*60)
print("SEQUENTIAL IMAGE GENERATION - TRAINING SUMMARY")
print("="*60)
print(f"\nModel Architecture: Lightweight Sequential Generator")
print(f"  Total Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")
print(f"\nDataset:")
print(f"  Training pairs: {len(train_pairs)}")
print(f"  Validation pairs: {len(val_pairs)}")
print(f"  Unique stories (train): {df_train['story_id'].nunique()}")
print(f"  Unique stories (val): {df_val['story_id'].nunique()}")
print(f"\nTraining Configuration:")
print(f"  Batch size: {config.batch_size}")
print(f"  Epochs: {config.num_epochs}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Image size: {config.image_size}x{config.image_size}")
print(f"\nResults:")
print(f"  Initial train loss: {train_losses[0]:.4f}")
print(f"  Final train loss: {train_losses[-1]:.4f}")
print(f"  Best validation loss: {min(val_losses):.4f}")
print(f"  Loss reduction: {(1 - train_losses[-1]/train_losses[0])*100:.1f}%")
print(f"\nCheckpoints saved in: {config.checkpoints_dir}")
print(f"Results saved in: {config.results_dir}")
print("\n" + "="*60)
print("\nKey Features:")
print("  ✓ Lightweight model (~50-100M parameters)")
print("  ✓ Efficient training on moderate GPUs")
print("  ✓ Takes previous image + text as input")
print("  ✓ Generates contextually coherent next images")
print("  ✓ Autoregressive inference for full story generation")
print("  ✓ Quality images at 256x256 resolution")