# D&D Session IC/OOC Classification Worker

This notebook runs on Google Colab to provide GPU-accelerated classification for the VideoChunking pipeline.

## Setup Instructions

1. **Open in Colab**: `File ‚Üí Open in Colab` or upload to Google Drive
2. **Enable GPU**: `Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator ‚Üí GPU ‚Üí T4`
3. **Mount Google Drive**: Run Cell 1
4. **Install Dependencies**: Run Cell 2
5. **Load Model**: Run Cell 3 (this may take a few minutes)
6. **Start Worker**: Run Cell 4 - this will continuously process jobs

## How It Works

- Your local pipeline uploads classification jobs to `VideoChunking/classification_pending/` in Google Drive
- This notebook watches that folder and processes jobs using a local LLM
- Results are written to `VideoChunking/classification_complete/`
- Your local pipeline polls for results and continues

Keep this notebook running while processing sessions!

In [1]:
# Cell 1: Mount Google Drive
from google.colab import drive
import os

drive.mount('/content/drive')

# Create classification directories if they don't exist
pending_dir = '/content/drive/MyDrive/VideoChunking/classification_pending'
complete_dir = '/content/drive/MyDrive/VideoChunking/classification_complete'

os.makedirs(pending_dir, exist_ok=True)
os.makedirs(complete_dir, exist_ok=True)

print(f"‚úì Google Drive mounted")
print(f"‚úì Pending jobs: {pending_dir}")
print(f"‚úì Completed jobs: {complete_dir}")

ValueError: mount failed

In [None]:
# Cell 2: Install Dependencies
!pip install -q transformers torch accelerate bitsandbytes

print("‚úì Dependencies installed")

In [None]:
# Cell 3: Load LLM Model
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Model selection - Qwen2.5-3B-Instruct (fits entirely on free T4 GPU)
# Smaller but still excellent for classification tasks
MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"

print(f"Loading model: {MODEL_NAME}")
print("This may take 2-5 minutes on first run...")

# Configure 8-bit quantization to fit on free GPU
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_compute_dtype=torch.float16
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

# Load model with 8-bit quantization
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True
)

# Print memory usage
if torch.cuda.is_available():
    print(f"‚úì Model loaded on GPU: {model.device}")
    print(f"‚úì GPU memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
else:
    print(f"‚úì Model loaded on CPU: {model.device}")
    print(f"‚ö† Warning: Running on CPU will be slower. Consider enabling GPU in Runtime settings.")

In [None]:
# Cell 4: Classification Functions
import json
import re
from pathlib import Path
from typing import List, Dict

def classify_segment(prompt: str, max_length: int = 512) -> str:
    """
    Classify a single segment using the loaded model.
    
    Args:
        prompt: Classification prompt
        max_length: Maximum tokens for generation
    
    Returns:
        Model response text
    """
    # Tokenize input
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=max_length
    ).to(model.device)
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode response
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return response.strip()


def build_prompt(segment_data: Dict, job_data: Dict) -> str:
    """
    Build classification prompt from segment and job data.
    
    Args:
        segment_data: Current segment with text
        job_data: Job containing character_names, player_names, prompt_template
    
    Returns:
        Formatted prompt string
    """
    segments = job_data['segments']
    idx = segment_data['index']
    
    prev_text = segments[idx-1]['text'] if idx > 0 else ""
    current_text = segment_data['text']
    next_text = segments[idx+1]['text'] if idx < len(segments) - 1 else ""
    
    char_list = ", ".join(job_data['character_names']) if job_data['character_names'] else "Unknown"
    player_list = ", ".join(job_data['player_names']) if job_data['player_names'] else "Unknown"
    
    return job_data['prompt_template'].format(
        char_list=char_list,
        player_list=player_list,
        prev_text=prev_text,
        current_text=current_text,
        next_text=next_text
    )


