# Wan 2.1 Video Generation on Kaggle

This notebook generates videos from Mahabharata stories using Wan 2.1 model.

**Workflow:**
1. Create Vector Database from PDF (first time only)
2. Retrieve passages from vector DB
3. Generate story scripts (~5 scenes) and narration
4. Generate videos from scripts

## Setup Instructions

1. **Add GitHub as Data Source**:
   - Click "Add Data" → "GitHub"
   - Enter your repository URL: `https://github.com/VibhavVPandit/Mahabharata`

2. **Upload PDF**:
   - Upload `Mahabharata.pdf` to `/kaggle/input/`

3. **Configure Notebook Settings** (in sidebar):
   - **Accelerator**: GPU T4 x2
   - **Language**: Python 3
   - **Internet**: On
   - **Persistence**: Files Only (optional)

4. **Set API Key**:
   - Go to Settings → Secrets → Add Secret
   - Name: `GEMINI_API_KEY`
   - Value: Your Gemini API key


## 1. Setup and Installation


In [None]:
# Clone repository (if not already added as data source)
import os
import sys
from pathlib import Path

# If repository is added as data source, it will be in /kaggle/input/
# Otherwise, clone it here
repo_path = Path("/kaggle/working/repo")
if not repo_path.exists():
    # Update with your repository URL
    # !git clone <YOUR_REPO_URL> /kaggle/working/repo
    repo_path = Path("/kaggle/working/repo")
else:
    # Find repo in input directory
    input_dir = Path("/kaggle/input")
    for item in input_dir.iterdir():
        if item.is_dir() and (item / "kaggle").exists():
            repo_path = item
            break

print(f"Repository path: {repo_path}")

# Add kaggle directory to path
kaggle_dir = repo_path / "kaggle"
if not kaggle_dir.exists():
    # Try alternative location
    kaggle_dir = Path("/kaggle/input") / "kaggle"
    if not kaggle_dir.exists():
        kaggle_dir = Path("/kaggle/working") / "kaggle"

sys.path.insert(0, str(kaggle_dir.parent))
sys.path.insert(0, str(kaggle_dir))
print(f"Kaggle directory: {kaggle_dir}")


In [None]:
# Install dependencies
!pip install -q diffusers>=0.21.0 transformers>=4.30.0 accelerate>=0.20.0
!pip install -q imageio>=2.31.0 imageio-ffmpeg>=0.4.8 sentencepiece
!pip install -q chromadb>=0.4.15 sentence-transformers>=2.2.0
!pip install -q google-generativeai>=0.3.0 pydantic>=2.0.0 pyyaml>=6.0.1 numpy>=1.24.0
!pip install -q pypdf>=3.0.0

print("✓ Dependencies installed")


In [None]:
# Setup Hugging Face cache in /kaggle/tmp/
import os
from pathlib import Path

cache_dir = Path("/kaggle/tmp/huggingface")
cache_dir.mkdir(parents=True, exist_ok=True)

os.environ['HF_HOME'] = str(cache_dir)
os.environ['TRANSFORMERS_CACHE'] = str(cache_dir)
os.environ['HF_DATASETS_CACHE'] = str(cache_dir)

print(f"✓ Hugging Face cache set to: {cache_dir}")


## 2. Configuration


In [None]:
# Configuration
QUERY = None  # Optional: "Arjuna" or "Kurukshetra" for targeted retrieval, None for random

# PDF file path (upload Mahabharata.pdf to /kaggle/input/)
PDF_PATH = "/kaggle/input/Mahabharata.pdf"  # Update if your PDF has different name/path

print(f"Query: {QUERY if QUERY else 'Random (will retrieve random passage)'}")
print(f"PDF Path: {PDF_PATH}")


## 3. Create Vector Database from PDF

**Run this section first** to create the vector database from the Mahabharata PDF.


In [None]:
# Create Vector Database from PDF
import sys
import re
import unicodedata
import numpy as np
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Optional, Set, Tuple
from collections import Counter

sys.path.insert(0, str(kaggle_dir))
from story_pipeline import KaggleVectorStore, KaggleEmbedder, Passage

