# Lovli - Index Laws to Qdrant (Colab A100 GPU)

This notebook indexes all Norwegian laws and regulations into Qdrant Cloud using an A100 GPU for fast embedding generation.

**Requirements:**
- Colab A100 GPU runtime (Runtime > Change runtime type > A100)
- `lovli-data.tar.bz2` in your Google Drive (root folder)
- Qdrant Cloud URL and API key

**Estimated time:** ~20-30 minutes for ~4,000 files

## 1. Setup

In [None]:
import os
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
# Set HF_TOKEN if you have one (reduces rate limit warnings):
# os.environ["HF_TOKEN"] = "your_token_here"

In [None]:
!pip install -q sentence-transformers qdrant-client beautifulsoup4
# Clone repo into Colab runtime (fresh session)
%cd /content
!rm -rf lovli
!git clone https://github.com/AndreasRamsli/lovli.git
%cd /content/lovli

# Install project package so `from lovli.parser import ...` works
!pip install -q -e .

# Optional: verify parser source
import lovli.parser as lp
print(lp.__file__)

In [None]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    name = torch.cuda.get_device_name(0)
    props = torch.cuda.get_device_properties(0)
    vram_gb = props.total_memory / (1024**3)
    print(f"GPU: {name} ({vram_gb:.1f} GB VRAM)")
    if "A100" not in name:
        print("  Note: Optimized for A100; other GPUs may need smaller batch sizes")
else:
    print("WARNING: No GPU detected. Go to Runtime > Change runtime type > A100 GPU")

## 2. Configuration

Fill in your Qdrant Cloud credentials:

In [None]:
# --- FILL THESE IN ---
QDRANT_URL = "https://acc5c492-7d2c-4b95-b0c5-2931ff2ecebd.eu-west-1-0.aws.cloud.qdrant.io"
QDRANT_API_KEY = ""  # Paste your Qdrant API key here, or use getpass below
# ---------------------

if not QDRANT_API_KEY:
    import getpass
    QDRANT_API_KEY = getpass.getpass("Qdrant API key: ")

COLLECTION_NAME = "lovli_laws"
EMBEDDING_MODEL = "BAAI/bge-m3"
EMBEDDING_DIMENSION = 1024
EMBEDDING_BATCH_SIZE = 256  # A100 80GB can handle large batches
INDEX_BATCH_SIZE = 500      # Upsert batch size to Qdrant

# Network/retry tuning for Qdrant Cloud
QDRANT_TIMEOUT_SECONDS = 120
UPSERT_MAX_RETRIES = 5
UPSERT_BACKOFF_SECONDS = 2

assert QDRANT_API_KEY, "Please set QDRANT_API_KEY above"

## 3. Data

Mount Google Drive and extract `lovli-data.tar.bz2` directly from Drive (no copy step).

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
DATA_DIRS = ["data/nl", "data/sf"]

