# 3-Stage Training Document Generation

This notebook implements a 3-stage process for generating training documents:
1. Use Gemini to analyze video and extract knowledge points
2. Use Gemini to select timestamps for screenshots (3 separate API calls)
3. Use OpenAI GPT-4o to curate screenshots and captions
4. Generate final DOCX document

In [None]:
# Import necessary libraries
import os
import json
import time
import cv2
import requests
from google import genai
from openai import OpenAI
from dotenv import load_dotenv
import base64
from docx import Document
from docx.shared import Inches
from docx.enum.text import WD_ALIGN_PARAGRAPH
import re
from IPython.display import display, Image
import numpy as np

# Load environment variables with API keys
load_dotenv()

# Initialize API clients
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")

gemni_client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
openai_client = OpenAI(api_key=OPENAI_API_KEY)

# Import prompts
from three_stage_testing.prompts_Three_Stage import stage_1_prompt100, stage_2_prompt100, stage_3_prompt100

## Configure Paths and Settings

In [None]:
# Configure paths and settings
video_path = "KT Recording/modify table in EDW using git.mp4"  # Update this to your video path
job_id = int(time.time())  # Generate a unique ID for this job

# Create directories for outputs
base_folder = f"training_job_{job_id}"
os.makedirs(base_folder, exist_ok=True)

# Folders for the 3 API attempts in stage 2
screenshots_folders = [
    os.path.join(base_folder, f"screenshots_attempt_{i+1}") for i in range(3)
]
for folder in screenshots_folders:
    os.makedirs(folder, exist_ok=True)

# Final output paths
output_json_path = os.path.join(base_folder, "training_data.json")
output_docx_path = os.path.join(base_folder, "training_document.docx")

print(f"Job ID: {job_id}")
print(f"Output folder: {base_folder}")

## Helper Functions

In [None]:
# Helper functions

def extract_screenshots(video_path, timestamps, output_folder, knowledge_point_index, api_attempt_index):
    """Extract screenshots from video at given timestamps"""
    screenshot_paths = []
    
    try:
        # Open the video file
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Could not open video file {video_path}")
            return []
        
        # Get video properties
        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        duration = total_frames / fps if fps > 0 else 0
        
        print(f"Video properties: Duration={duration:.2f}s, FPS={fps:.2f}, Total frames={total_frames}")
        
        # Process each timestamp
        for screenshot_index, timestamp in enumerate(timestamps):
            try:
                # Parse timestamp (assuming format like "1:30")
                if ':' in timestamp:
                    minutes, seconds = timestamp.split(':')
                    time_in_seconds = int(minutes) * 60 + float(seconds)
                else:
                    # If only seconds are provided
                    time_in_seconds = float(timestamp)
                
                # Skip if timestamp is beyond video duration
                if duration > 0 and time_in_seconds > duration:
                    print(f"Timestamp {timestamp} exceeds video duration of {duration:.2f}s")
                    continue
                
                # Set the frame position
                cap.set(cv2.CAP_PROP_POS_MSEC, time_in_seconds * 1000)
                
                # Read the frame
                success, frame = cap.read()
                if success:
                    # Generate filename using the specified format
                    screenshot_filename = f"{knowledge_point_index+1}_{screenshot_index+1}_{api_attempt_index+1}.png"
                    screenshot_path = os.path.join(output_folder, screenshot_filename)
                    
                    # Save the frame
                    cv2.imwrite(screenshot_path, frame)
                    screenshot_paths.append(screenshot_path)
                    print(f"Saved screenshot: {screenshot_path}")
                else:
                    print(f"Failed to capture screenshot at timestamp {timestamp}")
            
            except Exception as e:
                print(f"Error processing timestamp {timestamp}: {str(e)}")
        
        # Release the video capture
        cap.release()
        
    except Exception as e:
        print(f"Error in extract_screenshots: {str(e)}")
    
    return screenshot_paths

def image_to_base64(image_path):
    """Convert an image file to base64 encoded string"""
    try:
        with open(image_path, "rb") as image_file:
            encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
            return encoded_string
    except Exception as e:
        print(f"Error encoding image to base64: {str(e)}")
        return None

