# LLM Service for Lecture AI

This notebook implements the LLM service endpoints required by the Flask backend:
- `/process` - Merges OCR + transcript and generates structured notes
- `/chat` - Answers questions about lectures

In [None]:
# ==============================
# INSTALL PACKAGES
# ==============================

# Install PyTorch with CUDA
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Install other libraries
!pip install -q fastapi uvicorn pyngrok nest_asyncio transformers

In [None]:
# ==============================
# IMPORTS
# ==============================
import nest_asyncio
import threading
import time
import torch
import json
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Dict, List, Optional, Any
from pyngrok import ngrok
import uvicorn
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

nest_asyncio.apply()

In [None]:
# ==============================
# LOAD MODEL
# ==============================
print("Loading model...")

model_name = "google/long-t5-tglobal-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

if torch.cuda.is_available():
    model = model.cuda()
    print("Using GPU")
else:
    print("Using CPU")

print("Model loaded successfully")

In [None]:
# ==============================
# HELPER FUNCTIONS
# ==============================

def generate_summary(text: str, max_length: int = 200, min_length: int = 50) -> str:
    """Generate summary using the loaded model"""
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=2048
    )

    if torch.cuda.is_available():
        inputs = {k: v.cuda() for k, v in inputs.items()}

    outputs = model.generate(
        **inputs,
        max_length=max_length,
        min_length=min_length,
        do_sample=False
    )

    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return summary

def extract_key_points(text: str) -> List[str]:
    """Extract key points from text (simplified - can be enhanced)"""
    sentences = text.split('.')
    key_points = [s.strip() for s in sentences if len(s.strip()) > 20][:5]
    return key_points

def merge_ocr_transcript(ocr_data: Dict, transcript_data: Dict) -> str:
    """Merge OCR and transcript data into a single text"""
    merged_text = []
    
    if ocr_data:
        if isinstance(ocr_data, dict):
            ocr_text = ocr_data.get('text', '') or json.dumps(ocr_data)
        else:
            ocr_text = str(ocr_data)
        merged_text.append(f"Board/Slide Content:\n{ocr_text}")
    
    if transcript_data:
        if isinstance(transcript_data, dict):
            transcript_text = transcript_data.get('text', '') or transcript_data.get('transcript', '') or json.dumps(transcript_data)
        else:
            transcript_text = str(transcript_data)
        merged_text.append(f"\nTranscript:\n{transcript_text}")
    
    return "\n\n".join(merged_text)

In [None]:
# ==============================
# FASTAPI APP & MODELS
# ==============================
app = FastAPI()

class ProcessRequest(BaseModel):
    job_id: str
    ocr_output: Dict[str, Any]
    transcript: Dict[str, Any]

class ChatRequest(BaseModel):
    lecture_id: str
    question: str
    context: Dict[str, Any]
    history: List[Dict[str, str]] = []

In [None]:
# ==============================
# API ENDPOINTS
# ==============================

@app.get("/")
def home():
    return {"message": "Lecture AI LLM Service Running"}

@app.post("/process")
def process_lecture(request: ProcessRequest):
    """
    Process OCR and transcript data to generate structured notes.
    Expected by Flask backend: POST /process
    """
    try:
        merged_text = merge_ocr_transcript(request.ocr_output, request.transcript)
        
        summary = generate_summary(merged_text, max_length=300, min_length=100)
        key_points = extract_key_points(merged_text)
        
        response = {
            "summary": summary,
            "key_points": key_points,
            "notes": {
                "ocr_content": request.ocr_output,
                "transcript_content": request.transcript,
                "merged_text": merged_text[:1000] + "..." if len(merged_text) > 1000 else merged_text
            }
        }
        
        return response
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")

@app.post("/chat")
def chat(request: ChatRequest):
    """
    Answer questions about a lecture using context.
    Expected by Flask backend: POST /chat
    """
    try:
        context = request.context
        
        summary = context.get('summary', '')
        notes = context.get('notes', {})
        transcript = context.get('transcript', {})
        
        context_text = f"Summary: {summary}\n\n"
        
        if isinstance(notes, dict):
            notes_text = json.dumps(notes)[:500]
            context_text += f"Notes: {notes_text}\n\n"
        
        if isinstance(transcript, dict):
            transcript_text = transcript.get('text', '') or json.dumps(transcript)[:500]
            context_text += f"Transcript: {transcript_text}\n\n"
        
        prompt = f"{context_text}\n\nQuestion: {request.question}\n\nAnswer:"
        
        answer = generate_summary(prompt, max_length=200, min_length=30)
        
        return {"answer": answer}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")

In [None]:
# ==============================
# START SERVER
# ==============================
def run():
    uvicorn.run(app, host="0.0.0.0", port=8000)

thread = threading.Thread(target=run, daemon=True)
thread.start()

time.sleep(5)

# ==============================
# START NGROK
# ==============================
# Replace with your ngrok auth token
ngrok.set_auth_token("YOUR_NGROK_AUTH_TOKEN_HERE")
public_url = ngrok.connect(8000)

print("\nðŸš€ LLM Service LIVE at:")
print(public_url)
print("\nðŸ“‹ Update your Flask .env file:")
print(f"LLM_SERVICE_URL={public_url}")