In [None]:
from pathlib import Path
import pandas as pd

DATA_PATH = Path("bbc_news_train.csv")
if not DATA_PATH.exists():
    raise FileNotFoundError(f"Impossible de trouver le fichier : {DATA_PATH.resolve()}")

df = pd.read_csv(DATA_PATH)
df.head()

Unnamed: 0,ArticleId,Text,Category
0,1833,worldcom ex-boss launches defence lawyers defe...,business
1,154,german business confidence slides german busin...,business
2,1101,bbc poll indicates economic gloom citizens in ...,business
3,1976,lifestyle governs mobile choice faster bett...,tech
4,917,enron bosses in $168m payout eighteen former e...,business


In [4]:
df_tech = df[df["Category"] == "tech"].head(100).copy()
display(df_tech.head())
print(f"Nombre d'articles dans le dataset filtré : {len(df_tech)}")

Unnamed: 0,ArticleId,Text,Category
3,1976,lifestyle governs mobile choice faster bett...,tech
19,1552,moving mobile improves golf swing a mobile pho...,tech
24,405,bt boosts its broadband packages british telec...,tech
26,702,peer-to-peer nets here to stay peer-to-peer ...,tech
30,1951,pompeii gets digital make-over the old-fashion...,tech


Nombre d'articles dans le dataset filtré : 100


In [5]:
# petit set de relations pour éviter un graphe incohérent

REL_SCHEMA = [
  {"pred":"schema:about",
   "subj":["schema:NewsArticle"], "obj":["schema:Thing"],
   "note":"Article -> entités mentionnées/centrales"},

  {"pred":"schema:author",
   "subj":["schema:NewsArticle"], "obj":["foaf:Person","schema:Organization"],
   "note":"Optionnel (souvent absent) : auteur ou média"},

  {"pred":"schema:worksFor",
   "subj":["foaf:Person"], "obj":["schema:Organization"],
   "note":"Emploi / affiliation forte"},

  {"pred":"schema:founder",
   "subj":["schema:Organization"], "obj":["foaf:Person"],
   "note":"Organisation -> fondateur(s)"},

  {"pred":"schema:acquiredBy",
   "subj":["schema:Organization"], "obj":["schema:Organization"],
   "note":"Convention: acquis -> acquéreur"},

  {"pred":"schema:produces",
   "subj":["schema:Organization"], "obj":["schema:Product"],
   "note":"Entreprise -> produit (hardware/software)"},

  {"pred":"schema:location",
   "subj":["schema:Organization","schema:Event"], "obj":["schema:Place"],
   "note":"Si une localisation est explicitement mentionnée"},

  {"pred":"onto:announced",
   "subj":["schema:Organization","foaf:Person"], "obj":["schema:Product","schema:Event"],
   "note":"Local (si tu veux expliciter les annonces/lancements)"},

  {"pred":"schema:datePublished",
   "subj":["schema:NewsArticle"], "obj":["xsd:date","xsd:dateTime"],
   "note":"Attribut date article (si dispo / inférée)"},

  {"pred":"schema:releaseDate",
   "subj":["schema:Product","schema:Event"], "obj":["xsd:date","xsd:dateTime"],
   "note":"Attribut date sortie/lancement (si mentionné)"}
]


In [6]:
import os, json, re, time
from pathlib import Path
from typing import List, Literal, Optional, Dict, Any, Tuple

import requests
from tqdm import tqdm
from pydantic import BaseModel, Field, ValidationError

# Détection "best effort" des colonnes
TEXT_COL = next((c for c in ["Text", "text", "News", "news", "Content", "content"] if c in df_tech.columns), None)
if TEXT_COL is None:
    raise KeyError(f"Colonne texte introuvable. Colonnes dispo: {list(df_tech.columns)}")

ID_COL = next((c for c in ["id", "Id", "ID", "article_id", "ArticleId"] if c in df_tech.columns), None)

ALLOWED_PREDS = [x["pred"] for x in REL_SCHEMA]

CACHE_DIR = Path("cache_extractions")
CACHE_DIR.mkdir(exist_ok=True)


In [7]:
EntityType = Literal["PERSON", "ORG", "GPE", "PRODUCT", "EVENT"]

class Entity(BaseModel):
    mention: str = Field(min_length=1)
    type: EntityType
    start: int
    end: int

class Relation(BaseModel):
    subj: str = Field(min_length=1)
    pred: str = Field(min_length=1)
    obj: str = Field(min_length=1)
    evidence: str = ""

class Extraction(BaseModel):
    entities: List[Entity] = []
    relations: List[Relation] = []


In [8]:
def _strip_code_fences(s: str) -> str:
    s = s.strip()
    s = re.sub(r"^```(?:json)?\s*", "", s, flags=re.IGNORECASE)
    s = re.sub(r"\s*```$", "", s)
    return s.strip()

def _extract_first_json_object(s: str) -> str:
    # essaie de récupérer le premier bloc {...} si le modèle a bavé
    s = _strip_code_fences(s)
    start = s.find("{")
    end = s.rfind("}")
    if start == -1 or end == -1 or end <= start:
        return s
    return s[start:end+1]

