In [None]:
from flask import Flask, request, render_template, redirect, url_for, flash, session
from transformers import BertTokenizer, BertForSequenceClassification, BertForTokenClassification, BertForQuestionAnswering
from transformers import pipeline
import torch
import cv2
import numpy as np
import pytesseract
import re
import random
import firebase_admin
from firebase_admin import credentials, firestore
from datetime import datetime
import json

In [None]:
# ## Run the Application
# !python -m pip install pytesseract
# !py -m pip install pytesseract
pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'


## Initialize Flask App

In a Jupyter notebook environment, we'll still define the Flask app but we won't run it directly.

In [None]:
# Initialize the Flask app
app = Flask(__name__)
app.secret_key = 'legal_education_platform_secret'  # Secret key for flash messages and sessions

## Load Pre-trained Models

We'll load the NLP models from Hugging Face's transformers library.

In [None]:
# Load pre-trained models from Hugging Face
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
ner_tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-large-cased-finetuned-conll03-english')
ner_model = BertForTokenClassification.from_pretrained('dbmdz/bert-large-cased-finetuned-conll03-english')
qa_tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
qa_model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

# Set up pipelines for NER and QA
ner_pipe = pipeline('ner', model=ner_model, tokenizer=ner_tokenizer)
qa_pipe = pipeline('question-answering', model=qa_model, tokenizer=qa_tokenizer)

## Firebase Initialization

Initialize Firebase for data storage.

In [None]:
# Firebase initialization
cred = credentials.Certificate("firebase_config.json")  # Add your Firebase service account key
firebase_admin.initialize_app(cred)
db = firestore.client()

## Constants and Data Definitions

Define the patterns, document types, and achievements for our application.

In [None]:
# Enhanced legal clauses patterns
CLAUSES = {
    "termination": r"(termination.*?)(?:\n|\.)",
    "indemnity": r"(indemnity.*?)(?:\n|\.)",
    "governing law": r"(governing law.*?)(?:\n|\.)",
    "confidentiality": r"(confidentiality.*?)(?:\n|\.)",
    "intellectual property": r"(intellectual property.*?)(?:\n|\.)",
    "liability": r"(liability.*?)(?:\n|\.)",
    "force majeure": r"(force majeure.*?)(?:\n|\.)",
    "payment terms": r"(payment terms.*?)(?:\n|\.)",
    "arbitration": r"(arbitration.*?)(?:\n|\.)",
    "warranties": r"(warranties.*?)(?:\n|\.)"
}

# Define legal document types and their characteristics
DOCUMENT_TYPES = {
    0: "Employment Contract",
    1: "Non-Disclosure Agreement",
    2: "Service Agreement",
    3: "Purchase Agreement",
    4: "License Agreement",
    5: "Lease Agreement"
}

# Achievement definitions
ACHIEVEMENTS = {
    "first_scan": {"name": "Legal Novice", "description": "Scanned your first document", "xp": 50},
    "scan_milestone_5": {"name": "Legal Apprentice", "description": "Scanned 5 documents", "xp": 100},
    "scan_milestone_10": {"name": "Legal Expert", "description": "Scanned 10 documents", "xp": 200},
    "quiz_perfect": {"name": "Perfect Score", "description": "Got all answers correct in a quiz", "xp": 150},
    "unique_clauses_5": {"name": "Clause Hunter", "description": "Discovered 5 different legal clauses", "xp": 125},
    "all_doc_types": {"name": "Document Master", "description": "Analyzed all document types", "xp": 300},
}

## Core Functions

Now let's define all the core functions for our application.

In [None]:
def classify_text_with_bert(text):
    """Classify text using BERT model"""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Get the predicted class
    predicted_class = torch.argmax(logits, dim=-1).item()
    return predicted_class % len(DOCUMENT_TYPES)  # Ensure we stay within our defined types

