In [None]:
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, VectorParams
from sentence_transformers import SentenceTransformer
import numpy as np
from pathlib import Path
import json
import pandas as pd
import env


#RUN this if you are not using the docker image (use it cause takes too much memory otherwise)
from huggingface_hub import login
login(env.HF_LOGIN_TOKEN)

# ----------------------------
# Step 1: Connect to Qdrant
# ----------------------------
# Option 1: In-memory (no Docker)
client = QdrantClient(":memory:")

# # Option 2: Local Qdrant server
# client = QdrantClient("http://localhost:6333")




In [None]:
# ----------------------------
# Step 2: Initialize model
# ----------------------------
model = SentenceTransformer("intfloat/multilingual-e5-large")
model.save("models/multilingual-e5-large") #save model locally (do only once)

In [None]:
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, VectorParams
from sentence_transformers import SentenceTransformer
import numpy as np
from pathlib import Path
import json
import pandas as pd
import torch
#If you already saved the model locally and are using docker
client = QdrantClient(host="localhost", port=6333, prefer_grpc=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SentenceTransformer("models/multilingual-e5-large", device=device)

# Run docker container before
# docker run -d --name qdrant -p 6333:6333 -p 6334:6334 -v /home/Loris/EPFL/MA3/ML/project2/project2Rag/qdrant_storage:/qdrant/storage qdrant/qdrant:latest



In [1]:

# ----------------------------
# Step 3: Load articles from JSON (Polars)
# ----------------------------
from pathlib import Path
import json
import polars as pl

path = Path("../wikiextractor/articles_fr_withLinks.json")

ARTICLES_NUMBER = None  # None = load all; set an int to cap when sampling

def load_jsonl_to_df(path: Path, max_rows: int | None = None) -> pl.DataFrame:
    # If loading the full dataset, use Polars' fast NDJSON reader
    if max_rows is None:
        try:
            df = pl.read_ndjson(str(path))
        except Exception:
            # Fallback if the NDJSON reader hits malformed lines: manual parse
            rows = []
            with open(path, "r", encoding="utf-8") as f:
                for line in f:
                    try:
                        rows.append(json.loads(line))
                    except json.JSONDecodeError:
                        continue
            df = pl.DataFrame(rows)
    else:
        # Limited load: manual parse up to max_rows
        rows = []
        with open(path, "r", encoding="utf-8") as f:
            for i, line in enumerate(f):
                if i >= max_rows:
                    break
                try:
                    rows.append(json.loads(line))
                except json.JSONDecodeError:
                    continue
        df = pl.DataFrame(rows)

    # Ensure columns exist and normalize types/nulls
    cols = set(df.columns)

    if "id" in cols:
        # Cast id to Int64
        df = df.with_columns(pl.col("id").cast(pl.Int64))
    else:
        df = df.with_columns(pl.lit(None).cast(pl.Int64).alias("id"))

    if "title" in cols:
        df = df.with_columns(pl.col("title").cast(pl.Utf8).fill_null(""))
    else:
        df = df.with_columns(pl.lit("").alias("title"))

    if "url" in cols:
        df = df.with_columns(pl.col("url").cast(pl.Utf8).fill_null(""))
    else:
        df = df.with_columns(pl.lit("").alias("url"))

    if "text" in cols:
        df = df.with_columns(pl.col("text").cast(pl.Utf8).fill_null(""))
    else:
        df = df.with_columns(pl.lit("").alias("text"))

    return df

df = load_jsonl_to_df(path, max_rows=ARTICLES_NUMBER)
print(f"‚úÖ Loaded {len(df)} articles into a single Polars DataFrame")


‚úÖ Loaded 4498441 articles into a single Polars DataFrame


In [None]:
#Write polars df to disk if needed
df.write_parquet("articles_fr_withLinks.parquet")

In [None]:
import polars as pl
import re
import html
from urllib.parse import unquote
#Extract href links from the articles with links and saves them to a df

BASE_URL = "https://fr.wikipedia.org/wiki/"
ANCHOR_RE = re.compile(r'<a\s+[^>]*href="([^"]+)"[^>]*>(.*?)</a>', re.IGNORECASE | re.DOTALL)

def extract_links_with_pos(text: str):
    if not text:
        return []
    unescaped = html.unescape(text)
    out = []
    search_from = 0
    for m in ANCHOR_RE.finditer(unescaped):
        href_raw = m.group(1)
        anchor = html.unescape(m.group(2))
        needle = f'&lt;a href="{href_raw}"'
        pos = text.find(needle, search_from)
        if pos == -1:
            pos = text.find("&lt;a ", search_from)
        if pos == -1:
            pos = 0
        search_from = pos + 1
        out.append({
            "full_url": BASE_URL + href_raw,
            "start_idx": pos,
            "anchor": anchor,
            "href_raw": href_raw,
            "href_decoded": unquote(href_raw),
        })
    return out

LINK_ITEM = pl.Struct([
    pl.Field("full_url", pl.Utf8),
    pl.Field("start_idx", pl.Int64),
    pl.Field("anchor", pl.Utf8),
    pl.Field("href_raw", pl.Utf8),
    pl.Field("href_decoded", pl.Utf8),
])
LINK_LIST = pl.List(LINK_ITEM)

def build_per_article_links(df: pl.DataFrame, max_rows: int | None = None) -> pl.DataFrame:
    if max_rows is not None:
        df = df.slice(0, min(max_rows, df.height))
    idx = (
        df.select(["id", "title", "text"])
          .with_columns(pl.col("text").cast(pl.Utf8).fill_null(""))
          .with_columns(links=pl.col("text").map_elements(extract_links_with_pos, return_dtype=LINK_LIST))
          .drop("text")
          .with_columns(
              links=pl.when(pl.col("links").is_null()).then(pl.lit([]).cast(LINK_LIST)).otherwise(pl.col("links")),
              link_count=pl.col("links").list.len()
          )
    )
    return idx

# Example usage on your Polars df:
per_article = build_per_article_links(df, max_rows=10000)
# print(per_article.select(["id", "title", "link_count"]).head(10))

# Inspect one article: all links and their positions in a single row
article_id = 3
row = per_article.filter(pl.col("id") == article_id)
links = row.select("links").to_series().to_list()[0] if row.height else []
print(links)
# for l in links[:20]:
#     print(f"{l['full_url']} - {l['start_idx']}")

[{'full_url': 'https://fr.wikipedia.org/wiki/Moulins%20%28Allier%29', 'start_idx': 25, 'anchor': 'Moulins', 'href_raw': 'Moulins%20%28Allier%29', 'href_decoded': 'Moulins (Allier)'}, {'full_url': 'https://fr.wikipedia.org/wiki/Allier%20%28d%C3%A9partement%29', 'start_idx': 83, 'anchor': 'Allier', 'href_raw': 'Allier%20%28d%C3%A9partement%29', 'href_decoded': 'Allier (d√©partement)'}, {'full_url': 'https://fr.wikipedia.org/wiki/Ch%C3%A2teaumeillant', 'start_idx': 162, 'anchor': 'Ch√¢teaumeillant', 'href_raw': 'Ch%C3%A2teaumeillant', 'href_decoded': 'Ch√¢teaumeillant'}, {'full_url': 'https://fr.wikipedia.org/wiki/Cher%20%28d%C3%A9partement%29', 'start_idx': 226, 'anchor': 'Cher', 'href_raw': 'Cher%20%28d%C3%A9partement%29', 'href_decoded': 'Cher (d√©partement)'}, {'full_url': 'https://fr.wikipedia.org/wiki/Philologie', 'start_idx': 296, 'anchor': 'philologue', 'href_raw': 'Philologie', 'href_decoded': 'Philologie'}, {'full_url': 'https://fr.wikipedia.org/wiki/liste%20de%20linguistes', 's

In [None]:
# Load vector size from model
vector_size = model.get_sentence_embedding_dimension()



In [None]:
#Warning: only create the collection once ! Otherwise you will loose the content of it

# collection_name = "wikipedia_fr"
# if not client.collection_exists(collection_name):
#     client.create_collection(
#         collection_name=collection_name,
#         vectors_config=VectorParams(size=vector_size, distance="Cosine"),
#     )
# else:
#     print(f"Collection '{collection_name}' already exists; skipping creation.")

In [None]:
import pyarrow.parquet as pq

#load file from parquet (if you saved it once previously)
pf = pq.ParquetFile("articles_fr_withLinks.parquet")
rowgroup_batch = 512
encode_batch = 64


In [None]:
from qdrant_client.http import models as rest
#get articles count from Qdrant DB 

collection_name = "wikipedia_fr"
total = client.count(collection_name=collection_name, exact=True).count
print(f"Total articles stored: {total}")

Total articles stored: 26112


In [None]:
#Recover a snapshot from the Qdrant db (should not be needed)
client.recover_snapshot(collection_name="wikipedia_fr", location="http://localhost:6333/collections/wikipedia_fr/snapshots/<snapshot_name>")

In [None]:
import polars as pl

# Process Articles into the Qdrant DB (warning: takes shit ton of time and ressources)

ARTICLES_NUMBER=2000

new_processed = 0
max_rows = ARTICLES_NUMBER

for rg in range(pf.num_row_groups):
    if max_rows is not None and new_processed >= max_rows:
        break
    table = pf.read_row_group(rg, columns=["id", "title", "url", "text"])
    df = pl.from_arrow(table).with_columns([
        pl.col("id").cast(pl.Int64),
        pl.col("title").fill_null(""),
        pl.col("url").fill_null(""),
        pl.col("text").fill_null("")
    ])
    n = df.height
    remaining = None if max_rows is None else max_rows - new_processed
    for start in range(0, n, rowgroup_batch):
        length = min(rowgroup_batch, n - start)
        if remaining is not None:
            length = min(length, remaining)
        sub = df.slice(start, length)
        ids = sub["id"].to_list()
        texts = sub["text"].to_list()
        titles = sub["title"].to_list()
        urls = sub["url"].to_list()

        existing = client.retrieve(
            collection_name="wikipedia_fr",
            ids=[int(i) for i in ids],
            with_payload=False,
            with_vectors=False
        )
        existing_ids = {p.id for p in existing}
        missing_idx = [i for i, pid in enumerate(ids) if int(pid) not in existing_ids]
        if not missing_idx:
            continue

        ids_m = [ids[i] for i in missing_idx]
        texts_m = [texts[i] for i in missing_idx]
        titles_m = [titles[i] for i in missing_idx]
        urls_m = [urls[i] for i in missing_idx]

        vectors = model.encode(texts_m, batch_size=encode_batch, convert_to_numpy=True, normalize_embeddings=True)
        points = [
            PointStruct(
                id=int(ids_m[i]),
                vector=vectors[i].tolist(),
                payload={"id": int(ids_m[i]), "title": titles_m[i], "url": urls_m[i], "text": texts_m[i]},
            )
            for i in range(len(ids_m))
        ]
        client.upsert(collection_name="wikipedia_fr", points=points, wait=True)
        new_processed += len(ids_m)
        if max_rows is not None and new_processed >= max_rows:
            break

In [None]:
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
import polars as pl
from typing import Optional

#Query Qdrant DB to rebuild the polars df (warning: untested)

def rebuild_df_from_qdrant(
    client: QdrantClient,
    collection_name: str,
    max_rows: Optional[int] = 1000,
    batch_limit: int = 1000,
) -> pl.DataFrame:
    rows = []
    next_page = None

    selector = rest.PayloadSelectorInclude(include=["id", "title", "url", "text"])

    while True:
        points, next_page = client.scroll(
            collection_name=collection_name,
            limit=batch_limit,
            with_payload=True,
            with_vectors=False,
            payload_selector=selector,
            offset=next_page,
        )
        if not points:
            break

        for p in points:
            pid = p.payload.get("id", p.id)
            rows.append({
                "id": pid,
                "title": p.payload.get("title", ""),
                "url": p.payload.get("url", ""),
                "text": p.payload.get("text", ""),
            })
            if max_rows is not None and len(rows) >= max_rows:
                break

        if max_rows is not None and len(rows) >= max_rows:
            break
        if next_page is None:
            break

    df = pl.DataFrame(rows).with_columns([
        pl.col("id").cast(pl.Int64),
        pl.col("title").cast(pl.Utf8).fill_null(""),
        pl.col("url").cast(pl.Utf8).fill_null(""),
        pl.col("text").cast(pl.Utf8).fill_null(""),
    ])
    return df


In [None]:
import polars as pl

#Query DB with one keyword for cosine similarity (warning : NANs issues gives 1.0 score)

query_title = "Math"

qdf = df.filter(pl.col("title").str.to_lowercase() == query_title.lower())
query_text = qdf.select("text").to_series()[0] if qdf.height > 0 else query_title

# E5 requires the 'query:' prefix for queries; documents were encoded as 'passage:' during ingestion
query_vector = model.encode([f"query: {query_text}"], normalize_embeddings=True)[0]

results = client.search(
    collection_name="wikipedia_fr",
    query_vector=query_vector,
    limit=5
)

print(f"\nüîç Top matches for '{query_title}':")
for r in results:
    print(f"‚Ä¢ {r.payload.get('title')} (score={r.score:.3f})")


üîç Top matches for 'Math':
‚Ä¢ Nombre entier (score=0.824)
‚Ä¢ S√©rie (math√©matiques) (score=0.818)
‚Ä¢ Distance (math√©matiques) (score=0.813)
‚Ä¢ Nombre (score=0.812)
‚Ä¢ G√©om√©trie arithm√©tique (score=0.811)


  results = client.search(


In [None]:
import polars as pl
import numpy as np

# Function to encode whole article text into vector 
# Then query the db with the article embedding to get the k most similar articles

def encode_article_text(text):
    words = text.split()
    if not words:
        return np.zeros(model.get_sentence_embedding_dimension(), dtype=np.float32)
    chunks = [" ".join(words[i:i+256]) for i in range(0, len(words), 256)]
    vecs = model.encode([f"passage: {c}" for c in chunks], normalize_embeddings=True)
    return np.mean(vecs, axis=0)

mode = "ids"  # "all" or "ids"
ids_to_process = [189, 205]
k = 5

ids_all = df.select(pl.col("id").cast(pl.Int64)).to_series().to_list()
ids = ids_all if mode == "all" else [i for i in ids_to_process if i in set(ids_all)]

for article_id in ids:
    title_row = df.filter(pl.col("id") == int(article_id)).select("title")
    title = title_row.to_series()[0] if title_row.height > 0 else "(unknown)"
    try:
        points = client.retrieve(collection_name="wikipedia_fr", ids=[int(article_id)], with_payload=True, with_vectors=False)
        if points:
            results = client.recommend(collection_name="wikipedia_fr", positive=[int(article_id)], limit=k)
        else:
            raise ValueError("missing")
    except Exception:
        row = df.filter(pl.col("id") == int(article_id))
        if row.height == 0:
            continue
        text = row.select(pl.col("text").fill_null("")).to_series()[0]
        vec = encode_article_text(text)
        results = client.search(collection_name="wikipedia_fr", query_vector=vec, limit=k)
    print(f"\nüîé Similar articles for: {title} (id={int(article_id)})")
    for r in results:
        print(f"  ‚Ä¢ {r.payload.get('title')} (id={r.payload.get('id')}, score={r.score:.3f})")


üîé Similar articles for: Atome (id=189)
  ‚Ä¢ Proton (id=2414, score=0.939)
  ‚Ä¢ Liaison chimique (id=22131, score=0.939)
  ‚Ä¢ √âl√©ment chimique (id=15349, score=0.936)
  ‚Ä¢ √âlectron (id=6716, score=0.935)
  ‚Ä¢ Particule √©l√©mentaire (id=23547, score=0.934)


  results = client.recommend(collection_name="wikipedia_fr", positive=[int(article_id)], limit=k)


: 

In [None]:
import polars as pl

#Export articles to csv for further use

cols = ["id", "title", "url", "text"]
df.select(cols).write_csv("articles.csv")

In [None]:
import polars as pl
import pyarrow.parquet as pq
import pyarrow.csv as pacsv

pf = pq.ParquetFile("articles_fr_withLinks.parquet")
cols = ["id", "title", "url", "text"]
ARTICLES_NUMBER = 100000
processed = 0
first = True
output_path = "articles_stream.csv"

for rg in range(pf.num_row_groups):
    if ARTICLES_NUMBER is not None and processed >= ARTICLES_NUMBER:
        break
    table = pf.read_row_group(rg, columns=cols)
    df_rg = pl.from_arrow(table).with_columns([
        pl.col("id").cast(pl.Int64),
        pl.col("title").fill_null(""),
        pl.col("url").fill_null(""),
        pl.col("text").fill_null("")
    ])
    remaining = None if ARTICLES_NUMBER is None else ARTICLES_NUMBER - processed
    if remaining is not None and df_rg.height > remaining:
        df_rg = df_rg.slice(0, remaining)
    mode = "wb" if first else "ab"
    with open(output_path, mode) as f:
        pacsv.write_csv(df_rg.to_arrow(), f, write_options=pacsv.WriteOptions(include_header=first))
    processed += df_rg.height
    first = False

In [None]:
import polars as pl
import pyarrow.parquet as pq
from qdrant_client.models import PointStruct

def export_points_df(df: pl.DataFrame, client, collection_name: str, model, encode_batch: int = 64, upsert_batch: int = 512, max_rows: int | None = None):
    n = df.height if max_rows is None else min(df.height, max_rows)
    processed = 0
    while processed < n:
        length = min(upsert_batch, n - processed)
        sub = df.slice(processed, length)
        ids = sub["id"].cast(pl.Int64).to_list()
        titles = sub["title"].fill_null("").to_list()
        urls = sub["url"].fill_null("").to_list()
        texts = sub["text"].fill_null("").to_list()
        vectors = model.encode(texts, batch_size=encode_batch, convert_to_numpy=True, normalize_embeddings=True)
        points = [
            PointStruct(
                id=int(ids[i]),
                vector=vectors[i].tolist(),
                payload={"id": int(ids[i]), "title": titles[i], "url": urls[i], "text": texts[i]},
            )
            for i in range(len(ids))
        ]
        client.upsert(collection_name=collection_name, points=points, wait=True)
        processed += length

NameError: name 'get_similar_articles_by_id' is not defined

In [None]:
import polars as pl
import pyarrow.parquet as pq
from qdrant_client.models import PointStruct

def export_points_parquet(parquet_path: str, client, collection_name: str, model, encode_batch: int = 64, upsert_batch: int = 512, max_rows: int | None = None):
    pf = pq.ParquetFile(parquet_path)
    processed = 0
    for rg in range(pf.num_row_groups):
        if max_rows is not None and processed >= max_rows:
            break
        table = pf.read_row_group(rg, columns=["id", "title", "url", "text"])
        df = pl.from_arrow(table).with_columns([
            pl.col("id").cast(pl.Int64),
            pl.col("title").fill_null(""),
            pl.col("url").fill_null(""),
            pl.col("text").fill_null("")
        ])
        remaining = None if max_rows is None else max_rows - processed
        n = df.height if remaining is None else min(df.height, remaining)
        inner = 0
        while inner < n:
            length = min(upsert_batch, n - inner)
            sub = df.slice(inner, length)
            ids = sub["id"].to_list()
            titles = sub["title"].to_list()
            urls = sub["url"].to_list()
            texts = sub["text"].to_list()
            vectors = model.encode(texts, batch_size=encode_batch, convert_to_numpy=True, normalize_embeddings=True)
            points = [
                PointStruct(
                    id=int(ids[i]),
                    vector=vectors[i].tolist(),
                    payload={"id": int(ids[i]), "title": titles[i], "url": urls[i], "text": texts[i]},
                )
                for i in range(len(ids))
            ]
            client.upsert(collection_name=collection_name, points=points, wait=True)
            processed += length
            inner += length