def parse_gemini_response(response_text):
    """Parse the response from Gemini to extract JSON"""
    try:
        # Try to parse the entire response as JSON
        return json.loads(response_text)
    except json.JSONDecodeError:
        # If that fails, look for JSON content in markdown code blocks
        json_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
        matches = re.findall(json_pattern, response_text)
        
        if matches:
            try:
                return json.loads(matches[0])
            except json.JSONDecodeError:
                print(f"Failed to parse JSON from code block: {matches[0]}")
                
        # If no code blocks or parsing failed, try to find JSON-like structures
        start_idx = response_text.find('{')
        end_idx = response_text.rfind('}')
        if start_idx >= 0 and end_idx > start_idx:
            try:
                json_str = response_text[start_idx:end_idx+1]
                return json.loads(json_str)
            except json.JSONDecodeError:
                print(f"Failed to parse JSON-like structure: {json_str}")
                
        print("Could not extract valid JSON from response")
        print("Response text:")
        print(response_text)
        return None

## Stage 1: Use Gemini to Extract Knowledge Points from Video

In [None]:
# Stage 1: Use Gemini to extract knowledge points from video

def stage1_extract_knowledge_points(video_path):
    print("=== Stage 1: Extracting Knowledge Points ===")
    
    # Read the video file
    with open(video_path, 'rb') as f:
        video_data = f.read()
    
    # Create a Gemini model instance (using the Vision model for video)
    model = genai.GenerativeModel('gemini-1.5-pro-vision')
    
    print(f"Sending video to Gemini for analysis...")
    
    # Create the prompt with video
    contents = [
        stage_1_prompt100,
        {"mime_type": "video/mp4", "data": video_data}
    ]
    
    # Call the Gemini API
    response = model.generate_content(contents)
    response_text = response.text
    
    # Parse the response
    print("Parsing response from Gemini...")
    result = parse_gemini_response(response_text)
    
    if result and isinstance(result, dict):
        # Save the result
        stage1_result_path = os.path.join(base_folder, "stage1_result.json")
        with open(stage1_result_path, 'w') as f:
            json.dump(result, f, indent=2)
        
        # Verify required fields
        if 'Summary' in result and 'knowledge_points' in result:
            print(f"Successfully extracted {len(result['knowledge_points'])} knowledge points")
            return result
        else:
            print("Warning: Missing required fields in the result")
            return result
    else:
        print("Failed to parse response from Gemini")
        return None

# Run Stage 1
stage1_result = stage1_extract_knowledge_points(video_path)

# Display summary and first few knowledge points
if stage1_result:
    print("\nSummary:")
    print(stage1_result.get('Summary', 'No summary provided'))
    
    print("\nFirst 3 Knowledge Points:")
    for i, point in enumerate(stage1_result.get('knowledge_points', [])[:3]):
        print(f"{i+1}. {point}")
    
    # Display total count
    total_points = len(stage1_result.get('knowledge_points', []))
    if total_points > 3:
        print(f"...and {total_points-3} more points")

## Stage 2: Use Gemini to Select Timestamps (3 Attempts)

In [None]:
# Stage 2: Use Gemini to select timestamps for knowledge points