# ============================================================================
# TEXT CLEANING FUNCTIONS (Preserves Paragraph Structure)
# ============================================================================

def normalize_unicode(text: str) -> str:
    """Normalize Unicode characters (NFD to NFC)."""
    return unicodedata.normalize('NFC', text)

def remove_pdf_artifacts(text: str) -> str:
    """
    Remove common PDF artifacts:
    - Page numbers (standalone numbers, "Page X", etc.)
    - Repeated headers/footers (lines that appear many times)
    - Common PDF metadata patterns
    """
    lines = text.split('\n')
    
    # Remove standalone page numbers (1-4 digits at start/end of line)
    lines = [re.sub(r'^\s*\d{1,4}\s*$', '', line) for line in lines]
    
    # Remove "Page X" patterns
    lines = [re.sub(r'^\s*Page\s+\d+\s*$', '', line, flags=re.IGNORECASE) for line in lines]
    
    # Remove common PDF metadata patterns
    lines = [re.sub(r'^\s*Table of Contents\s*$', '', line, flags=re.IGNORECASE) for line in lines]
    
    # Count line frequencies to detect repeated headers/footers
    line_counts = Counter(line.strip() for line in lines if line.strip())
    
    # Remove lines that appear more than 10 times (likely headers/footers)
    threshold = 10
    lines = [
        line for line in lines 
        if not line.strip() or line_counts.get(line.strip(), 0) <= threshold
    ]
    
    return '\n'.join(lines)

def fix_hyphenation(text: str) -> str:
    """
    Fix hyphenation at line breaks in PDFs.
    Example: "Ar-\njuna" -> "Arjuna"
    """
    text = re.sub(r'(\w+)-\s*\n\s*(\w+)', r'\1\2', text)
    return text

def clean_text(text: str) -> str:
    """
    Improved text cleaning that PRESERVES paragraph structure.
    
    Key difference: Preserves double newlines as paragraph markers.
    """
    print("  Step 1: Normalizing Unicode...")
    text = normalize_unicode(text)
    
    print("  Step 2: Removing PDF artifacts...")
    text = remove_pdf_artifacts(text)
    
    print("  Step 3: Fixing hyphenation...")
    text = fix_hyphenation(text)
    
    print("  Step 4: Normalizing whitespace (preserving paragraphs)...")
    # Mark paragraph breaks before collapsing whitespace
    text = re.sub(r'\n\s*\n+', '\n\n<PARA_BREAK>\n\n', text)
    # Collapse single newlines and multiple spaces
    text = re.sub(r'[ \t]+', ' ', text)
    text = re.sub(r' ?\n ?', ' ', text)
    # Restore paragraph breaks
    text = text.replace('<PARA_BREAK>', '\n\n')
    
    print("  Step 5: Removing problematic characters...")
    # Remove control characters
    text = re.sub(r'[\x00-\x08\x0B-\x1F\x7F-\x9F]', '', text)
    # Keep more punctuation
    text = re.sub(r'[^\w\s.,!?;:\-\(\)\[\]\'"—–\n]', ' ', text)
    
    print("  Step 6: Normalizing quotes...")
    text = text.replace('"', '"').replace('"', '"')
    text = text.replace(''', "'").replace(''', "'")
    
    # Final cleanup
    text = re.sub(r' +', ' ', text)
    text = re.sub(r'\n{3,}', '\n\n', text)  # Max 2 newlines
    text = text.strip()
    
    return text

# ============================================================================
# SENTENCE TOKENIZATION
# ============================================================================

def split_into_sentences(text: str) -> List[str]:
    """
    Split text into sentences using regex-based approach.
    Handles common abbreviations and edge cases.
    """
    # Common abbreviations that shouldn't end sentences
    abbreviations = r'(?:Mr|Mrs|Ms|Dr|Prof|Sr|Jr|vs|etc|viz|al|eg|ie|cf|Vol|Ch|Sec|Fig|No|ca|Inc|Ltd|Corp|St|Ave|Blvd|Rd)'
    
    # Pattern for sentence boundaries
    # Matches: period/question/exclamation followed by space and capital letter
    # But not after abbreviations
    sentence_pattern = rf'(?<!{abbreviations})(?<=[.!?])\s+(?=[A-Z])'
    
    sentences = re.split(sentence_pattern, text)
    
    # Clean up and filter empty sentences
    sentences = [s.strip() for s in sentences if s.strip()]
    
    return sentences

