# Retrieve Tos;dr data

## Load libraries

In [1]:
import os
import json
import asyncio
import re
import random
import hashlib
from typing import List, Dict
from google import genai
from google.genai import types
from dotenv import load_dotenv
from pathlib import Path

from rich.console import Console
from rich.progress import (
    Progress, SpinnerColumn, BarColumn, TextColumn, 
    TimeRemainingColumn, MofNCompleteColumn
)
from rich.panel import Panel
from rich.rule import Rule

## Global variables

In [2]:
ROOT = Path('../..')
DATA_DIR = ROOT / "data" / "TOSDR"
MARKDOWN_OUTPUT = DATA_DIR / "tosdr_markdowns_en.jsonl"
BATCH_INPUT_FILE = DATA_DIR / "tosdr_batch_input.jsonl"
HIGHLIGHTS_OUTPUT = DATA_DIR / "tosdr_summaries.jsonl"
ENV_FILE = ROOT / ".env"
model_name = "gemini-2.0-flash-lite"

# Limit for this specific execution
MAX_TO_PROCESS = 1000
CONCURRENCY_LIMIT = 5

console = Console()
load_dotenv(ENV_FILE)
client = genai.Client(api_key=os.getenv("GOOGLE_AI_API_KEY"))

In [3]:
SYSTEM_PROMPT = """You are an expert legal analyst for ToS;dr (Terms of Service; Didn't Read). 
Your mission is to analyze legal documents and classify clauses according to ToS;dr standards.

CLASSIFICATION LABELS (ToS;dr Standard):
- [GOOD]: Positive for user rights (e.g., clear refund policy, strong privacy protection, logs deleted quickly).
- [NEUTRAL]: Standard or balanced (e.g., age restrictions, standard governing law, reasonable liability caps).
- [BAD]: Negative for the user (e.g., tracking for ads, waiver of moral rights, no refunds, binding arbitration).
- [BLOCKER]: CRITICAL/DANGEROUS. The "scariest" clauses. (e.g., we sell your personal data, we can read your private messages, broad copyright license on your content).

OUTPUT FORMAT RULES:
1. Provide a bulleted list where each line follows this EXACT pattern:
   - [LABEL]: SHORT TITLE: Detailed explanation of why this is important.
2. Do NOT output JSON. Output plain text.
3. If a document is empty or irrelevant, return "NO_DATA".

Example Output:
- [BLOCKER]: Sale of Data: The service explicitly states they sell your personal data to third parties.
- [BAD]: Binding Arbitration: You waive your right to sue in court or join a class action.
- [GOOD]: Data Portability: You can download all your data in a standard format at any time.
"""

## Utilities functions

In [19]:
def parse_generative_output(text: str) -> List[Dict]:
    """Parse le texte pour extraire les labels, y compris BLOCKER."""
    results = []
    pattern = re.compile(r"^-\s*\[(BAD|GOOD|NEUTRAL|BLOCKER)\]\s*:\s*([^:]+):\s*(.+)$", re.MULTILINE)
    
    for match in pattern.finditer(text):
        results.append({
            "label": match.group(1).upper(),
            "title": match.group(2).strip(),
            "explanation": match.group(3).strip()
        })
    return results

async def extract_summaries(item, semaphore):
    async with semaphore:
        try:
            response = client.models.generate_content(
                model=model_name,
                config=types.GenerateContentConfig(
                    system_instruction=SYSTEM_PROMPT,
                    temperature=0.1, 
                ),
                contents=f"Analyze this document:\n\n{item['markdown'][:100000]}" 
            )
            text_out = response.text.strip() if response.text else ""
            return parse_generative_output(text_out)
        except Exception as e:
            console.print(f"[red]Error on {item['service_name']}: {e}[/red]")
            return []

async def main_dataset_builder():
    if not MARKDOWN_OUTPUT.exists():
        console.print("[red]Source file missing.[/red]")
        return

    processed_ids = set()
    content_cache = {} 
    
    if HIGHLIGHTS_OUTPUT.exists():
        with open(HIGHLIGHTS_OUTPUT, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    d = json.loads(line)
                    processed_ids.add(d.get("service_id"))
                   
                except: continue

    all_data = []
    with open(MARKDOWN_OUTPUT, "r", encoding="utf-8") as f:
        for line in f:
            try:
                d = json.loads(line)
                if d.get("status") == "success" and d["service_id"] not in processed_ids:
                    d["content_hash"] = hashlib.md5(d["markdown"].encode('utf-8')).hexdigest()
                    all_data.append(d)
            except: continue

    if not all_data:
        console.print("[bold green]✔ Tout est déjà traité ![/bold green]")
        return

    # 3. Traitement avec dédoublonnage de contenu
    if MAX_TO_PROCESS:
        all_data = all_data[:MAX_TO_PROCESS]

    semaphore = asyncio.Semaphore(CONCURRENCY_LIMIT)
    
    with Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        MofNCompleteColumn(),
        TimeRemainingColumn(),
        console=console,
    ) as progress:
        
        main_task = progress.add_task("[cyan]Analyse en cours...", total=len(all_data))
        
        with open(HIGHLIGHTS_OUTPUT, "a", encoding="utf-8") as f_out:
            for item in all_data:
                h = item["content_hash"]
                
                if h in content_cache:
                    progress.console.print(f"[yellow]⚡ Doublon de contenu détecté pour {item['service_name']} (utilisation du cache)[/yellow]")
                    summaries = content_cache[h]
                else:
                    progress.update(main_task, description=f"[cyan]Analyse {item['service_name']}...")
                    summaries = await extract_summaries(item, semaphore)
                    if summaries:
                        content_cache[h] = summaries 

                output_payload = {
                    "service_id": item["service_id"],
                    "service_name": item["service_name"],
                    "doc_name": item.get("doc_name", ""),
                    "url": item["url"],
                    "points": summaries
                }
                
                f_out.write(json.dumps(output_payload, ensure_ascii=False) + "\n")
                f_out.flush()
                progress.advance(main_task)

    console.print(f"\n[bold green]Terminé ![/bold green]")

