# Extract Key Points from Textbook Chunks

This notebook processes the chunked textbook data and extracts key points from each section using rotating LLM providers (OpenAI, Deepseek, Gemini).

## 1. Setup and Imports

In [None]:
import os
import json
import sys
from dotenv import load_dotenv
from pathlib import Path
from typing import Dict, List, Any

# Add repo to path
sys.path.append(os.path.abspath(os.path.join("..", "")))

# Load environment variables
load_dotenv()

# Import LLM clients
from src.core.llm_client import ModelProvider, LLMFactory

## 2. Load Chunks Data

In [None]:
# Path to the chunks file
chunks_file = Path("../data/hizan/output/pyhton_short-1772218124093/hybrid_auto/chunks.json")

# Load chunks
with open(chunks_file, 'r', encoding='utf-8') as f:
    chunks = json.load(f)

print(f"Loaded {len(chunks)} chunks")
print(f"First chunk keys: {chunks[0].keys() if chunks else 'N/A'}")

# Display sample chunk
if chunks and len(chunks) > 10:
    sample = chunks[10]
    print(f"\nSample chunk:")
    print(f"Content length: {len(sample.get('content', ''))}")
    print(f"Header 1: {sample.get('metadata', {}).get('header_1', 'N/A')}")
    print(f"Content preview: {sample.get('content', '')[:150]}...")

## 3. Create Key Points Extraction Prompt

In [None]:
def create_extraction_prompt(content: str) -> str:
    """Create a prompt for extracting key points"""
    return f"""Extract 3-5 concise key points from the following textbook section. 
Format the output as a JSON object with a "key_points" array.

Content:
{content}

Return ONLY valid JSON in this format:
{{"key_points": ["point 1", "point 2", "point 3"]}}
"""

# Test the prompt
sample_content = chunks[15]["content"] if len(chunks) > 15 else "Sample text"
print("Sample prompt:")
print(create_extraction_prompt(sample_content[:200])[:300])

## 4. Initialize LLM Clients with Rotating Provider

In [None]:
# Create clients
try:
    openai_client = LLMFactory.create_client(ModelProvider.OPENAI, temperature=0.3)
    deepseek_client = LLMFactory.create_client(ModelProvider.DEEPSEEK, temperature=0.3)
    gemini_client = LLMFactory.create_client(ModelProvider.GOOGLE, temperature=0.3)
    
    clients = [openai_client, deepseek_client, gemini_client]
    client_names = ["OpenAI", "Deepseek", "Gemini"]
    
    print("✓ All clients initialized successfully")
except Exception as e:
    print(f"✗ Error initializing clients: {e}")
    clients = []

## 5. Extract Key Points with Progress Tracking

In [None]:
results = []
failed_chunks = []

# Process only chunks with meaningful content
processing_chunks = [c for c in chunks if c.get('content') and len(c.get('content', '')) > 100]

print(f"Processing {len(processing_chunks)} chunks with meaningful content...\n")

for idx, chunk in enumerate(processing_chunks):
    # Rotate through clients
    client = clients[idx % len(clients)]
    client_name = client_names[idx % len(client_names)]
    
    content = chunk.get('content', '')
    metadata = chunk.get('metadata', {})
    header_1 = metadata.get('header_1', 'Unknown')
    
    # Skip if content is too short
    if len(content) < 100:
        continue
    
    try:
        # Truncate long content to avoid token limits (use first 2000 chars)
        truncated_content = content[:2000]
        prompt = create_extraction_prompt(truncated_content)
        
        # Generate key points
        response_text = client.generate_text(
            prompt, 
            max_tokens=200,
            system_prompt="You are a helpful assistant that extracts key points from educational content."
        )
        
        # Try to parse JSON response
        try:
            response_json = json.loads(response_text)
            key_points = response_json.get('key_points', [])
        except json.JSONDecodeError:
            # If JSON parsing fails, try to extract from markdown code blocks
            if '```json' in response_text:
                json_str = response_text.split('```json')[1].split('```')[0]
                response_json = json.loads(json_str)
                key_points = response_json.get('key_points', [])
            elif '```' in response_text:
                json_str = response_text.split('```')[1].split('```')[0]
                response_json = json.loads(json_str)
                key_points = response_json.get('key_points', [])
            else:
                raise ValueError(f"Could not parse response: {response_text}")
        
        results.append({
            'header': header_1,
            'source': metadata.get('source', ''),
            'provider': client_name,
            'key_points': key_points,
            'chunk_index': idx
        })
        
        if (idx + 1) % 5 == 0:
            print(f"✓ Processed {idx + 1}/{len(processing_chunks)} - {header_1[:40]}... ({client_name})")
    
    except Exception as e:
        failed_chunks.append({
            'index': idx,
            'header': header_1,
            'error': str(e)
        })
        print(f"✗ Failed chunk {idx}: {header_1[:40]}... - {str(e)[:60]}")

print(f"\n✓ Extraction complete!")
print(f"  Successfully processed: {len(results)}")
print(f"  Failed: {len(failed_chunks)}")

## 6. Save Results to JSON

In [None]:
# Save results
output_file = Path("../data/hizan/output/keypoints_extracted.json")
output_file.parent.mkdir(parents=True, exist_ok=True)

with open(output_file, 'w', encoding='utf-8') as f:
    json.dump({
        'metadata': {
            'total_chunks_processed': len(processing_chunks),
            'total_successful': len(results),
            'total_failed': len(failed_chunks),
            'providers_used': client_names
        },
        'keypoints': results,
        'failed': failed_chunks
    }, f, indent=2, ensure_ascii=False)

print(f"✓ Results saved to: {output_file}")
print(f"  File size: {output_file.stat().st_size / 1024:.1f} KB")

## 7. Preview Results

In [None]:
# Display sample results
print("Sample extracted key points:\n")
for i, result in enumerate(results[:5]):
    print(f"{i+1}. {result['header']} ({result['provider']})")
    for point in result['key_points']:
        print(f"   • {point}")
    print()