# ============================================================================
# HYBRID SEMANTIC + PARAGRAPH-AWARE CHUNKING
# ============================================================================

@dataclass
class TextChunk:
    """Represents a chunk of text with metadata."""
    text: str
    chunk_id: str
    characters: List[str]
    themes: List[str]
    chapter: Optional[str] = None
    section: Optional[str] = None
    embedding: Optional[np.ndarray] = field(default=None, repr=False)

# Chunking configuration
CHUNK_CONFIG = {
    'target_size': 400,      # Target chunk size in characters
    'min_size': 100,         # Minimum chunk size
    'max_size': 800,         # Maximum chunk size before forced split
    'overlap_sentences': 1,  # Number of sentences to overlap
    'semantic_threshold': 0.75,  # Cosine similarity threshold for merging
}

def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
    """Calculate cosine similarity between two vectors."""
    if vec1 is None or vec2 is None:
        return 0.0
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    if norm1 == 0 or norm2 == 0:
        return 0.0
    return np.dot(vec1, vec2) / (norm1 * norm2)

def split_into_paragraphs(text: str) -> List[str]:
    """Split text into paragraphs (double newline separated)."""
    paragraphs = re.split(r'\n\n+', text)
    return [p.strip() for p in paragraphs if p.strip()]

def recursive_split(text: str, max_size: int, separators: List[str] = None) -> List[str]:
    """
    Recursively split text using hierarchy of separators.
    Tries to split on paragraph breaks first, then sentences, then words.
    """
    if separators is None:
        separators = ['\n\n', '. ', '! ', '? ', ', ', ' ']
    
    if len(text) <= max_size:
        return [text]
    
    # Try each separator in order
    for sep in separators:
        if sep in text:
            parts = text.split(sep)
            
            # Rebuild chunks respecting max_size
            chunks = []
            current = ""
            
            for i, part in enumerate(parts):
                # Add separator back (except for last part)
                part_with_sep = part + sep if i < len(parts) - 1 else part
                
                if len(current) + len(part_with_sep) <= max_size:
                    current += part_with_sep
                else:
                    if current:
                        chunks.append(current.strip())
                    current = part_with_sep
            
            if current:
                chunks.append(current.strip())
            
            # Recursively split any chunks that are still too large
            result = []
            remaining_seps = separators[separators.index(sep) + 1:] if sep in separators else []
            for chunk in chunks:
                if len(chunk) > max_size and remaining_seps:
                    result.extend(recursive_split(chunk, max_size, remaining_seps))
                else:
                    result.append(chunk)
            
            return result
    
    # Fallback: hard split by character count
    chunks = []
    for i in range(0, len(text), max_size):
        chunks.append(text[i:i + max_size])
    return chunks

def paragraph_aware_chunking(text: str, config: dict = None) -> List[str]:
    """
    Stage 1: Paragraph-aware recursive chunking.
    
    1. Split into paragraphs
    2. Group small paragraphs together
    3. Split large paragraphs using recursive approach
    """
    if config is None:
        config = CHUNK_CONFIG
    
    target_size = config['target_size']
    min_size = config['min_size']
    max_size = config['max_size']
    
    paragraphs = split_into_paragraphs(text)
    chunks = []
    current_chunk = ""
    
    for para in paragraphs:
        # If paragraph alone is too large, split it recursively
        if len(para) > max_size:
            # Save current chunk first
            if current_chunk:
                chunks.append(current_chunk.strip())
                current_chunk = ""
            
            # Split large paragraph
            sub_chunks = recursive_split(para, max_size)
            chunks.extend(sub_chunks)
        
        # If adding paragraph keeps us under target, accumulate
        elif len(current_chunk) + len(para) + 2 <= target_size:
            current_chunk += ("\n\n" if current_chunk else "") + para
        
        # If current chunk is big enough, save it and start new one
        elif len(current_chunk) >= min_size:
            chunks.append(current_chunk.strip())
            current_chunk = para
        
        # Current chunk is small, but adding this makes it too big
        else:
            current_chunk += ("\n\n" if current_chunk else "") + para
            if len(current_chunk) > max_size:
                # Split the combined chunk
                sub_chunks = recursive_split(current_chunk, max_size)
                chunks.extend(sub_chunks[:-1])
                current_chunk = sub_chunks[-1] if sub_chunks else ""
    
    # Don't forget the last chunk
    if current_chunk.strip():
        chunks.append(current_chunk.strip())
    
    return chunks