def stage2_select_timestamps(video_path, stage1_result):
    print("=== Stage 2: Selecting Timestamps for Screenshots ===")
    
    # Check if we have the required data from Stage 1
    if not stage1_result or 'knowledge_points' not in stage1_result:
        print("Error: Missing required data from Stage 1")
        return None
    
    # Read the video file
    with open(video_path, 'rb') as f:
        video_data = f.read()
    
    # Create prompt with video and knowledge points
    knowledge_points = stage1_result['knowledge_points']
    
    # Replace placeholder in prompt template
    prompt = stage_2_prompt100.replace("{{summary_from_stage_1}}", json.dumps(stage1_result.get('Summary', '')))
    
    # Create the content with prompt, video and knowledge points
    contents = [
        prompt,
        {"mime_type": "video/mp4", "data": video_data},
        "Knowledge Points:\n" + "\n".join([f"{i+1}. {point}" for i, point in enumerate(knowledge_points)])
    ]
    
    # Create a Gemini model instance
    model = genai.GenerativeModel('gemini-1.5-pro-vision')
    
    # Run 3 separate API calls and collect timestamps
    all_attempt_results = []
    all_screenshot_paths = []
    
    for attempt in range(3):
        print(f"\nAttempt {attempt+1}/3: Calling Gemini API for timestamp selection...")
        
        # Call the Gemini API
        response = model.generate_content(contents)
        response_text = response.text
        
        # Save the raw response
        raw_response_path = os.path.join(base_folder, f"stage2_raw_response_attempt_{attempt+1}.txt")
        with open(raw_response_path, 'w') as f:
            f.write(response_text)
        
        # Parse the response to get timestamps
        timestamps_data = parse_gemini_response(response_text)
        
        if not timestamps_data or not isinstance(timestamps_data, dict):
            print(f"Attempt {attempt+1}: Failed to parse response from Gemini")
            continue
        
        # Save parsed response
        parsed_response_path = os.path.join(base_folder, f"stage2_parsed_response_attempt_{attempt+1}.json")
        with open(parsed_response_path, 'w') as f:
            json.dump(timestamps_data, f, indent=2)
        
        print(f"Attempt {attempt+1}: Successfully parsed response")
        
        # Extract timestamps for each knowledge point
        attempt_screenshots = []
        
        # Extract screenshots based on timestamps
        for knowledge_point_index, timestamps in timestamps_data.items():
            try:
                # Convert string index to integer if needed
                if isinstance(knowledge_point_index, str) and knowledge_point_index.isdigit():
                    knowledge_point_index = int(knowledge_point_index)
                
                if isinstance(knowledge_point_index, int) and 0 <= knowledge_point_index < len(knowledge_points):
                    # Extract screenshots
                    print(f"Extracting screenshots for knowledge point {knowledge_point_index+1}")
                    screenshot_paths = extract_screenshots(
                        video_path, timestamps, screenshots_folders[attempt],
                        knowledge_point_index, attempt
                    )
                    
                    attempt_screenshots.append({
                        "knowledge_point_index": knowledge_point_index,
                        "timestamps": timestamps,
                        "screenshot_paths": screenshot_paths
                    })
            except Exception as e:
                print(f"Error processing knowledge point {knowledge_point_index}: {str(e)}")
        
        all_attempt_results.append(attempt_screenshots)
        
        # Collect all screenshot paths
        all_paths = []
        for item in attempt_screenshots:
            all_paths.extend(item["screenshot_paths"])
        all_screenshot_paths.extend(all_paths)
        
        print(f"Attempt {attempt+1}: Extracted {len(all_paths)} screenshots")
    
    # Save combined results
    stage2_result = {
        "attempt_results": all_attempt_results,
        "all_screenshot_paths": all_screenshot_paths
    }
    
    stage2_result_path = os.path.join(base_folder, "stage2_result.json")
    with open(stage2_result_path, 'w') as f:
        json.dump(stage2_result, f, indent=2)
    
    print(f"\nStage 2 completed with {len(all_screenshot_paths)} total screenshots across 3 attempts")
    return stage2_result

# Run Stage 2
stage2_result = stage2_select_timestamps(video_path, stage1_result)

# Display a sample of screenshots from each attempt
if stage2_result and stage2_result['attempt_results']:
    print("\nSample screenshots from each attempt:")
    for i, attempt_result in enumerate(stage2_result['attempt_results']):
        if attempt_result and len(attempt_result) > 0:
            sample = attempt_result[0]  # Get the first knowledge point's screenshots
            print(f"\nAttempt {i+1} - Knowledge Point {sample['knowledge_point_index']+1}:")
            print(f"Timestamps: {sample['timestamps']}")
            print(f"Screenshots: {len(sample['screenshot_paths'])} extracted")
            
            # Display a sample screenshot if available
            if sample['screenshot_paths'] and len(sample['screenshot_paths']) > 0:
                sample_path = sample['screenshot_paths'][0]
                display(Image(filename=sample_path, width=400))

## Stage 3: Use GPT-4o to Curate Screenshots