def parse_classification_response(response: str, index: int) -> Dict:
    """
    Parse model response into classification result.
    
    Expected format:
    Classificatie: IC|OOC|MIXED
    Reden: <reasoning>
    Vertrouwen: <0.0-1.0>
    Personage: <name or N/A>
    
    Args:
        response: Model response text
        index: Segment index
    
    Returns:
        Classification result dictionary
    """
    # Default values
    classification = "IC"
    confidence = 0.7
    reasoning = "Could not parse response"
    character = None
    
    # Extract classification
    class_match = re.search(r'Classificatie:\s*(\w+)', response, re.IGNORECASE)
    if class_match:
        classification = class_match.group(1).strip().upper()
    
    # Extract reasoning
    reden_match = re.search(
        r'Reden:\s*(.+?)(?=(?:Vertrouwen:|Personage:|$))',
        response,
        re.DOTALL | re.IGNORECASE
    )
    if reden_match:
        reasoning = reden_match.group(1).strip()
    
    # Extract confidence
    conf_match = re.search(r'Vertrouwen:\s*([\d.]+)', response, re.IGNORECASE)
    if conf_match:
        try:
            confidence = float(conf_match.group(1).strip())
            confidence = max(0.0, min(1.0, confidence))  # Clamp to [0, 1]
        except ValueError:
            pass
    
    # Extract character
    char_match = re.search(r'Personage:\s*(.+?)(?:\n|$)', response, re.IGNORECASE)
    if char_match:
        char_text = char_match.group(1).strip()
        if char_text.upper() != "N/A":
            character = char_text
    
    return {
        "segment_index": index,
        "classification": classification,
        "confidence": confidence,
        "reasoning": reasoning,
        "character": character
    }


def process_job(job_file: Path) -> None:
    """
    Process a single classification job.
    
    Args:
        job_file: Path to job JSON file
    """
    print(f"\n{'='*60}")
    print(f"Processing: {job_file.name}")
    
    # Load job data
    with open(job_file, 'r', encoding='utf-8') as f:
        job_data = json.load(f)
    
    job_id = job_data['job_id']
    segments = job_data['segments']
    
    print(f"Job ID: {job_id}")
    print(f"Segments to classify: {len(segments)}")
    
    # Classify each segment
    classifications = []
    for i, segment in enumerate(segments):
        segment_with_index = {**segment, 'index': i}
        prompt = build_prompt(segment_with_index, job_data)
        
        # Get classification from model
        response = classify_segment(prompt)
        result = parse_classification_response(response, i)
        
        classifications.append(result)
        
        # Progress indicator
        if (i + 1) % 10 == 0 or (i + 1) == len(segments):
            print(f"  Progress: {i+1}/{len(segments)} segments classified")
    
    # Write results
    result_file = Path(complete_dir) / f"{job_id}_result.json"
    result_data = {
        "job_id": job_id,
        "classifications": classifications
    }
    
    with open(result_file, 'w', encoding='utf-8') as f:
        json.dump(result_data, f, indent=2, ensure_ascii=False)
    
    print(f"‚úì Results written: {result_file.name}")
    print(f"{'='*60}\n")


print("‚úì Classification functions ready")

In [None]:
# Cell 5: Start Classification Worker
import time
from datetime import datetime

print("üöÄ Starting classification worker...")
print("üìÅ Watching:", pending_dir)
print("üì§ Results to:", complete_dir)
print("\nPress Ctrl+C (or interrupt kernel) to stop\n")
print("="*60)

# Track processed jobs to avoid reprocessing
processed_jobs = set()

try:
    while True:
        # Find pending jobs
        pending_path = Path(pending_dir)
        job_files = list(pending_path.glob("job_*.json"))
        
        # Filter out already processed jobs
        new_jobs = [f for f in job_files if f.name not in processed_jobs]
        
        if new_jobs:
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Found {len(new_jobs)} new job(s)")
            
            for job_file in new_jobs:
                try:
                    process_job(job_file)
                    processed_jobs.add(job_file.name)
                    
                    # Delete processed job file
                    job_file.unlink()
                    
                except Exception as e:
                    print(f"‚ùå Error processing {job_file.name}: {e}")
                    import traceback
                    traceback.print_exc()
        else:
            # No new jobs, wait
            print(f"[{datetime.now().strftime('%H:%M:%S')}] No pending jobs, waiting...")
        
        # Sleep before next check
        time.sleep(5)

except KeyboardInterrupt:
    print("\n\nüõë Worker stopped by user")
    print(f"Processed {len(processed_jobs)} jobs total")