def semantic_chunk_refinement(
    chunks: List[str], 
    embedder, 
    config: dict = None
) -> List[str]:
    """
    Stage 2: Semantic refinement using embeddings.
    
    1. Compute embeddings for all chunks
    2. Merge adjacent chunks if semantically similar and combined size is acceptable
    3. Split chunks at semantic boundaries if they contain topic shifts
    """
    if config is None:
        config = CHUNK_CONFIG
    
    if not chunks:
        return chunks
    
    threshold = config['semantic_threshold']
    min_size = config['min_size']
    max_size = config['max_size']
    
    print(f"    Computing embeddings for {len(chunks)} chunks...")
    embeddings = embedder.embed(chunks)
    
    # Stage 2a: Merge semantically similar adjacent chunks
    print("    Merging semantically similar chunks...")
    merged_chunks = []
    merged_embeddings = []
    i = 0
    
    while i < len(chunks):
        current_chunk = chunks[i]
        current_emb = embeddings[i]
        
        # Try to merge with next chunks
        while i + 1 < len(chunks):
            next_chunk = chunks[i + 1]
            next_emb = embeddings[i + 1]
            combined_len = len(current_chunk) + len(next_chunk) + 2
            
            # Check if merge is beneficial
            similarity = cosine_similarity(current_emb, next_emb)
            
            # Merge if: similar content AND combined size is reasonable
            if similarity > threshold and combined_len <= max_size:
                current_chunk = current_chunk + "\n\n" + next_chunk
                # Recompute embedding for merged chunk
                current_emb = embedder.embed_single(current_chunk)
                i += 1
            else:
                break
        
        merged_chunks.append(current_chunk)
        merged_embeddings.append(current_emb)
        i += 1
    
    print(f"    After merging: {len(merged_chunks)} chunks")
    
    # Stage 2b: Split chunks with internal topic shifts
    print("    Checking for internal topic shifts...")
    final_chunks = []
    
    for chunk in merged_chunks:
        # Only check chunks that are large enough to potentially split
        if len(chunk) > config['target_size'] * 1.5:
            sentences = split_into_sentences(chunk)
            
            if len(sentences) >= 4:  # Need enough sentences to detect shifts
                # Compute embeddings for sentence groups (sliding window)
                window_size = 2
                shifts = []
                
                for j in range(len(sentences) - window_size * 2 + 1):
                    group1 = ' '.join(sentences[j:j + window_size])
                    group2 = ' '.join(sentences[j + window_size:j + window_size * 2])
                    
                    emb1 = embedder.embed_single(group1)
                    emb2 = embedder.embed_single(group2)
                    
                    similarity = cosine_similarity(emb1, emb2)
                    shifts.append((j + window_size, similarity))
                
                # Find significant topic shifts (low similarity points)
                if shifts:
                    min_sim_idx, min_sim = min(shifts, key=lambda x: x[1])
                    
                    # If there's a clear topic shift, split there
                    if min_sim < threshold - 0.1:  # Significant drop in similarity
                        split_point = min_sim_idx
                        chunk1 = ' '.join(sentences[:split_point])
                        chunk2 = ' '.join(sentences[split_point:])
                        
                        if len(chunk1) >= min_size and len(chunk2) >= min_size:
                            final_chunks.append(chunk1)
                            final_chunks.append(chunk2)
                            continue
            
        final_chunks.append(chunk)
    
    print(f"    After topic-shift splitting: {len(final_chunks)} chunks")
    return final_chunks

