In [None]:
import pickle
from genre.fairseq_model import GENRE
from genre.trie import Trie

model_path = "/home/lars/GENRE/models/fairseq_entity_disambiguation_aidayago"
trie_path = "/home/lars/GENRE/data/kilt_titles_trie_dict.pkl"

# Load Wikipedia titles trie
with open(trie_path, "rb") as f:
    trie = Trie.load_from_dict(pickle.load(f))

# Load model
model = GENRE.from_pretrained(model_path).eval()

# Example cryptic methodcodes in context
codes = [
    "The method [START_ENT] ICPMS_TD [END_ENT] is widely used in elemental analysis.",
    "We used [START_ENT] FE_EMP [END_ENT] for microscopic analysis.",
    "Dating was done using [START_ENT] PB207_PB206_AGE [END_ENT].",
    "We measured radiation using [START_ENT] GAMMA [END_ENT].",
    "Isotope analysis used [START_ENT] WBD [END_ENT].",
    "Analysis performed via [START_ENT] ICPOES [END_ENT].",
    "Determined age using [START_ENT] NE21 AGE [END_ENT]."
]

# Run GENRE on each sentence
for sentence in codes:
    results = model.sample(
        [sentence],
        prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist())
    )
    print(f"\n{sentence}")
    for r in results[0][:3]:  # top 3 candidates
        print(f"  → {r['text']} (score: {r['score'].item():.4f})")


  state = torch.load(f, map_location=torch.device("cpu"))



The method [START_ENT] ICPMS_TD [END_ENT] is widely used in elemental analysis.
  → Inductively coupled plasma mass spectrometry (score: -0.1703)
  → Inductively coupled plasma atomic emission spectroscopy (score: -1.2015)
  → ICPM (score: -2.2986)

We used [START_ENT] FE_EMP [END_ENT] for microscopic analysis.
  → Faraday effect (score: -1.8307)
  → Far-infrared astronomy (score: -3.0039)
  → Far-infrared laser (score: -3.3976)

Dating was done using [START_ENT] PB207_PB206_AGE [END_ENT].
  → Parallel ATA (score: -1.7021)
  → Postal codes in Malaysia (score: -1.9363)
  → Postal codes in Canada (score: -1.9818)

We measured radiation using [START_ENT] GAMMA [END_ENT].
  → Gamma-Aminobutyric acid (score: -0.5029)
  → GAMMA (score: -0.6394)
  → Gamma spectroscopy (score: -1.2691)

Isotope analysis used [START_ENT] WBD [END_ENT].
  → Wavelength-division multiplexing (score: -0.8292)
  → WBD (score: -1.5071)
  → Wavelength-division multiple access (score: -1.6072)

Analysis performed via 

In [32]:
import pickle
from genre.fairseq_model import GENRE
from genre.trie import Trie

trie_path = "/home/lars/GENRE/data/kilt_titles_trie_dict.pkl"

# Load Wikipedia titles trie
with open("/home/lars/GENRE/data/kilt_titles_trie_dict.pkl", "rb") as f:
    trie = Trie.load_from_dict(pickle.load(f))


title = "Inductively coupled plasma atomic emission spectroscopy"

In [22]:
candidate_entities = [
  "TIMS_ID: ISOTOPE-DILUTION THERMAL-IONIZATION MASS SPECTROMETRY",
  "TIMS_CA_ID: CHEMICAL ABRASION ISOTOPE-DILUTION THERMAL-IONIZATION MASS SPECTROMETRY",
  "EMP (EPMA): ELECTRON MICROPROBE ANALYSIS",
  "LA-ICPMS: LASER ABLATION INDUCTIVELY-COUPLED PLASMA MASS SPECTROMETRY",
  "SIMS: SECONDARY IONIZATION MASS SPECTROMETRY",
  "MC-ICPMS: MULTI-COLLECTOR INDUCTIVELY COUPLED PLASMA MASS SPECTROMETRY"
]

from transformers import BartTokenizer

# Load tokenizer and model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
tokenized_entities = [
    [2] + tokenizer.encode(entity, add_special_tokens=False)
    for entity in candidate_entities
]
for e in tokenized_entities:
    print(e)


# 2. Build trie from candidate entity strings
trie = Trie(sequences=tokenized_entities)

[2, 565, 3755, 104, 1215, 2688, 35, 3703, 3293, 23075, 12, 495, 3063, 34058, 8640, 2076, 43602, 12, 7744, 17045, 6034, 256, 17042, 44921, 6997, 3765, 3935, 16802]
[2, 565, 3755, 104, 1215, 4054, 1215, 2688, 35, 3858, 5330, 24308, 6266, 500, 2336, 7744, 3703, 3293, 23075, 12, 495, 3063, 34058, 8640, 2076, 43602, 12, 7744, 17045, 6034, 256, 17042, 44921, 6997, 3765, 3935, 16802]
[2, 42257, 36, 9662, 5273, 3256, 36630, 37043, 29615, 500, 5733, 500, 7912, 717, 5102, 2118, 14780, 1729]
[2, 8272, 12, 2371, 510, 6222, 35, 226, 2336, 2076, 83, 7976, 6034, 12569, 28120, 6372, 29313, 12, 347, 5061, 7205, 1691, 12901, 2336, 5273, 256, 17042, 44921, 6997, 3765, 3935, 16802]
[2, 37266, 104, 35, 3614, 7054, 10760, 38, 2191, 17045, 6034, 256, 17042, 44921, 6997, 3765, 3935, 16802]
[2, 6018, 12, 2371, 510, 6222, 35, 256, 25938, 100, 12, 18047, 42120, 3411, 12569, 28120, 6372, 29313, 230, 5061, 7205, 1691, 12901, 2336, 5273, 256, 17042, 44921, 6997, 3765, 3935, 16802]