In [None]:
def extract_entities(text):
    """Extract named entities from text"""
    entities = ner_pipe(text)
    named_entities = []
    
    # Process and group entities
    current_entity = ""
    current_label = ""
    current_score_sum = 0
    count = 0
    
    for entity in entities:
        if entity['word'].startswith('##'):
            # Continue previous entity
            current_entity += entity['word'][2:]  # Remove the ## prefix
            current_score_sum += entity['score']
            count += 1
        else:
            # Save previous entity if it exists
            if current_entity and count > 0:
                named_entities.append({
                    'entity': current_entity,
                    'label': current_label,
                    'score': current_score_sum / count  # Average score
                })
            
            # Start new entity
            current_entity = entity['word']
            current_label = entity['entity']
            current_score_sum = entity['score']
            count = 1
    
    # Add the last entity
    if current_entity and count > 0:
        named_entities.append({
            'entity': current_entity,
            'label': current_label,
            'score': current_score_sum / count
        })
    
    # Filter entities with high confidence (score > 0.8)
    return [e for e in named_entities if e['score'] > 0.8]

In [None]:
def answer_question(context, question):
    """Get answer to a question from the context"""
    result = qa_pipe(question=question, context=context)
    return result['answer']

In [None]:
def generate_advanced_quiz(text, clauses):
    """Generate an enhanced quiz based on document content and extracted clauses"""
    questions = []
    
    # Question types based on extracted clauses
    for key, clause in clauses.items():
        # Multiple choice question about the clause
        clause_content = clause.split(":")[1].strip() if ":" in clause else clause
        
        # Generate incorrect options that are plausible but wrong
        incorrect_options = [
            f"The {key} clause allows for immediate termination without notice",
            f"The {key} clause requires written approval from all parties involved",
            f"The {key} clause limits liability to $10,000 for each occurrence"
        ]
        
        options = [clause_content] + incorrect_options[:3]  # Use at most 3 incorrect options
        random.shuffle(options)
        
        questions.append({
            "type": "multiple_choice",
            "question": f"What does the '{key.title()}' clause specify in this document?",
            "options": options,
            "correct_answer": clause_content,
            "difficulty": "medium",
            "points": 10
        })
    
    # Generate true/false questions
    if len(clauses) > 0:
        # Sample some clauses to create true/false questions
        sampled_clauses = random.sample(list(clauses.items()), min(2, len(clauses)))
        for key, clause in sampled_clauses:
            # True statement
            questions.append({
                "type": "true_false",
                "question": f"This document contains a {key} clause.",
                "correct_answer": "True",
                "difficulty": "easy",
                "points": 5
            })
    
    # Generate short answer questions using QA model if we have enough text
    if len(text) > 100:
        # Create questions based on document content
        potential_questions = [
            "Who are the parties involved in this agreement?",
            "What is the effective date of this document?",
            "What happens if one party breaches this agreement?",
            "Is there a notice period specified in the document?"
        ]
        
        # Sample 2 questions randomly
        sampled_questions = random.sample(potential_questions, min(2, len(potential_questions)))
        for question in sampled_questions:
            try:
                answer = answer_question(text, question)
                if len(answer) > 2:  # Ensure we got a meaningful answer
                    questions.append({
                        "type": "short_answer",
                        "question": question,
                        "correct_answer": answer,
                        "difficulty": "hard",
                        "points": 15
                    })
            except Exception as e:
                # Skip this question if there's an error
                print(f"Error generating question: {e}")
                continue
    
    # Shuffle the questions
    random.shuffle(questions)
    
    # Limit to 5 questions total
    return questions[:5]