def add_sentence_overlap(chunks: List[str], overlap_sentences: int = 1) -> List[str]:
    """
    Add sentence overlap between chunks for better context continuity.
    """
    if overlap_sentences <= 0 or len(chunks) <= 1:
        return chunks
    
    overlapped_chunks = [chunks[0]]
    
    for i in range(1, len(chunks)):
        # Get last N sentences from previous chunk
        prev_sentences = split_into_sentences(chunks[i - 1])
        overlap_text = ' '.join(prev_sentences[-overlap_sentences:]) if prev_sentences else ""
        
        # Prepend to current chunk (with marker for clarity)
        if overlap_text:
            overlapped_chunks.append(overlap_text + " " + chunks[i])
        else:
            overlapped_chunks.append(chunks[i])
    
    return overlapped_chunks

def hybrid_semantic_chunking(
    text: str, 
    embedder,
    config: dict = None,
    add_overlap: bool = True
) -> List[str]:
    """
    Main hybrid chunking function combining:
    1. Paragraph-aware recursive splitting
    2. Semantic refinement (merge similar, split at topic shifts)
    3. Sentence overlap for context continuity
    """
    if config is None:
        config = CHUNK_CONFIG
    
    print("  Stage 1: Paragraph-aware recursive chunking...")
    chunks = paragraph_aware_chunking(text, config)
    print(f"    Created {len(chunks)} initial chunks")
    
    print("  Stage 2: Semantic refinement...")
    chunks = semantic_chunk_refinement(chunks, embedder, config)
    
    if add_overlap:
        print("  Stage 3: Adding sentence overlap...")
        chunks = add_sentence_overlap(chunks, config['overlap_sentences'])
        print(f"    Final chunks with overlap: {len(chunks)}")
    
    return chunks

# ============================================================================
# METADATA EXTRACTION
# ============================================================================

MAHABHARAT_CHARACTERS = [
    "Krishna", "Arjuna", "Yudhishthira", "Bhima", "Nakula", "Sahadeva",
    "Draupadi", "Duryodhana", "Dushasana", "Karna", "Drona", "Bhishma",
    "Kunti", "Pandu", "Dhritarashtra", "Gandhari", "Vidura", "Shakuni",
    "Abhimanyu", "Ghatotkacha", "Kripacharya", "Ashwatthama", "Jayadratha"
]

MAHABHARAT_THEMES = [
    "dharma", "war", "duty", "honor", "betrayal", "sacrifice",
    "friendship", "loyalty", "revenge", "justice", "wisdom", "courage"
]

def extract_characters(text: str) -> List[str]:
    """Extract character names from text."""
    found_characters = []
    text_lower = text.lower()
    for char in MAHABHARAT_CHARACTERS:
        if char.lower() in text_lower:
            found_characters.append(char)
    return found_characters

def extract_themes(text: str) -> List[str]:
    """Extract themes from text."""
    found_themes = []
    text_lower = text.lower()
    for theme in MAHABHARAT_THEMES:
        if theme.lower() in text_lower:
            found_themes.append(theme)
    return found_themes

def create_text_chunks(chunks: List[str]) -> List[TextChunk]:
    """Convert text strings to TextChunk objects with metadata."""
    text_chunks = []
    for i, text in enumerate(chunks):
        chunk = TextChunk(
            text=text,
            chunk_id=f"chunk_{i:04d}",
            characters=extract_characters(text),
            themes=extract_themes(text),
            chapter=None,
            section=None
        )
        text_chunks.append(chunk)
    return text_chunks

# ============================================================================
# VALIDATION AND DEDUPLICATION
# ============================================================================

def validate_chunk(text: str, min_length: int = 50, min_word_count: int = 5) -> bool:
    """Validate that a chunk is meaningful."""
    if not text or not text.strip():
        return False
    
    text = text.strip()
    
    if len(text) < min_length:
        return False
    
    words = text.split()
    if len(words) < min_word_count:
        return False
    
    alphanumeric_chars = sum(1 for c in text if c.isalnum())
    if alphanumeric_chars < len(text) * 0.5:
        return False
    
    return True