In [23]:
trie.trie_dict

{2: {565: {3755: {104: {1215: {2688: {35: {3703: {3293: {23075: {12: {495: {3063: {34058: {8640: {2076: {43602: {12: {7744: {17045: {6034: {256: {17042: {44921: {6997: {3765: {3935: {16802: {}}}}}}}}}}}}}}}}}}}}}}},
      4054: {1215: {2688: {35: {3858: {5330: {24308: {6266: {500: {2336: {7744: {3703: {3293: {23075: {12: {495: {3063: {34058: {8640: {2076: {43602: {12: {7744: {17045: {6034: {256: {17042: {44921: {6997: {3765: {3935: {16802: {}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}},
  42257: {36: {9662: {5273: {3256: {36630: {37043: {29615: {500: {5733: {500: {7912: {717: {5102: {2118: {14780: {1729: {}}}}}}}}}}}}}}}}},
  8272: {12: {2371: {510: {6222: {35: {226: {2336: {2076: {83: {7976: {6034: {12569: {28120: {6372: {29313: {12: {347: {5061: {7205: {1691: {12901: {2336: {5273: {256: {17042: {44921: {6997: {3765: {3935: {16802: {}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}},
  37266: {104: {35: {3614: {7054: {10760: {38: {2191: {17045: {6034: {256: {17042: {44921: {6997: {3765: {3935: {16802: {}}}}}}}}}

In [33]:
from transformers import BartTokenizer

# Load tokenizer and model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
tokenized_entities = tokenizer.encode(title, add_special_tokens=False)

for e in tokenized_entities:
    print(e)

15248
21491
6608
11531
29051
21495
22679
19416
39655
16572


In [34]:
tokenized_entities = [2] + tokenized_entities
for e in tokenized_entities:
    print(e)



2
15248
21491
6608
11531
29051
21495
22679
19416
39655
16572


In [35]:
for e in range(len(tokenized_entities)):
    token_sequence = tokenized_entities[:e]
    print(token_sequence)
    possible_next_entities = trie.get(token_sequence)
    print(possible_next_entities)
    if e in possible_next_entities:
        print(e)

[]
[2]
[2]
[250, 7083, 26145, 26880, 42820, 4688, 29743, 3684, 37434, 36583, 45680, 19897, 38150, 41622, 39021, 104, 36977, 15724, 25127, 23055, 25089, 40211, 14484, 24185, 28216, 41415, 46953, 22611, 17858, 8138, 35242, 40230, 673, 43827, 510, 133, 28535, 387, 48761, 13112, 18935, 597, 27298, 495, 37388, 725, 29390, 40948, 21169, 23031, 5320, 975, 1301, 771, 24476, 42594, 27845, 24877, 30811, 448, 12271, 36675, 45628, 25447, 10169, 15827, 31371, 3609, 7539, 25869, 38448, 38517, 28151, 6407, 39482, 40080, 11770, 24007, 30121, 29217, 10567, 100, 17521, 534, 44287, 47967, 32686, 41192, 39323, 5096, 574, 15426, 26519, 26222, 42734, 4771, 15117, 32405, 44559, 13368, 35689, 23952, 33153, 33531, 717, 30019, 47181, 35695, 347, 17297, 35136, 32743, 45093, 10350, 32884, 1366, 40683, 16551, 6517, 45558, 40529, 4154, 35792, 4741, 15791, 48080, 12645, 28678, 5054, 31680, 863, 30888, 25395, 40169, 38178, 33295, 36728, 3506, 25093, 500, 9518, 9325, 37475, 23791, 16215, 42765, 13365, 39845, 20556, 27

In [31]:
import json


# 1. Load your candidate entities JSON, assumed format: list of strings or dict keys
with open("genre_input.json", "r") as f:
    genre_input = json.load(f)

candidate_entities = genre_input["candidate_entities"]
missing_codes = genre_input["missing_codes"]

In [9]:
from transformers import BartTokenizer
from genre.fairseq_model import GENRE
from genre.trie import Trie

# Load tokenizer and model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
model = GENRE.from_pretrained("models/fairseq_entity_disambiguation_aidayago").eval()

# Tokenize them into ID sequences
tokenized_entities = [
    tokenizer.encode(entity, add_special_tokens=False)
    for entity in candidate_entities
]


# 2. Build trie from candidate entity strings
trie = Trie(sequences=tokenized_entities)

In [None]:
# Run GENRE on each sentence
for missing_code in missing_codes:

    sentence = (
        f'[START_ENT] {missing_code} [END_ENT] is a method code describing a technique for '
        f'geochemical element extraction from rock samples in the GEOROC database.'
    )
    
    results = model.sample(
        [sentence],
        prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist())
    )
    print(f"\n{missing_code}")
    for r in results[0][:3]:  # top 3 candidates
        print(f"  → {r['text']} (score: {r['score'].item():.4f})")