In [None]:
%cd ../..
%load_ext autoreload

%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import json
import folium
from collections import Counter
import os
import re
import matplotlib.pyplot as plt
import seaborn as sns
from ast import literal_eval

import litellm
from litellm import completion
litellm.set_verbose=False

from emv.features.wikidata import get_wikidata_id, get_property, get_wikidata_label
from emv.features.wikidata import process_batch

In [None]:
df = pd.read_csv("data/rts_sample.csv", 
                 sep = "\t", 
                 converters = {
                     "data": literal_eval,
                     "locations": literal_eval,
                     "people": literal_eval,
                     "orgs": literal_eval,
                     "misc": literal_eval
                })

In [None]:
df["transcript"] = df["data"].map(lambda x: x.get("transcript", []))
df = df[df.transcript.map(lambda x: type(x) == list)].reset_index(drop=True) # Get full videos with speaker diarization info
print(f"Processed {len(df)} videos.")

# Persons

In [None]:
filter_people = ["messieurs", "monsieur", "madame", "mesdames"]

df["people"] = df["people"].apply(lambda x: [p for p in x if p.lower() not in filter_people])  

In [None]:
persons = df["transcript"].to_frame()
persons["entities"] = df.transcript.map(lambda x: [t.get("entities", None) for t in x])
persons["context"] = df.transcript.map(lambda x: [t["t"] for t in x])
persons = persons.explode(["entities", "context"]).explode("entities").reset_index(drop=True).dropna(subset = ["entities"])
persons["entities"] = persons.entities.map(lambda x: x[0] if x[1] == "PER" else None)
persons = persons.dropna(subset = ["entities"]).reset_index(drop=True)
persons.head()

In [None]:
top_persons = df[["people", "year"]].explode("people").dropna().groupby("people").agg(list).reset_index()
top_persons["count"] = top_persons["year"].apply(len)
top_persons["year"] = top_persons.year.map(lambda x: Counter(x))
top_persons = top_persons.sort_values("count", ascending=False)
print(f"Found {len(top_persons)} persons in the dataset.")
print(f"Mean number of mentions per person: {top_persons['count'].mean():.2f} +/- {top_persons['count'].sem():.2f}")

In [None]:
min_count = 50
top_persons = top_persons[top_persons["count"] > min_count]
print(f"Found {len(top_persons)} persons with more than {min_count} occurrences.")
top_persons.head(10)

In [None]:
top_persons["wikidata_search"] = top_persons["people"].map(lambda x: get_wikidata_id(x, top_n = 10, delay = 1))
top_persons = top_persons[top_persons["wikidata_search"].map(lambda x: len(x) > 0)]

In [None]:
top_persons = top_persons[top_persons["wikidata_search"].map(lambda x: len(x) > 0)]
top_persons["wikidata_candidates"] = top_persons["wikidata_search"].map(lambda x: [(c.get("id"), c.get("description")) for c in x])

In [None]:
persons_candidates = top_persons[["people", "wikidata_candidates"]].to_dict(orient="records")
persons_candidates = {p["people"]: p["wikidata_candidates"] for p in persons_candidates}
persons_candidates

# Disambiguation

In [None]:
def build_query(named_entity, context, candidates):

   query = f"""
   I have a list of named entities and the context they have been extracted from. 
   I will provide you with pairs (named entity, context) as well as a list of possible candidates for those named entities. 
   Given the context, choose the best candidate. 
   Return ONLY the best candidate, without any additional information or context. DO NOT add any other words.
   I repeat, return ONLY THE BEST CANDIDATE.

   Named entity: {named_entity}
   Context: {context}
   Candidates: {candidates}
   """
   
   return query

In [None]:
np.random.seed(42)

test_persons = list(persons_candidates.keys())[:20]
test_pairs = persons[persons.entities.map(lambda x: x in test_persons)].drop_duplicates("context")[["entities", "context"]].sample(500).to_dict("records")
print(f"Testing on {len(test_pairs)} pairs.")

In [None]:
queries = [build_query(
            pair["entities"], 
            pair["context"], 
            persons_candidates[pair["entities"]]
            ) 
            for pair in test_pairs
        ]

print(f"Built {len(queries)} queries.")

In [None]:
model = "ollama/mixtral:8x7b-instruct-v0.1-fp16"
api_base = "http://192.168.1.42:11434"

pattern = r'\bQ\d+\b'
responses = []

for query in queries:
    response = completion(
                model=model, 
                messages = [{ "content": query, "role": "user"}], 
                api_base=api_base
    )
    match_id = re.findall(pattern, response.choices[0].message.content)
    if len(match_id) > 0:
        responses.append(match_id[0])
    else:
        responses.append(None)

In [None]:
with open("data/disambiguation_results.json", "w") as f:
    json.dump({"queries": queries, "responses": responses}, f, indent=4)

In [None]:
import json

In [None]:
with open("data/disambiguation_results.json", "r") as f:
    responses = json.load(f)

In [None]:
responses = pd.DataFrame(responses)
responses.head()

In [None]:
responses["entities"] = responses.queries.map(lambda x: re.findall(r'Named entity: (.*)\n', x)[0])
responses["context"] = responses.queries.map(lambda x: re.findall(r'Context: (.*)\n', x)[0])

In [None]:
responses