def deduplicate_chunks(chunks: List[TextChunk], similarity_threshold: float = 0.95) -> List[TextChunk]:
    """Remove duplicate and near-duplicate chunks."""
    seen_texts: Set[str] = set()
    unique_chunks = []
    
    for chunk in chunks:
        text = chunk.text.strip().lower()
        
        if text in seen_texts:
            continue
        
        is_duplicate = False
        words1 = set(text.split())
        
        for seen_text in seen_texts:
            words2 = set(seen_text.split())
            if len(words1) > 0 and len(words2) > 0:
                intersection = len(words1 & words2)
                union = len(words1 | words2)
                similarity = intersection / union if union > 0 else 0
                if similarity > similarity_threshold:
                    is_duplicate = True
                    break
        
        if not is_duplicate:
            seen_texts.add(text)
            unique_chunks.append(chunk)
    
    return unique_chunks

def extract_text_from_pdf(pdf_path: Path) -> str:
    """Extract text from PDF file."""
    from pypdf import PdfReader
    reader = PdfReader(pdf_path)
    text = ""
    print(f"Extracting text from PDF ({len(reader.pages)} pages)...")
    for i, page in enumerate(reader.pages):
        text += page.extract_text() + "\n"
        if (i + 1) % 50 == 0:
            print(f"  Processed {i + 1}/{len(reader.pages)} pages...")
    print(f"Extracted {len(text)} characters from PDF")
    return text

# ============================================================================
# MAIN PIPELINE
# ============================================================================

# Check if vector DB already exists
config_path = kaggle_dir / "kaggle_config.yaml"
vector_store = KaggleVectorStore(config_path=config_path)

if vector_store.count() > 0:
    print(f"\n✓ Vector database already exists with {vector_store.count()} passages")
    print("Skipping PDF processing. If you want to rebuild, delete the vector_db directory first.")