In [None]:
# Stage 3: Use GPT-4o to curate screenshots

def stage3_curate_screenshots(stage1_result, stage2_result):
    print("=== Stage 3: Curating Screenshots with GPT-4o ===")
    
    # Check if we have the required data from Stages 1 and 2
    if not stage1_result or 'knowledge_points' not in stage1_result:
        print("Error: Missing required data from Stage 1")
        return None
    
    if not stage2_result or 'attempt_results' not in stage2_result:
        print("Error: Missing required data from Stage 2")
        return None
    
    # Organize screenshots by knowledge point
    knowledge_points = stage1_result['knowledge_points']
    curated_results = []
    
    # Replace placeholder in prompt template
    prompt_template = stage_3_prompt100.replace("{{summary_from_stage_1}}", json.dumps(stage1_result.get('Summary', '')))
    
    # Process each knowledge point
    for knowledge_point_index, knowledge_point in enumerate(knowledge_points):
        print(f"\nProcessing knowledge point {knowledge_point_index+1}/{len(knowledge_points)}")
        print(f"Knowledge point: {knowledge_point[:100]}..." if len(knowledge_point) > 100 else knowledge_point)
        
        # Collect all screenshots for this knowledge point from all attempts
        point_screenshots = []
        point_screenshot_ids = []
        
        for attempt_index, attempt_result in enumerate(stage2_result['attempt_results']):
            for item in attempt_result:
                if item["knowledge_point_index"] == knowledge_point_index:
                    for screenshot_index, path in enumerate(item.get("screenshot_paths", [])):
                        # Extract base filename as ID
                        filename = os.path.basename(path)
                        name_without_ext = os.path.splitext(filename)[0]
                        
                        point_screenshots.append(path)
                        point_screenshot_ids.append(name_without_ext)
        
        # If no screenshots found, skip this knowledge point
        if not point_screenshots:
            print(f"No screenshots found for knowledge point {knowledge_point_index+1}")
            curated_results.append({
                "knowledge_point_index": knowledge_point_index,
                "knowledge_point": knowledge_point,
                "screenshots": [],
                "captions": []
            })
            continue
        
        print(f"Found {len(point_screenshots)} screenshots for curation")
        
        # Check if there's a reasonable number of screenshots to process
        if len(point_screenshots) > 20:
            print(f"Warning: Large number of screenshots ({len(point_screenshots)}). Processing may take time.")
        
        # Create the input for GPT-4o
        user_content = [
            {"type": "text", "text": f"Knowledge point: {knowledge_point}\n\nBelow are screenshots to curate:"}
        ]
        
        # Add images to the content
        for path, screenshot_id in zip(point_screenshots, point_screenshot_ids):
            base64_image = image_to_base64(path)
            if base64_image:
                user_content.append({
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/png;base64,{base64_image}"
                    }
                })
                user_content.append({"type": "text", "text": f"Image ID: {screenshot_id}"})  
        
        # Call GPT-4o
        print(f"Calling GPT-4o API for curation...")
        try:
            response = openai_client.chat.completions.create(
                model="gpt-4o",
                messages=[
                    {"role": "system", "content": prompt_template},
                    {"role": "user", "content": user_content}
                ],
                max_tokens=4000
            )
            
            response_text = response.choices[0].message.content
            
            # Save the raw response
            raw_response_path = os.path.join(base_folder, f"stage3_raw_response_point_{knowledge_point_index+1}.txt")
            with open(raw_response_path, 'w') as f:
                f.write(response_text)
            
            # Parse the response to get curated screenshots
            curated_data = parse_gemini_response(response_text)  # Reusing the same parsing function
            
            if not curated_data or not isinstance(curated_data, dict):
                print(f"Failed to parse GPT-4o response for point {knowledge_point_index+1}")
                curated_results.append({
                    "knowledge_point_index": knowledge_point_index,
                    "knowledge_point": knowledge_point,
                    "screenshot_groups": [],
                    "selected_screenshots": [],
                    "captions": []
                })
                continue
            
            # Extract selected screenshots
            selected_ids = curated_data.get("selected_indexes", [])
            captions = curated_data.get("captions", [])
            groups = curated_data.get("groups", [])
            
            # Map IDs back to file paths
            selected_paths = []
            for selected_id in selected_ids:
                # Find the matching screenshot path
                found = False
                for i, id_val in enumerate(point_screenshot_ids):
                    if id_val == selected_id:
                        selected_paths.append(point_screenshots[i])
                        found = True
                        break
                
                if not found:
                    print(f"Warning: Selected ID {selected_id} not found in screenshots")
            
            print(f"GPT-4o selected {len(selected_paths)} out of {len(point_screenshots)} screenshots")
            
            # Add results for this knowledge point
            curated_results.append({
                "knowledge_point_index": knowledge_point_index,
                "knowledge_point": knowledge_point,
                "screenshot_groups": groups,
                "selected_screenshots": selected_paths,
                "selected_ids": selected_ids,
                "captions": captions
            })
            
        except Exception as e:
            print(f"Error calling GPT-4o API: {str(e)}")
            # Add empty result for this knowledge point
            curated_results.append({
                "knowledge_point_index": knowledge_point_index,
                "knowledge_point": knowledge_point,
                "screenshot_groups": [],
                "selected_screenshots": [],
                "captions": []
            })
    
    # Save combined results
    stage3_result = {
        "curated_knowledge_points": curated_results,
    }
    
    stage3_result_path = os.path.join(base_folder, "stage3_result.json")
    with open(stage3_result_path, 'w') as f:
        json.dump(stage3_result, f, indent=2)
    
    # Count total selected screenshots
    total_selected = sum(len(item.get("selected_screenshots", [])) for item in curated_results)
    print(f"\nStage 3 completed with {total_selected} selected screenshots across {len(curated_results)} knowledge points")
    return stage3_result