In [None]:
# Extract directly into the cloned repo (skip macOS ._ resource fork files).
!tar -xjf /content/drive/MyDrive/lovli-data.tar.bz2 -C /content/lovli --exclude='._*'
!ls /content/lovli/data/nl/*.xml 2>/dev/null | wc -l && ls /content/lovli/data/sf/*.xml 2>/dev/null | wc -l

## 4. Parser (from lovli/parser.py)

In [None]:
import hashlib
import logging
import re
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterator

from bs4 import BeautifulSoup, Tag

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

LOVDATA_BASE_URL = "https://lovdata.no"


@dataclass(slots=True)
class LegalArticle:
    article_id: str
    title: str
    content: str
    law_id: str
    law_title: str
    law_short_name: str | None = None
    chapter_id: str | None = None
    chapter_title: str | None = None
    cross_references: list[str] = field(default_factory=list)
    url: str | None = None


def _extract_law_ref_from_filename(filename: str) -> str:
    parts = filename.split("-")
    if len(parts) >= 3:
        prefix = parts[0]
        date_part = parts[1]
        num_part = parts[2]
        if len(date_part) == 8:
            year, month, day = date_part[:4], date_part[4:6], date_part[6:8]
            ref_prefix = "forskrift" if prefix == "sf" else "lov"
            return f"{ref_prefix}/{year}-{month}-{day}-{num_part}"
    return filename


def _extract_short_name(soup: BeautifulSoup) -> str | None:
    short_elem = soup.find("dd", class_="titleShort")
    if not short_elem:
        return None
    text = short_elem.get_text(strip=True)
    for sep in (" \u2013 ", " \u2014 ", " - "):
        if sep in text:
            return text.split(sep)[0].strip()
    return text.strip() or None


def _extract_cross_references(article_element: Tag, self_law_ref: str) -> list[str]:
    refs, seen = [], set()
    for a_tag in article_element.find_all("a", href=True):
        href = a_tag.get("href", "")
        if not href or not (href.startswith("lov/") or href.startswith("forskrift/")):
            continue
        if self_law_ref and href.startswith(self_law_ref):
            continue
        base_href = href.split("#")[0] if "#" in href else href
        if base_href not in seen:
            seen.add(base_href)
            refs.append(base_href)
    return refs


def _extract_article(article, idx, law_id, law_ref, law_title_text, law_short_name, chapter_id, chapter_title, xml_path):
    try:
        article_id = article.get("id") or f"{law_id}_art_{idx}"
        if "-ledd-" in article_id or "-punkt-" in article_id:
            return None
        h3 = article.find("h3")
        title_text = h3.get_text(strip=True) if h3 else "Untitled Article"
        article_content = article.get_text(separator="\n", strip=True)
        if not article_content.strip():
            return None
        cross_refs = _extract_cross_references(article, law_ref)
        url = f"{LOVDATA_BASE_URL}/{law_ref}#{article_id}"
        return LegalArticle(
            article_id=article_id, title=title_text, content=article_content,
            law_id=law_id, law_title=law_title_text, law_short_name=law_short_name,
            chapter_id=chapter_id, chapter_title=chapter_title,
            cross_references=cross_refs, url=url,
        )
    except Exception as e:
        logger.error(f"Error processing article {idx} in {xml_path}: {e}")
        return None


def parse_xml_file(xml_path: Path) -> Iterator[LegalArticle]:
    with open(xml_path, "r", encoding="utf-8") as f:
        content = f.read()
    if not content.strip():
        return
    soup = BeautifulSoup(content, "html.parser")
    law_id = xml_path.stem
    law_ref = _extract_law_ref_from_filename(law_id)
    title_elem = soup.find("dd", class_="title") or soup.find("title")
    law_title = title_elem.get_text(strip=True) if title_elem else "Unknown Law"
    law_short_name = _extract_short_name(soup)

    sections = soup.find_all("section", id=re.compile(r"^kapittel-\d+[a-zA-Z]?$"))
    if sections:
        for section in sections:
            chapter_id = section.get("id", "")
            h2 = section.find("h2")
            chapter_title = h2.get_text(strip=True) if h2 else None
            if chapter_title:
                match = re.match(r"^Kapittel\s+\d+[A-Za-z]?\.\s*", chapter_title)
                if match:
                    chapter_title = chapter_title[match.end():].strip() or chapter_title
            for idx, art in enumerate(section.find_all("article")):
                result = _extract_article(art, idx, law_id, law_ref, law_title, law_short_name, chapter_id, chapter_title, xml_path)
                if result:
                    yield result
    else:
        for idx, art in enumerate(soup.find_all("article")):
            result = _extract_article(art, idx, law_id, law_ref, law_title, law_short_name, None, None, xml_path)
            if result:
                yield result

print("Parser loaded.")

In [None]:
# Parser override: use production parser implementation directly.
from lovli.parser import LegalArticle, parse_law_header, parse_xml_file

print("Parser override loaded from lovli.parser (single source of truth).")

In [None]:
# Pre-index parser sanity check: verify production parser and required metadata fields.
import lovli.parser as lp
from pathlib import Path

print(f"Parser module path: {lp.__file__}")

sample = Path("data/nl/nl-19990326-017.xml")
if sample.exists():
    sample_articles = list(parse_xml_file(sample))
    assert sample_articles, "Sample parse returned no articles"
    first = sample_articles[0]
    assert hasattr(first, "source_anchor_id"), "source_anchor_id missing; parser import is wrong"
    assert hasattr(first, "doc_type"), "doc_type missing; parser import is wrong"
    assert first.doc_type in {"provision", "editorial_note"}, "Unexpected doc_type value"
    print(
        "Parser sanity OK:",
        {
            "sample_articles": len(sample_articles),
            "first_article_id": first.article_id,
            "first_source_anchor_id": first.source_anchor_id,
            "first_doc_type": first.doc_type,
        },
    )
else:
    print("WARNING: sample file data/nl/nl-19990326-017.xml not found; skipping parser sanity parse.")

## 5. Initialize Qdrant + Embedding Model

In [None]:
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct

# Connect to Qdrant Cloud
client = QdrantClient(
    url=QDRANT_URL,
    api_key=QDRANT_API_KEY,
    timeout=QDRANT_TIMEOUT_SECONDS,
)
print(f"Connected to Qdrant. Collections: {[c.name for c in client.get_collections().collections]}")

# Load embedding model on GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading {EMBEDDING_MODEL} on {device}...")
model = SentenceTransformer(EMBEDDING_MODEL, device=device)
# FP16 for ~2x throughput on A100
if device == "cuda":
    model = model.half()
print(f"Model loaded on {device} (FP16). Embedding dim: {model.get_sentence_embedding_dimension()}")

## 6. Create Collection

In [None]:
# Recreate collection (delete if exists)
if client.collection_exists(COLLECTION_NAME):
    print(f"Deleting existing collection: {COLLECTION_NAME}")
    client.delete_collection(COLLECTION_NAME)

client.create_collection(
    collection_name=COLLECTION_NAME,
    vectors_config=VectorParams(
        size=EMBEDDING_DIMENSION,
        distance=Distance.COSINE,
    ),
    on_disk_payload=True,
)
print(f"Collection '{COLLECTION_NAME}' created (dense-only, payloads on disk)")

## 7. Index All Files

In [None]:
import json


def generate_id(law_id: str, source_anchor_id: str | None, article_id: str) -> int:
    """Stable 63-bit point ID based on law + stable source identity."""
    stable_source_id = source_anchor_id or article_id
    key = f"{law_id}::{stable_source_id}"
    hash_bytes = hashlib.sha256(key.encode("utf-8")).digest()
    return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False) % (2**63)


def encode_texts_with_fallback(texts, batch_size):
    """Encode texts with OOM fallback: retry with batch 16, then CPU as last resort."""
    try:
        return model.encode(texts, batch_size=batch_size, show_progress_bar=False, normalize_embeddings=True)
    except RuntimeError as e:
        if "out of memory" not in str(e).lower():
            raise
        if device == "cuda":
            torch.cuda.empty_cache()
        try:
            return model.encode(texts, batch_size=16, show_progress_bar=False, normalize_embeddings=True)
        except RuntimeError as e2:
            if "out of memory" not in str(e2).lower():
                raise
            if device != "cuda":
                raise
            torch.cuda.empty_cache()
            model_cpu = model.float().to("cpu")
            result = model_cpu.encode(texts, batch_size=32, show_progress_bar=False, normalize_embeddings=True)
            model.to(device).half()
            return result


def upsert_with_retry(points, upsert_batch_size=INDEX_BATCH_SIZE):
    """Upsert to Qdrant with retry/backoff to handle transient write timeouts."""
    for i in range(0, len(points), upsert_batch_size):
        batch = points[i:i + upsert_batch_size]
        for attempt in range(1, UPSERT_MAX_RETRIES + 1):
            try:
                client.upsert(
                    collection_name=COLLECTION_NAME,
                    points=batch,
                    wait=True,
                )
                break
            except Exception as e:
                is_last = attempt == UPSERT_MAX_RETRIES
                if is_last:
                    raise
                sleep_s = UPSERT_BACKOFF_SECONDS * (2 ** (attempt - 1))
                logger.warning(
                    f"Upsert retry {attempt}/{UPSERT_MAX_RETRIES} failed: {e}. "
                    f"Sleeping {sleep_s}s..."
                )
                time.sleep(sleep_s)


def index_single_file(file_path: Path, embedding_batch_size=EMBEDDING_BATCH_SIZE, upsert_batch_size=INDEX_BATCH_SIZE):
    """Index a single XML file. Returns (article_count, error_or_none)."""
    if file_path.name.startswith("._"):
        return 0, None

    try:
        # Parse articles
        articles = list(parse_xml_file(file_path))
        if not articles:
            return 0, None

        # Deduplicate by generated point ID (law_id + source_anchor_id fallback)
        seen_ids = set()
        unique_articles = []
        for art in articles:
            pid = generate_id(art.law_id, art.source_anchor_id, art.article_id)
            if pid not in seen_ids:
                seen_ids.add(pid)
                unique_articles.append(art)
        articles = unique_articles

        # Generate embeddings with OOM fallback
        texts = [a.content for a in articles]
        all_embeddings = encode_texts_with_fallback(texts, embedding_batch_size)

        # Build points
        points = []
        for idx, art in enumerate(articles):
            points.append(PointStruct(
                id=generate_id(art.law_id, art.source_anchor_id, art.article_id),
                vector=all_embeddings[idx].tolist(),
                payload={
                    "page_content": art.content,
                    "metadata": {
                        "article_id": art.article_id,
                        "title": art.title,
                        "law_id": art.law_id,
                        "law_title": art.law_title,
                        "law_short_name": art.law_short_name,
                        "chapter_id": art.chapter_id,
                        "chapter_title": art.chapter_title,
                        "source_anchor_id": art.source_anchor_id,
                        "doc_type": art.doc_type,
                        "cross_references": art.cross_references or [],
                        "url": art.url,
                    },
                },
            ))

        upsert_with_retry(points, upsert_batch_size=upsert_batch_size)
        return len(points), None

    except Exception as e:
        return 0, str(e)

    finally:
        if device == "cuda":
            torch.cuda.empty_cache()


def index_files(data_dir: str):
    """Index all XML files in a directory to Qdrant."""
    data_path = Path(data_dir)
    if not data_path.is_dir():
        print(f"Skipping {data_dir} (not found)")
        return 0, 0, []

    files = sorted(data_path.glob("*.xml"))
    total_files = len(files)
    print(f"\nIndexing {total_files} files from {data_dir}")

    total_articles = 0
    files_done = 0
    failed = []
    start = time.time()

    for file_idx, file_path in enumerate(files):
        article_count, err = index_single_file(file_path)
        if err:
            logger.error(f"Failed {file_path.name}: {err}")
            failed.append({"path": str(file_path), "error": err})
        else:
            files_done += 1
            total_articles += article_count

        # Progress every 100 files
        if (file_idx + 1) % 100 == 0 or file_idx == total_files - 1:
            elapsed = time.time() - start
            rate = (file_idx + 1) / elapsed if elapsed > 0 else 0
            eta = (total_files - file_idx - 1) / rate if rate > 0 else 0
            print(
                f"  [{file_idx + 1}/{total_files}] "
                f"{total_articles} articles | "
                f"{rate:.1f} files/s | "
                f"ETA: {eta / 60:.0f}m"
            )

    elapsed = time.time() - start
    print(f"\nDone: {files_done}/{total_files} files, {total_articles} articles in {elapsed:.0f}s")
    if failed:
        failed_names = [Path(x["path"]).name for x in failed]
        print(f"Failed ({len(failed)}): {', '.join(failed_names[:10])}{'...' if len(failed_names) > 10 else ''}")
    return files_done, total_articles, failed


# Run indexing for all directories
grand_total_files = 0
grand_total_articles = 0
all_failed = []
grand_start = time.time()

for data_dir in DATA_DIRS:
    files_done, articles, failed = index_files(data_dir)
    grand_total_files += files_done
    grand_total_articles += articles
    all_failed.extend(failed)

grand_elapsed = time.time() - grand_start
print("\n" + "=" * 60)
print("INDEXING COMPLETE")
print("=" * 60)
print(f"  Total files:    {grand_total_files}")
print(f"  Total articles: {grand_total_articles}")
print(f"  Failed:         {len(all_failed)}")
print(f"  Time:           {grand_elapsed / 60:.1f} minutes")
print("=" * 60)

# Persist failed records for optional backfill step
FAILED_MANIFEST_PATH = "/content/failed_files.json"
with open(FAILED_MANIFEST_PATH, "w", encoding="utf-8") as f:
    json.dump(all_failed, f, ensure_ascii=False, indent=2)
print(f"Saved failed-file manifest: {FAILED_MANIFEST_PATH}")

## 8. Backfill Failed Files (Optional)

Run this after indexing if `failed_files.json` is non-empty. It retries only failed files with safer batch sizes.

In [None]:
BACKFILL_EMBEDDING_BATCH_SIZE = 64
BACKFILL_UPSERT_BATCH_SIZE = 150

failed_path = Path("/content/failed_files.json")
if not failed_path.exists():
    print("No failed manifest found. Run the indexing cell first.")
else:
    with open(failed_path, "r", encoding="utf-8") as f:
        failed_records = json.load(f)

    if not failed_records:
        print("No failed files to backfill.")
    else:
        print(f"Backfilling {len(failed_records)} failed files...")
        recovered = 0
        still_failed = []
        backfill_start = time.time()

        for idx, rec in enumerate(failed_records, start=1):
            file_path = Path(rec.get("path", ""))
            if not file_path.exists() and file_path.name:
                # Fallback lookup by filename inside DATA_DIRS
                candidates = [Path(d) / file_path.name for d in DATA_DIRS]
                candidates = [p for p in candidates if p.exists()]
                if candidates:
                    file_path = candidates[0]

            if not file_path.exists():
                still_failed.append({
                    "path": rec.get("path", ""),
                    "error": "File not found during backfill",
                })
                continue

            article_count, err = index_single_file(
                file_path,
                embedding_batch_size=BACKFILL_EMBEDDING_BATCH_SIZE,
                upsert_batch_size=BACKFILL_UPSERT_BATCH_SIZE,
            )

            if err:
                still_failed.append({"path": str(file_path), "error": err})
            else:
                recovered += 1

            if idx % 10 == 0 or idx == len(failed_records):
                elapsed = time.time() - backfill_start
                print(f"  [{idx}/{len(failed_records)}] recovered={recovered}, remaining_failed={len(still_failed)}, elapsed={elapsed:.0f}s")

        remaining_path = Path("/content/failed_files_remaining.json")
        with open(remaining_path, "w", encoding="utf-8") as f:
            json.dump(still_failed, f, ensure_ascii=False, indent=2)

        print("\nBACKFILL COMPLETE")
        print(f"  Recovered: {recovered}")
        print(f"  Still failed: {len(still_failed)}")
        print(f"  Remaining manifest: {remaining_path}")

## 9. Verify

In [None]:
info = client.get_collection(COLLECTION_NAME)
print(f"Collection: {COLLECTION_NAME}")
print(f"Points: {info.points_count}")
print(f"Status: {info.status}")

# Validate metadata completeness for doc_type and show distribution sanity.
missing_doc_type = 0
provision_count = 0
editorial_note_count = 0
offset = None
while True:
    points, offset = client.scroll(
        collection_name=COLLECTION_NAME,
        limit=256,
        offset=offset,
        with_payload=True,
        with_vectors=False,
    )
    if not points:
        break
    for point in points:
        metadata = (point.payload or {}).get("metadata", {}) or {}
        doc_type = metadata.get("doc_type")
        if not doc_type:
            missing_doc_type += 1
        elif doc_type == "provision":
            provision_count += 1
        elif doc_type == "editorial_note":
            editorial_note_count += 1
    if offset is None:
        break

print(f"missing_doc_type: {missing_doc_type}")
print(f"provision_count: {provision_count}")
print(f"editorial_note_count: {editorial_note_count}")

# Quick test search
test_query = "Hvor mye kan utleier kreve i depositum?"
test_embedding = model.encode(test_query).tolist()
results = client.query_points(
    collection_name=COLLECTION_NAME,
    query=test_embedding,
    limit=3,
)
print(f"\nTest query: '{test_query}'")
for i, point in enumerate(results.points):
    meta = point.payload.get("metadata", {})
    print(
        f"  {i+1}. {meta.get('law_title', '?')} - {meta.get('title', '?')} "
        f"[{meta.get('doc_type', '?')}] (score: {point.score:.3f})"
    )

In [None]:
# (Deprecated duplicate parser import cell intentionally left blank.)