else:
    print("\n" + "="*60)
    print("CREATING VECTOR DATABASE FROM PDF")
    print("="*60)
    
    pdf_path = Path(PDF_PATH)
    if not pdf_path.exists():
        # Try alternative paths
        alt_paths = [
            Path("/kaggle/input/data/raw/Mahabharata.pdf"),
            Path("/kaggle/input/Mahabharata.pdf"),
        ]
        for alt_path in alt_paths:
            if alt_path.exists():
                pdf_path = alt_path
                break
    
    if not pdf_path.exists():
        raise FileNotFoundError(
            f"PDF not found at {PDF_PATH}\n"
            f"Please upload Mahabharata.pdf to /kaggle/input/ first"
        )
    
    print(f"\nProcessing PDF: {pdf_path}")
    
    # Step 1: Extract text from PDF
    print("\n" + "-"*40)
    print("STEP 1: Extracting text from PDF")
    print("-"*40)
    text = extract_text_from_pdf(pdf_path)
    original_length = len(text)
    
    # Step 2: Clean text (preserving paragraph structure)
    print("\n" + "-"*40)
    print("STEP 2: Cleaning text (preserving paragraphs)")
    print("-"*40)
    text = clean_text(text)
    cleaned_length = len(text)
    print(f"  Cleaned: {original_length:,} → {cleaned_length:,} chars ({100*cleaned_length/original_length:.1f}% retained)")
    
    # Count paragraphs for info
    paragraph_count = len(split_into_paragraphs(text))
    print(f"  Detected {paragraph_count} paragraphs")
    
    # Step 3: Initialize embedder (needed for semantic chunking)
    print("\n" + "-"*40)
    print("STEP 3: Initializing embedder")
    print("-"*40)
    embedder = KaggleEmbedder(config_path=config_path)
    print(f"  Model: {embedder.model_name}")
    
    # Step 4: Hybrid semantic + paragraph-aware chunking
    print("\n" + "-"*40)
    print("STEP 4: Hybrid semantic chunking")
    print("-"*40)
    print(f"  Config: target={CHUNK_CONFIG['target_size']}, min={CHUNK_CONFIG['min_size']}, max={CHUNK_CONFIG['max_size']}")
    print(f"  Semantic threshold: {CHUNK_CONFIG['semantic_threshold']}")
    
    chunk_texts = hybrid_semantic_chunking(
        text, 
        embedder, 
        config=CHUNK_CONFIG,
        add_overlap=True
    )
    
    # Step 5: Create TextChunk objects with metadata
    print("\n" + "-"*40)
    print("STEP 5: Extracting metadata")
    print("-"*40)
    chunks = create_text_chunks(chunk_texts)
    print(f"  Created {len(chunks)} chunks with metadata")
    
    # Count metadata
    chunks_with_characters = sum(1 for c in chunks if c.characters)
    chunks_with_themes = sum(1 for c in chunks if c.themes)
    print(f"  Chunks with characters: {chunks_with_characters}")
    print(f"  Chunks with themes: {chunks_with_themes}")
    
    # Step 6: Validate chunks
    print("\n" + "-"*40)
    print("STEP 6: Validating chunks")
    print("-"*40)
    valid_chunks = [chunk for chunk in chunks if validate_chunk(chunk.text)]
    invalid_count = len(chunks) - len(valid_chunks)
    print(f"  Valid chunks: {len(valid_chunks)} (removed {invalid_count} invalid)")
    
    # Step 7: Deduplicate chunks
    print("\n" + "-"*40)
    print("STEP 7: Deduplicating chunks")
    print("-"*40)
    unique_chunks = deduplicate_chunks(valid_chunks)
    duplicate_count = len(valid_chunks) - len(unique_chunks)
    print(f"  Unique chunks: {len(unique_chunks)} (removed {duplicate_count} duplicates)")
    
    # Step 8: Generate final embeddings
    print("\n" + "-"*40)
    print("STEP 8: Generating final embeddings")
    print("-"*40)
    texts = [chunk.text for chunk in unique_chunks]
    embeddings = embedder.embed(texts)
    print(f"  Generated {len(embeddings)} embeddings")
    
    # Step 9: Add to vector store
    print("\n" + "-"*40)
    print("STEP 9: Adding passages to vector database")
    print("-"*40)
    passages = []
    for i, chunk in enumerate(unique_chunks):
        passage = Passage(
            chunk_id=chunk.chunk_id,
            text=chunk.text,
            embedding=embeddings[i],
            characters=chunk.characters,
            themes=chunk.themes,
            chapter=chunk.chapter,
            section=chunk.section
        )
        passages.append(passage)
    
    # Add in batches
    batch_size = 100
    for i in range(0, len(passages), batch_size):
        batch = passages[i:i + batch_size]
        vector_store.add_passages(batch)
        print(f"  Added batch {i//batch_size + 1}/{(len(passages) + batch_size - 1)//batch_size}")
    
    # Summary
    print("\n" + "="*60)
    print("VECTOR DATABASE CREATION COMPLETE")
    print("="*60)
    print(f"  Source paragraphs: {paragraph_count}")
    print(f"  Initial chunks (after semantic): {len(chunks)}")
    print(f"  Invalid removed: {invalid_count}")
    print(f"  Duplicates removed: {duplicate_count}")
    print(f"  Final passages: {vector_store.count()}")
    
    # Chunk size statistics
    chunk_sizes = [len(c.text) for c in unique_chunks]
    print(f"\n  Chunk size stats:")
    print(f"    Min: {min(chunk_sizes)} chars")
    print(f"    Max: {max(chunk_sizes)} chars")
    print(f"    Avg: {sum(chunk_sizes) // len(chunk_sizes)} chars")


## 4. Generate Story Scripts and Videos


In [None]:
# Generate Story Scripts and Videos
import sys
import json
from pathlib import Path
from datetime import datetime

sys.path.insert(0, str(kaggle_dir))
from story_pipeline import (
    KaggleVectorStore,
    KaggleRetriever,
    KaggleStoryGenerator,
    KaggleNarrationGenerator
)
from wan21_generator import Wan21KaggleGenerator

config_path = kaggle_dir / "kaggle_config.yaml"

print("\n" + "="*60)
print("STEP 1: LOADING VECTOR DATABASE")
print("="*60)

# Load vector store
vector_store = KaggleVectorStore(config_path=config_path)
print(f"Vector DB loaded: {vector_store.count()} passages")

if vector_store.count() == 0:
    print("\n⚠ WARNING: Vector database is empty!")
    print("Please run the previous cell to create vector DB from PDF first.")
    raise ValueError("Vector database is empty")

print("\n" + "="*60)
print("STEP 2: RETRIEVING PASSAGES")
print("="*60)