def _find_span(text: str, mention: str) -> Tuple[int, int]:
    if not mention:
        return -1, -1
    idx = text.find(mention)
    if idx == -1:
        return -1, -1
    return idx, idx + len(mention)

def _postprocess(extr: Extraction, text: str) -> Extraction:
    # offsets
    fixed_entities = []
    for e in extr.entities:
        s, t = _find_span(text, e.mention)
        fixed_entities.append(Entity(mention=e.mention, type=e.type, start=s, end=t))
    extr.entities = fixed_entities

    entity_mentions = {e.mention for e in extr.entities}

    # relations: pred whitelist + subj/obj exist + evidence substring
    fixed_rel = []
    for r in extr.relations:
        if r.pred not in ALLOWED_PREDS:
            continue
        if (r.subj not in entity_mentions) or (r.obj not in entity_mentions and r.pred not in ["schema:datePublished", "schema:releaseDate"]):
            # pour les dates, obj peut être un literal ISO. Sinon on exige obj dans entities.
            continue
        if r.evidence and (r.evidence not in text):
            r.evidence = ""
        fixed_rel.append(r)

    extr.relations = fixed_rel
    return extr


In [9]:
OLLAMA_URL = "http://localhost:11434/api/chat"
MODEL = "llama3.1:8b"

def _ollama_chat(messages: List[Dict[str, str]], temperature: float = 0.0, timeout: int = 120) -> str:
    payload = {
        "model": MODEL,
        "messages": messages,
        "stream": False,
        "options": {"temperature": temperature}
    }
    r = requests.post(OLLAMA_URL, json=payload, timeout=timeout)
    r.raise_for_status()
    data = r.json()
    return data["message"]["content"]

def _build_prompt(article_text: str) -> List[Dict[str, str]]:
    system = (
        "Tu es un extracteur d'information. "
        "Tu DOIS répondre en JSON strict, sans aucun texte autour, sans markdown. "
        "Le JSON doit respecter EXACTEMENT ce schéma: "
        "{\"entities\":[{\"mention\":str,\"type\":\"PERSON|ORG|GPE|PRODUCT|EVENT\",\"start\":int,\"end\":int}],"
        "\"relations\":[{\"subj\":str,\"pred\":str,\"obj\":str,\"evidence\":str}]}. "
        f"pred doit être dans cette liste exacte: {ALLOWED_PREDS}. "
        "Règles: "
        "1) Les champs entities/relations existent toujours (liste vide ok). "
        "2) subj et obj doivent correspondre à des 'mention' présentes dans entities (sauf pred datePublished/releaseDate où obj peut être une date ISO). "
        "3) evidence doit être un extrait COPIÉ du texte source (30-200 chars si possible). "
        "4) N'invente rien: si doute -> n'extrais pas."
    )

    user = (
        "Texte:\n"
        f"{article_text}\n\n"
        "Retourne UNIQUEMENT le JSON."
    )
    return [{"role": "system", "content": system}, {"role": "user", "content": user}]

def extract(article_text: str, article_id: str | int, max_retries: int = 3) -> Dict[str, Any]:
    cache_path = CACHE_DIR / f"{article_id}.jsonl"
    if cache_path.exists():
        with cache_path.open("r", encoding="utf-8") as f:
            return json.loads(f.readline())

    messages = _build_prompt(article_text)

    last_err = None
    raw = ""
    for attempt in range(1, max_retries + 1):
        raw = _ollama_chat(messages, temperature=0.0)
        raw_json = _extract_first_json_object(raw)

        try:
            parsed = json.loads(raw_json)
            extr = Extraction.model_validate(parsed)
            extr = _postprocess(extr, article_text)

            out = extr.model_dump()
            # cache (jsonl = une ligne)
            with cache_path.open("w", encoding="utf-8") as f:
                f.write(json.dumps(out, ensure_ascii=False) + "\n")
            return out

        except (json.JSONDecodeError, ValidationError) as e:
            last_err = str(e)
            # retry: demander une correction JSON stricte
            messages = [
                {"role": "system", "content": "Corrige la sortie pour qu'elle soit du JSON strict conforme au schéma. Réponds en JSON uniquement."},
                {"role": "user", "content": f"ERREUR:\n{last_err}\n\nSORTIE A CORRIGER:\n{raw}\n\nRappel: pred doit être dans {ALLOWED_PREDS}."}
            ]
            time.sleep(0.2)

    # si échec total: on cache un résultat vide (évite boucles)
    fallback = {"entities": [], "relations": [], "error": last_err, "raw": raw[:1000]}
    with cache_path.open("w", encoding="utf-8") as f:
        f.write(json.dumps(fallback, ensure_ascii=False) + "\n")
    return fallback

In [10]:
results = []

for i, row in tqdm(df_tech.iterrows(), total=len(df_tech)):
    article_id = row[ID_COL] if ID_COL else i
    text = str(row[TEXT_COL])

    res = extract(text, article_id=article_id)
    results.append({"article_id": article_id, **res})

results[0].keys(), len(results)

  0%|          | 0/100 [00:18<?, ?it/s]


KeyboardInterrupt: 