In [None]:
def check_and_award_achievements(user_id):
    """Check user's progress and award any earned achievements"""
    user_ref = db.collection("users").document(user_id)
    user_doc = user_ref.get()
    
    if not user_doc.exists:
        # New user, award first scan achievement
        return [ACHIEVEMENTS["first_scan"]]
    
    user_data = user_doc.to_dict()
    earned_achievements = []
    
    # Get user stats
    scan_count = user_data.get("scan_count", 0)
    achievements = user_data.get("achievements", [])
    unique_clauses = user_data.get("unique_clauses", set())
    if isinstance(unique_clauses, list):
        unique_clauses = set(unique_clauses)
    doc_types = user_data.get("doc_types", set())
    if isinstance(doc_types, list):
        doc_types = set(doc_types)
    
    # Check scan milestones
    if scan_count == 5 and "scan_milestone_5" not in achievements:
        earned_achievements.append(ACHIEVEMENTS["scan_milestone_5"])
    
    if scan_count == 10 and "scan_milestone_10" not in achievements:
        earned_achievements.append(ACHIEVEMENTS["scan_milestone_10"])
    
    # Check unique clauses
    if len(unique_clauses) >= 5 and "unique_clauses_5" not in achievements:
        earned_achievements.append(ACHIEVEMENTS["unique_clauses_5"])
    
    # Check document types
    if len(doc_types) == len(DOCUMENT_TYPES) and "all_doc_types" not in achievements:
        earned_achievements.append(ACHIEVEMENTS["all_doc_types"])
    
    return earned_achievements

In [None]:
def save_to_firebase(text, quiz, clauses, doc_type, user_id="anonymous"):
    """Save scan data and update user stats in Firebase"""
    # Save the scan
    db.collection("scans").add({
        "user_id": user_id,
        "text": text,
        "quiz": quiz,
        "clauses": list(clauses.keys()),
        "doc_type": doc_type,
        "xp": sum(q["points"] for q in quiz),
        "timestamp": datetime.utcnow().isoformat()
    })

    # Update user stats
    user_ref = db.collection("users").document(user_id)
    user_doc = user_ref.get()
    
    if user_doc.exists:
        user_data = user_doc.to_dict()
        current_xp = user_data.get("xp", 0)
        scan_count = user_data.get("scan_count", 0) + 1
        
        # Track unique clauses
        unique_clauses = set(user_data.get("unique_clauses", []))
        unique_clauses.update(clauses.keys())
        
        # Track document types
        doc_types = set(user_data.get("doc_types", []))
        doc_types.add(doc_type)
        
        # Track achievements
        achievements = user_data.get("achievements", [])
        
        user_ref.set({
            "xp": current_xp + sum(q["points"] for q in quiz),
            "scan_count": scan_count,
            "unique_clauses": list(unique_clauses),
            "doc_types": list(doc_types),
            "achievements": achievements,
            "updated_at": datetime.utcnow().isoformat()
        }, merge=True)
    else:
        # Create new user
        user_ref.set({
            "user_id": user_id,
            "xp": sum(q["points"] for q in quiz),
            "scan_count": 1,
            "unique_clauses": list(clauses.keys()),
            "doc_types": [doc_type],
            "achievements": ["first_scan"],
            "created_at": datetime.utcnow().isoformat(),
            "updated_at": datetime.utcnow().isoformat()
        })

    # Check and award achievements
    earned_achievements = check_and_award_achievements(user_id)
    
    # Update user with earned achievements and XP
    if earned_achievements:
        achievement_ids = [a["name"] for a in earned_achievements]
        achievement_xp = sum(a["xp"] for a in earned_achievements)
        
        user_doc = user_ref.get().to_dict()
        current_xp = user_doc.get("xp", 0)
        achievements = user_doc.get("achievements", [])
        achievements.extend([a["name"] for a in earned_achievements])
        
        user_ref.set({
            "xp": current_xp + achievement_xp,
            "achievements": list(set(achievements)),  # Remove duplicates
        }, merge=True)
    
    return earned_achievements

## Flask Routes

In a Jupyter environment, we'll define the Flask routes but they won't be active. This is mainly for reference and completeness.

In [None]:
@app.route('/')
def index():
    return render_template("index.html")