# Initialize retriever
retriever = KaggleRetriever(vector_store=vector_store, config_path=config_path)

# Retrieve seed passage
seed_passages = retriever.retrieve_diverse(query=QUERY, n_results=1)
if not seed_passages:
    raise ValueError("Failed to retrieve seed passage")

seed_passage = seed_passages[0]
print(f"\nSeed Passage ID: {seed_passage['id']}")
print(f"Text preview: {seed_passage['text'][:200]}...")

# Retrieve context passages
context_passages = retriever.retrieve_context(seed_passage, n_context=3)
print(f"\nRetrieved {len(context_passages)} context passages")

print("\n" + "="*60)
print("STEP 3: GENERATING STORY SCRIPTS")
print("="*60)

# Generate story
story_generator = KaggleStoryGenerator(config_path=config_path)

story_data = story_generator.generate_story(
    seed_passage=seed_passage['text'],
    context_passages=[p['text'] for p in context_passages]
)

print(f"\n✓ Generated story with {len(story_data.get('story_sequence', []))} scenes")

# Generate narration
narration_generator = KaggleNarrationGenerator(config_path=config_path)
narration_data = narration_generator.generate_narration(story_data)

# Save generated scripts
output_dir = Path("/kaggle/working/output")
output_dir.mkdir(parents=True, exist_ok=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Convert to video_prompts format
video_prompts = []
for scene in story_data.get('story_sequence', []):
    clip_prompt = scene.get('clip_prompt', {})
    video_prompts.append({
        "scene_number": scene.get('scene_number', len(video_prompts) + 1),
        "title": scene.get('title', f"Scene {scene.get('scene_number', len(video_prompts) + 1)}"),
        "video_prompt": clip_prompt.get('visual_description', ''),
        "camera_angles": clip_prompt.get('camera_angles', ''),
        "key_objects": clip_prompt.get('key_objects', ''),
        "ambient_audio": clip_prompt.get('ambient_audio', ''),
        "duration_seconds": 5
    })

# Save video prompts
prompts_file = output_dir / f"video_prompts_{timestamp}.json"
with open(prompts_file, 'w', encoding='utf-8') as f:
    json.dump(video_prompts, f, indent=2, ensure_ascii=False)
print(f"\n✓ Saved video prompts: {prompts_file}")

# Save narration
narration_file = output_dir / f"narration_{timestamp}.json"
with open(narration_file, 'w', encoding='utf-8') as f:
    json.dump(narration_data, f, indent=2, ensure_ascii=False)
print(f"✓ Saved narration: {narration_file}")

# Display the generated scripts
print("\n" + "="*60)
print("GENERATED VIDEO PROMPTS")
print("="*60)
for scene in video_prompts:
    print(f"\nScene {scene['scene_number']}: {scene['title']}")
    print(f"  Prompt: {scene['video_prompt'][:150]}...")

print("\n" + "="*60)
print("GENERATED NARRATION")
print("="*60)
print(f"Title: {narration_data['story_title']}")
print(f"\nHook: {narration_data['hook']}")
print(f"\nStory: {narration_data['story'][:200]}...")
print(f"\nCTA: {narration_data['cta']}")

print("\n" + "="*60)
print("STEP 4: GENERATING VIDEOS")
print("="*60)

# Initialize video generator
video_generator = Wan21KaggleGenerator(config_path=config_path)

# Generate videos from scenes
video_paths = video_generator.generate_from_scenes(video_prompts)

print("\n" + "="*60)
print("GENERATION COMPLETE")
print("="*60)
print(f"\nGenerated {len(video_paths)} videos:")
for path in video_paths:
    print(f"  - {path}")


In [None]:
# Zip output folder for easy download
import shutil

output_dir = Path("/kaggle/working/output")
zip_path = Path("/kaggle/working/videos_output.zip")

if output_dir.exists() and any(output_dir.iterdir()):
    shutil.make_archive(str(zip_path).replace('.zip', ''), 'zip', output_dir)
    print(f"✓ Created zip file: {zip_path}")
    print(f"  Size: {zip_path.stat().st_size / (1024*1024):.2f} MB")
else:
    print("No output files to zip")