# Run Stage 3
stage3_result = stage3_curate_screenshots(stage1_result, stage2_result)

# Display a sample of the curated results
if stage3_result and 'curated_knowledge_points' in stage3_result:
    curated_points = stage3_result['curated_knowledge_points']
    if curated_points:
        # Find a point with screenshots to display
        for point in curated_points:
            if point.get('selected_screenshots'):
                print(f"\nSample Curated Point - Knowledge Point {point['knowledge_point_index']+1}:")
                print(f"Knowledge point: {point['knowledge_point'][:100]}..." if len(point['knowledge_point']) > 100 else point['knowledge_point'])
                print(f"Selected {len(point['selected_screenshots'])} screenshots")
                
                # Display first screenshot and caption
                if point['selected_screenshots'] and point['captions']:
                    sample_path = point['selected_screenshots'][0]
                    caption = point['captions'][0] if 0 < len(point['captions']) else "No caption"
                    print(f"Caption: {caption}")
                    display(Image(filename=sample_path, width=400))
                break

## Generate Final DOCX Document

In [None]:
# Generate final DOCX document

def generate_final_document(stage1_result, stage3_result):
    print("=== Generating Final DOCX Document ===")
    
    # Check if we have the required data
    if not stage1_result or 'Summary' not in stage1_result:
        print("Error: Missing required data from Stage 1")
        return None
    
    if not stage3_result or 'curated_knowledge_points' not in stage3_result:
        print("Error: Missing required data from Stage 3")
        return None
    
    curated_points = stage3_result['curated_knowledge_points']
    
    # Create document
    document = Document()
    
    # Add title
    title = document.add_heading('Training Document', 0)
    title.alignment = WD_ALIGN_PARAGRAPH.CENTER
    
    # Add Summary section
    document.add_heading('Summary', level=1)
    document.add_paragraph(stage1_result['Summary'])
    document.add_paragraph('')  # Add some space
    
    # Add knowledge points with screenshots
    document.add_heading('Knowledge Points', level=1)
    
    for point in curated_points:
        point_index = point['knowledge_point_index']
        knowledge_point = point['knowledge_point']
        
        # Add knowledge point as heading
        document.add_heading(f"{point_index+1}. {knowledge_point}", level=2)
        
        # Add screenshots with captions
        selected_screenshots = point.get('selected_screenshots', [])
