In [10]:
import json
from lxml import etree
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
WIKI_PATH = "cswiki-latest-pages-articles.xml"

In [None]:
# def get_page_by_title(xml_path, search_title):
#     context = etree.iterparse(xml_path, events=('end',), tag='{*}page')
#     for _, elem in context:
#         title = elem.findtext('{*}title')
#         if title == search_title:
#             text_elem = elem.find('.//{*}revision/{*}text')
#             raw_text = text_elem.text if text_elem is not None else ""
#             return title, raw_text.strip()
#         elem.clear()
#         while elem.getprevious() is not None:
#             del elem.getparent()[0]
#     return None, None

In [None]:
# jsonl_path = Path("npfl140/data/wiki_passages_with_embeddings.jsonl")
# with open(jsonl_path, "r", encoding="utf-8") as f:
#     for line in f:
#         item = json.loads(line)
#         raw_title = item["title"]
#
# # === Normalize title ===
# normalized_title = raw_title.replace("_", " ")
#
# # === Search in XML ===
# title, content = get_page_by_title(WIKI_PATH, normalized_title)
#
# if content:
#     print(f"✅ Found: {title}\n---\n{content[:1000]}...")
# else:
#     print(f"❌ Page not found for title: {normalized_title}")

# RAG

In [11]:
model_name = "intfloat/multilingual-e5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

XLMRobertaModel(
  (embeddings): XLMRobertaEmbeddings(
    (word_embeddings): Embedding(250002, 1024, padding_idx=1)
    (position_embeddings): Embedding(514, 1024, padding_idx=1)
    (token_type_embeddings): Embedding(1, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): XLMRobertaEncoder(
    (layer): ModuleList(
      (0-23): 24 x XLMRobertaLayer(
        (attention): XLMRobertaAttention(
          (self): XLMRobertaSdpaSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): XLMRobertaSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-05, eleme

In [12]:
def load_embeddings(jsonl_path):
    data = []
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            item = json.loads(line)
            item["embedding"] = np.array(item["embedding"], dtype=np.float32)
            data.append(item)
    return data

In [14]:
embeddings_data = load_embeddings("npfl140/data/wiki_passages_with_embeddings.jsonl")

In [18]:
WIKI_PATH = "cswiki-latest-pages-articles.xml"

In [15]:
def encode_query(text: str) -> np.ndarray:
    with torch.no_grad():
        input_ids = tokenizer(
            "query: " + text,
            return_tensors="pt",
            truncation=True,
            max_length=512
        ).to(device)
        output = model(**input_ids)
        mask = input_ids["attention_mask"].unsqueeze(-1).bool()
        masked = output.last_hidden_state.masked_fill(~mask, 0.0)
        mean_pool = masked.sum(1) / mask.sum(1)
        return F.normalize(mean_pool, p=2, dim=1)[0].cpu().numpy()

In [17]:
def find_top_passages(query_vec, embeddings_data, top_k=3):
    all_embeddings = np.stack([item["embedding"] for item in embeddings_data])
    scores = np.dot(all_embeddings, query_vec)  # cosine similarity
    top_indices = np.argsort(scores)[-top_k:][::-1]
    return [embeddings_data[i] for i in top_indices], scores[top_indices]

In [16]:
def get_page_by_title(xml_path, search_title):
    context = etree.iterparse(xml_path, events=('end',), tag='{*}page')
    for _, elem in context:
        title = elem.findtext('{*}title')
        if title == search_title:
            text_elem = elem.find('.//{*}revision/{*}text')
            raw_text = text_elem.text if text_elem is not None else ""
            return title, raw_text.strip()
        elem.clear()
        while elem.getprevious() is not None:
            del elem.getparent()[0]
    return None, None

In [19]:
query = "Šperky s jakým drahokamem vyrábí Granát Turnov?"
query_vec = encode_query(query)

top_matches, scores = find_top_passages(query_vec, embeddings_data, top_k=3)

for i, match in enumerate(top_matches):
    print(f"\n=== Match #{i+1} | Title: {match['title']} | Score: {scores[i]:.3f} ===")
    print(match['passage'][:300], "...\n")

    # Try to get full content
    normalized_title = match["title"].replace("_", " ")
    title, full_text = get_page_by_title(WIKI_PATH, normalized_title)

    if full_text:
        print(f"\n✅ Full Page Content: {title}\n{'-'*60}\n{full_text[:2000]}...\n")
    else:
        print(f"⚠️ Page not found for title: {normalized_title}")



=== Match #1 | Title: Granát_Turnov | Score: 23.637 ===
Granát Turnov thumb|Hlavní budova Granátu v Turnově Granát, družstvo umělecké výroby Turnov je česká šperkařská firma se sídlem v Turnově. Družstvo je největším výrobcem šperků s české granáty (tzn. pyropy). Je vlastníkem práva těžby českých granátů, které těží v povrchovém dolu u Podsedic. Ukázku r ...


✅ Full Page Content: Granát Turnov
------------------------------------------------------------
{{Infobox - firma
| jméno = Granát, družstvo umělecké výroby, Turnov
| logo = logo_GranatTurnov.jpg
| obrázek = Granát, Výšinka, Turnov.jpg
| popisek = Hlavní budova
| typ = [[družstvo]]
| datum založení = 1953
| zakladatel = 
| sídlo město = Výšinka 1409, [[Turnov]]
| sídlo stát = [[Česko]]
| klíčoví lidé = Pavel Tvrzník (předseda)
| oblast činnosti = [[šperkařství]]
| produkty = šperky
| služby = 
| obrat = 
| výsledek hospodaření = 
| počet zaměstnanců = 220
| mateřská společnost = 
| majitel = 
| web = http://www.granat.cz
| pozná