def visualize_results(num_samples=3):
    console.print(Rule("[bold magenta]Visualisation ToS;dr[/bold magenta]"))

    if not HIGHLIGHTS_OUTPUT.exists():
        console.print("[red]Pas de fichier trouvé.[/red]")
        return

    data = []
    with open(HIGHLIGHTS_OUTPUT, "r", encoding="utf-8") as f:
        for line in f:
            try: data.append(json.loads(line))
            except: continue
    
    data = [d for d in data if d.get("points")]
    if not data: return

    selected = random.sample(data, min(num_samples, len(data)))

    for item in selected:
        console.print(f"\n[bold underline white on blue] SERVICE: {item['service_name']} [/bold underline white on blue] [italic]({item['doc_name']})[/italic]")
        
        weight_map = {'BLOCKER': 0, 'BAD': 1, 'NEUTRAL': 2, 'GOOD': 3}
        points = sorted(item['points'], key=lambda x: weight_map.get(x['label'], 99))
        
        for p in points:
            label = p['label']
            title = p['title']
            expl = p['explanation']
            
            if label == "BLOCKER":
                color = "white on red" 
                icon = "⛔"
                border = "red"
            elif label == "BAD":
                color = "red"
                icon = "❌"
                border = "red"
            elif label == "GOOD":
                color = "green"
                icon = "✅"
                border = "green"
            else:
                color = "yellow"
                icon = "ℹ️"
                border = "yellow"
            
            p_content = f"[{color}][bold]{title}[/bold][/{color}]\n[white]{expl}[/white]"
            
            console.print(Panel(
                p_content,
                title=f"{icon} {label}",
                title_align="left",
                border_style=border,
                width=100
            ))

In [None]:
await main_dataset_builder()

Output()

In [None]:
visualize_results()

## Train/validation split

In [6]:
import json
import hashlib
from pathlib import Path
from sklearn.model_selection import train_test_split

ROOT = Path('../..')
DATA_DIR = ROOT / "data" / "TOSDR"
MARKDOWN_SOURCE = DATA_DIR / "tosdr_markdowns_en.jsonl" 
SUMMARIES_SOURCE = DATA_DIR / "tosdr_summaries.jsonl"  
OUTPUT_DIR = ROOT / "data" / "EULAI"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

def generate_unique_id(text):
    return hashlib.md5(text.encode('utf-8')).hexdigest()

def process_dataset():
    markdowns = {}
    if not MARKDOWN_SOURCE.exists():
        print(f"Erreur: {MARKDOWN_SOURCE} introuvable.")
        return

    with open(MARKDOWN_SOURCE, "r", encoding="utf-8") as f:
        for line in f:
            try:
                data = json.loads(line)
                if data.get("status") == "success":
                    markdowns[data["service_id"]] = data.get("markdown", "")
            except: continue

    unique_policies = {} 
    
    if not SUMMARIES_SOURCE.exists():
        print(f"Erreur: {SUMMARIES_SOURCE} introuvable.")
        return

    with open(SUMMARIES_SOURCE, "r", encoding="utf-8") as f:
        for line in f:
            try:
                item = json.loads(line)
                s_id = item["service_id"]
                
                policy_text = markdowns.get(s_id, "")
                
                if not policy_text or not item.get("points"):
                    continue
                    
                summary_text = "\n".join([
                    f"- [{p['label']}]: {p['title']}: {p['explanation']}" 
                    for p in item["points"]
                ])
                
                policy_hash = generate_unique_id(policy_text)
                
                if policy_hash not in unique_policies:
                    unique_policies[policy_hash] = {
                        "id": policy_hash,
                        "service_id": s_id, 
                        "service_name": item["service_name"],
                        "url": item["url"],
                        "policy": policy_text,
                        "summary": summary_text
                    }
            except: continue

    dataset = list(unique_policies.values())
    n_samples = len(dataset)
    print(f"Échantillons uniques trouvés : {n_samples}")

    if n_samples == 0:
        print("Erreur : Aucun échantillon valide trouvé. Vérifie les IDs dans tes fichiers.")
        return

    test_size = 0.05 if n_samples > 5 else 1 
    train_data, test_data = train_test_split(dataset, test_size=test_size, random_state=42)

    for name, data in [("train", train_data), ("test", test_data)]:
        output_path = OUTPUT_DIR / f"{name}.jsonl"
        with open(output_path, "w", encoding="utf-8") as f:
            for entry in data:
                f.write(json.dumps(entry, ensure_ascii=False) + "\n")
        print(f"Sauvegardé : {output_path} ({len(data)} lignes)")

process_dataset()

Échantillons uniques trouvés : 2956
Sauvegardé : ../../data/EULAI/train.jsonl (2808 lignes)
Sauvegardé : ../../data/EULAI/test.jsonl (148 lignes)