In [None]:
@app.route("/scan", methods=["POST"])
def scan_document():
    file = request.files["document"]
    user_id = request.form.get("user_id", "anonymous")
    
    # Read and process the image
    image = cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_COLOR)
    text = pytesseract.image_to_string(image)

    # Classify the extracted text using BERT
    classification_result = classify_text_with_bert(text)
    doc_type = DOCUMENT_TYPES[classification_result]

    # Extract clauses using regex
    extracted_clauses = {}
    for key, pattern in CLAUSES.items():
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            extracted_clauses[key] = match.group(1)

    # Generate enhanced quiz based on extracted clauses
    quiz = generate_advanced_quiz(text, extracted_clauses)

    # Extract entities using NER
    entities = extract_entities(text)

    # Answer a sample question using QA model
    sample_question = "What is the main purpose of this document?"
    answer = answer_question(text, sample_question)

    # Save the scan and quiz data to Firebase and get earned achievements
    earned_achievements = save_to_firebase(text, quiz, extracted_clauses, doc_type, user_id)

    return render_template(
        "results.html", 
        doc_type=doc_type,
        clauses=extracted_clauses, 
        quiz=quiz, 
        classification=classification_result, 
        entities=entities, 
        answer=answer,
        achievements=earned_achievements
    )

In [None]:
@app.route("/analyze_document", methods=["POST"])
def analyze_document():
    file = request.files["document"]
    user_id = request.form.get("user_id", "anonymous")
    
    # Process the file - both image and text documents
    if file.filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')):
        # Process image file with OCR
        image = cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_COLOR)
        text = pytesseract.image_to_string(image)
    else:
        # Handle text-based documents (PDFs would need additional processing)
        text = file.read().decode('utf-8', errors='ignore')
    
    # Always use auto-detection for document type
    classification_result = classify_text_with_bert(text)
    doc_type = DOCUMENT_TYPES[classification_result]

    # Extract clauses using regex
    extracted_clauses = {}
    for key, pattern in CLAUSES.items():
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            extracted_clauses[key] = match.group(1)

    # Generate enhanced quiz based on extracted clauses
    quiz = generate_advanced_quiz(text, extracted_clauses)

    # Extract entities using NER
    entities = extract_entities(text)

    # Save the scan and quiz data to Firebase and get earned achievements
    earned_achievements = save_to_firebase(text, quiz, extracted_clauses, doc_type, user_id)

    # Get recent scans for this user to display
    recent_scans = db.collection("scans").where("user_id", "==", user_id).order_by(
        "timestamp", direction=firestore.Query.DESCENDING
    ).limit(5).stream()
    
    recent_scan_data = []
    for scan in recent_scans:
        scan_data = scan.to_dict()
        scan_data["id"] = scan.id
        recent_scan_data.append(scan_data)

    return render_template(
        "results.html", 
        doc_type=doc_type,
        clauses=extracted_clauses, 
        quiz=quiz,
        entities=entities,
        user_id=user_id,
        recent_scans=recent_scan_data,
        achievements=earned_achievements
    )

In [None]:
@app.route("/submit_quiz", methods=["POST"])
def submit_quiz():
    scan_id = request.args.get("scan_id", "")
    user_id = request.args.get("user", "anonymous")
    
    # Get the quiz questions from Firebase
    scan_doc = db.collection("scans").document(scan_id).get()
    if not scan_doc.exists:
        flash("Quiz not found.")
        return redirect(url_for("index"))
    
    scan_data = scan_doc.to_dict()
    quiz = scan_data.get("quiz", [])
    
    # Score the quiz
    score = 0
    total_possible = sum(q["points"] for q in quiz)
    answers = []
    
    for i, question in enumerate(quiz):
        user_answer = request.form.get(f"q{i}", "")
        correct = False
        
        if user_answer == question["correct_answer"]:
            score += question["points"]
            correct = True
        
        answers.append({
            "question": question["question"],
            "user_answer": user_answer,
            "correct_answer": question["correct_answer"],
            "correct": correct,
            "points": question["points"] if correct else 0
        })
    
    # Check for perfect score achievement
    if score == total_possible:
        # Update user with perfect score achievement
        user_ref = db.collection("users").document(user_id)
        user_doc = user_ref.get()
        
        if user_doc.exists:
            user_data = user_doc.to_dict()
            achievements = user_data.get("achievements", [])
            
            if "quiz_perfect" not in achievements:
                achievements.append("quiz_perfect")
                current_xp = user_data.get("xp", 0)
                
                user_ref.update({
                    "achievements": achievements,
                    "xp": current_xp + ACHIEVEMENTS["quiz_perfect"]["xp"]
                })
    
    # Save quiz results
    db.collection("quiz_results").add({
        "user_id": user_id,
        "scan_id": scan_id,
        "score": score,
        "total_possible": total_possible,
        "answers": answers,
        "timestamp": datetime.utcnow().isoformat()
    })
    
    return render_template(
        "quiz_results.html",
        score=score,
        total_possible=total_possible,
        answers=answers,
        perfect_score=(score == total_possible)
    )

