In [None]:
import pickle
from genre.fairseq_model import GENRE
from genre.trie import Trie

# Load Wikipedia titles trie
with open("data/kilt_titles_trie_dict.pkl", "rb") as f:
    trie = Trie.load_from_dict(pickle.load(f))

# Load model
model = GENRE.from_pretrained("models/fairseq_entity_disambiguation_aidayago").eval()

# Example cryptic methodcodes in context
codes = [
    "The method [START_ENT] ICPMS_TD [END_ENT] is widely used in elemental analysis.",
    "We used [START_ENT] FE_EMP [END_ENT] for microscopic analysis.",
    "Dating was done using [START_ENT] PB207_PB206_AGE [END_ENT].",
    "We measured radiation using [START_ENT] GAMMA [END_ENT].",
    "Isotope analysis used [START_ENT] WBD [END_ENT].",
    "Analysis performed via [START_ENT] ICPOES [END_ENT].",
    "Determined age using [START_ENT] NE21 AGE [END_ENT]."
]

# Run GENRE on each sentence
for sentence in codes:
    results = model.sample(
        [sentence],
        prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist())
    )
    print(f"\n{sentence}")
    for r in results[0][:3]:  # top 3 candidates
        print(f"  → {r['text']} (score: {r['score'].item():.4f})")


  from .autonotebook import tqdm as notebook_tqdm



The method [START_ENT] ICPMS_TD [END_ENT] is widely used in elemental analysis.
  → Inductively coupled plasma mass spectrometry (score: -0.1703)
  → Inductively coupled plasma atomic emission spectroscopy (score: -1.2015)
  → ICPM (score: -2.2986)

We used [START_ENT] FE_EMP [END_ENT] for microscopic analysis.
  → Faraday effect (score: -1.8307)
  → Far-infrared astronomy (score: -3.0039)
  → Far-infrared laser (score: -3.3976)

Dating was done using [START_ENT] PB207_PB206_AGE [END_ENT].
  → Parallel ATA (score: -1.7021)
  → Postal codes in Malaysia (score: -1.9363)
  → Postal codes in Canada (score: -1.9818)

We measured radiation using [START_ENT] GAMMA [END_ENT].
  → Gamma-Aminobutyric acid (score: -0.5029)
  → GAMMA (score: -0.6394)
  → Gamma spectroscopy (score: -1.2691)

Isotope analysis used [START_ENT] WBD [END_ENT].
  → Wavelength-division multiplexing (score: -0.8292)
  → WBD (score: -1.5071)
  → Wavelength-division multiple access (score: -1.6072)

Analysis performed via 

In [8]:
import json


# 1. Load your candidate entities JSON, assumed format: list of strings or dict keys
with open("genre_input.json", "r") as f:
    genre_input = json.load(f)

candidate_entities = genre_input["candidate_entities"]
missing_codes = genre_input["missing_codes"]

In [9]:
from transformers import BartTokenizer
from genre.fairseq_model import GENRE
from genre.trie import Trie

# Load tokenizer and model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
model = GENRE.from_pretrained("models/fairseq_entity_disambiguation_aidayago").eval()

# Tokenize them into ID sequences
tokenized_entities = [
    tokenizer.encode(entity, add_special_tokens=False)
    for entity in candidate_entities
]


# 2. Build trie from candidate entity strings
trie = Trie(sequences=tokenized_entities)

In [None]:
# Run GENRE on each sentence
for missing_code in missing_codes:

    sentence = (
        f'[START_ENT] {missing_code} [END_ENT] is a method code describing a technique for '
        f'geochemical element extraction from rock samples in the GEOROC database.'
    )
    
    results = model.sample(
        [sentence],
        prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist())
    )
    print(f"\n{missing_code}")
    for r in results[0][:3]:  # top 3 candidates
        print(f"  → {r['text']} (score: {r['score'].item():.4f})")