In [1]:
import json

import torch

import numpy as np

from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import GPT2LMHeadModel

In [6]:
splits = ['train', 'val', 'test']
data = {
    k: json.load(open(f'tune/words_dataset_unique.csv_{k}.json'))
    for k in splits
}

In [3]:
# reload models
modelf = GPT2LMHeadModel.from_pretrained('tune/model_unique_best/')
modelr = GPT2LMHeadModel.from_pretrained('tune/model_unique_rev_best/')

In [4]:
# get tokenizer
tok = AutoTokenizer.from_pretrained('gpt2')
tok.add_special_tokens({'pad_token': '<|endoftext|>'})

0

## Word Guessing Game

In [13]:
def guess_words(query, target, n_words, maxlen=8, verbose=False):
    seed_txt = 'Definition: ' + query + ' ; Word:'

    input_ids = tok.encode(seed_txt, return_tensors='pt')

    beam_outputs = modelr.generate(
        input_ids,
        max_length=input_ids.shape[1] + maxlen,
        num_beams=n_words,
        num_return_sequences=n_words,
        early_stopping=True,
        pad_token_id=tok.eos_token_id,
    )

    words = []
    if verbose:
        print('Seed:', seed_txt)
        print(f"Guesses (True Answer: {target}):\n" + 100 * '-')

    for i, beam_output in enumerate(beam_outputs):
        w = tok.decode(beam_output, skip_special_tokens=True).replace(seed_txt, '').strip()
        words.append(w)

        if verbose:
            print("{}: {}".format(i, w))

    metrics = {
        'Exact MRR': max([1 / (ix + 1) if w == target else 0 for ix, w in enumerate(words)]),
    }
    
    if verbose:
        print('-' * 100)
        for k, v in metrics.items():
            print(f'{k}: {v:.2f}')
    
    return words, metrics

In [95]:
# d = 'To rapidly move. Do this in a race'
# w = 'run'
# n = 10

# d = 'A plant with petals'
# w = 'flower'
# n = 20

# d = 'A device to control a television'
# w = 'remote'
# n = 10

# d = 'The use of computing to solve scientific and engineering problems, especially by means of simulation, or the construction of mathematical models of physical, chemical or biological processes'
# w = 'scientific computing'
# n = 100

# d = '(mathematics) The study of algorithms to solve mathematical problems concerning continuous sets of values (such as the real numbers, complex numbers or vector spaces).'
# d = 'the study of algorithms for the problems of continuous mathematics'
# w = 'numerical analysis'
# n = 100

# d = 'a simple hydrocarbon; a powerful greenhouse gas.'
# d = 'a poweful greenhouse gas'
# d = 'colourless, odourless gas that occurs abundantly in nature and as a product of certain human activities'
# w = 'methane'
# n = 100

# d = '(physics) In the Standard Model, an elementary subatomic particle that forms matter. They combine to form hadrons, such as protons and neutrons.'
# w = 'quark'
# n = 100

d = 'A large mammal found in arctic regions'
w = 'polar bear'
n = 20

_ = guess_words(d, w, n, verbose=True)

Seed: Definition: A large mammal found in arctic regions ; Word:
Guesses (True Answer: polar bear):
----------------------------------------------------------------------------------------------------
0: polar bear
1: penguin
2: bear
3: grizzly
4: fox
5: mammoth
6: iceman
7: raccoon
8: puffer
9: stag
10: skunk
11: wolf
12: skank
13: deer
14: pack
15: polaroid
16: panda
17: siren
18: bosh
19: tule
----------------------------------------------------------------------------------------------------
Exact MRR: 1.00


## Definition Generation

In [71]:
def top_p_sample(text, p, temp=1.0, maxlen=128, minlen=2, num_samples=3, verbose=True, latex=False, label=None):
    prefix = torch.tensor(tok.encode(text)).unsqueeze(0)

    sample_outputs = modelf.generate(
        prefix,
        pad_token_id=50256,
        do_sample=True,
        temperature=temp,
        max_length=prefix.shape[1] + maxlen,
        min_length=prefix.shape[1] + minlen,
        top_p=p,
        num_return_sequences=num_samples,
    )

    samples = []
    for i, sample_output in enumerate(sample_outputs):
        ox = tok.decode(sample_output, skip_special_tokens=True)
        samples.append(ox)

        if verbose:
            out = "{}: {}".format(i, ox)
            print(out)
        
        if latex:
            if i == 0:
                out = text if label is None else label
            else:
                out = ''
            out += ' & \\parbox{10cm}{'
            out += ox.replace(text, '')
            out += '} \\\\'
            print(out)
            print('\\hline')
    

    return samples

In [96]:
p = 0.95
# w = 'eigenspectrum'
# w = 'matrix'
# w = 'scientific computing'
# w = 'tweet'
# w = 'jawn'
# w = 'polar bear'

tx = 'Word: ' + w + ' ; Definition:'
_ = top_p_sample(tx, p, label=w, verbose=True, latex=False)

0: Word: polar bear ; Definition: polar bear (plural polar bears)
1: Word: polar bear ; Definition: (figuratively) Anything similar, similar to (or similar to) a polar bear, such as a polar bear in fur or feathers; an attractive or intimidating looking person, person, or person-animal.
2: Word: polar bear ; Definition: (uncountable, derogatory) A large mammal of similar weight, stature and sexual orientation.


## Example Usage Generation

In [98]:
p = 0.95
# w = 'scientific computing'
# w = 'matrix'
# w = 'tweet'
# w = 'eigenspectrum'
# w = 'jawn'
w = 'polar bear'

tx = 'Word: ' + w + ' ; Example:'
_ = top_p_sample(tx, p, label=w, verbose=True, latex=False)

0: Word: polar bear ; Example: That's an image of polar bears all over, one of the polar bears looking down towards their captions as they travel through the air in a circular motion.
1: Word: polar bear ; Example: The polar bear is the smallest of the polar bears, but is larger than many a mammal.
2: Word: polar bear ; Example: Polar bears are typically seen in many different species, including polar bears in the Australian and New Zealand polar bear subspecies.