In [None]:
@app.route("/leaderboard")
def leaderboard():
    # Get top users by XP
    users_ref = db.collection("users").order_by("xp", direction=firestore.Query.DESCENDING).limit(10)
    users = []
    
    for doc in users_ref.stream():
        user_data = doc.to_dict()
        user_id = user_data.get("user_id", doc.id)
        xp = user_data.get("xp", 0)
        level = xp // 100
        scan_count = user_data.get("scan_count", 0)
        achievement_count = len(user_data.get("achievements", []))
        
        users.append({
            "user_id": user_id,
            "xp": xp,
            "level": level,
            "scan_count": scan_count,
            "achievement_count": achievement_count
        })
    
    # Get current user's rank if provided
    current_user = request.args.get("user")
    current_user_rank = None
    
    if current_user:
        # Get all users
        all_users = list(db.collection("users").order_by("xp", direction=firestore.Query.DESCENDING).stream())
        
        # Find current user's position
        for i, user_doc in enumerate(all_users):
            if user_doc.id == current_user or user_doc.to_dict().get("user_id") == current_user:
                current_user_rank = i + 1
                break
    
    return render_template(
        "leaderboard.html",
        users=users,
        current_user=current_user,
        current_user_rank=current_user_rank
    )

In [None]:
@app.route('/take_quiz')
def take_quiz():
    scan_id = request.args.get('scan_id', '')
    user_id = request.args.get('user', '')
    
    # Fetch the quiz data based on scan_id
    # This is a placeholder - replace with your actual database query
    try:
        # Example: Fetch from Firestore
        doc_ref = db.collection('scans').document(scan_id)
        scan_doc = doc_ref.get()
        
        if scan_doc.exists:
            scan_data = scan_doc.to_dict()
            quiz_data = scan_data.get('quiz', [])
        else:
            # Handle case where scan doesn't exist
            flash("Scan not found", "error")
            return redirect(url_for('index'))
            
    except Exception as e:
        # Handle any errors
        flash(f"Error retrieving quiz: {str(e)}", "error")
        return redirect(url_for('index'))
    
    # Render the quiz template with the retrieved data
    return render_template('qiuz.html', scan_id=scan_id, user_id=user_id, quiz=quiz_data)

In [None]:
@app.route("/dashboard")
def dashboard():
    user_id = request.args.get("user", "anonymous")
    user_ref = db.collection("users").document(user_id)
    user_data = user_ref.get().to_dict()

    xp = user_data["xp"] if user_data else 0
    level = xp // 100
    progress = xp % 100

    return render_template("dashboard.html", xp=xp, level=level, progress=progress)

In [None]:
# Import threading to run Flask in a separate thread
import threading

def run_flask():
    app.run(debug=False, use_reloader=False)  # Disable reloader and debug for notebook use

# Start Flask in a separate thread
flask_thread = threading.Thread(target=run_flask)
flask_thread.daemon = True  # This makes the thread exit when the notebook is closed
flask_thread.start()

print("Flask server is running in the background. Access it at http://127.0.0.1:5000")