In [1]:
from pathlib import Path

while Path.cwd().name != 'ambignli':
    %cd ..

/mmfs1/gscratch/xlab/alisaliu/ambignli


In [4]:
from generation.gpt3_generation import request
import random
from tqdm import tqdm
import pandas as pd
import numpy as np
from collections import Counter
import peewee
from conceptnet_lite import edges_for, Label
import conceptnet_lite

In [6]:
conceptnet_lite.connect("/gscratch/scrubbed/alisaliu/conceptnet.db")

In [25]:
instruction = "For each word, provide two disambiguations.\n\n"
def construct_prompt(ambiguous_words_dict):
    prompt = instruction
    in_context_examples = random.sample(ambiguous_words_dict.items(), 3)
    for word, disambiguations in in_context_examples:
        prompt += f'The word "{word}" has two meanings: {disambiguations[0]} and {disambiguations[1]}\n'
    prompt += 'The word'
    return prompt

In [26]:
def in_conceptnet(word):
    word = '_'.join(word.split(' '))
    try:
        edges = edges_for(Label.get(text=word).concepts, same_language=True)
        if len(edges) == 0:
            return False
    except peewee.DoesNotExist:
        return False
    return True

In [27]:
def get_outgoing_edges(concept):    
    edges = edges_for(Label.get(text=concept).concepts, same_language=True)
    outgoing_edges = []
    for e in edges:
        if e.start.text == concept:
            outgoing_edges.append(e)
    return outgoing_edges

In [28]:
def get_incoming_edges(concept):    
    edges = edges_for(Label.get(text=concept).concepts, same_language=True)
    incoming_edges = []
    for e in edges:
        if e.end.text == concept:
            incoming_edges.append(e)
    return incoming_edges

In [29]:
def print_edges(concept):
    edges = edges_for(Label.get(text=concept).concepts, same_language=True)
    for e in edges:
        print(f'{e.start.text} -- {e.relation.name} ({e.etc["weight"]}) --> {e.end.text}')

In [31]:
def trim_generation(word):
    if word.startswith('a '):
        return word[2:]
    else:
        return word

In [47]:
ambiguous_words_dict = {
    'mouse': ['computer mouse', 'house mouse'],
    'bow': ['violin bow', 'hair bow'],
    'trunk': ['elephant trunk', 'car trunk'],
    'bass': ['sea bass', 'double bass'],
    'crane': ['construction crane', 'whooping crane'],
    'jam': ['strawberry jam', 'traffic jam']
}

In [48]:
results = []
for i in tqdm(range(100)):
    prompt = construct_prompt(ambiguous_words_dict)
    gen = request(
        prompt=prompt,
        top_p=0.8,
        model='text-davinci-002',
        return_only_text=True
    )
    word = gen.split('"')[1]
    d1 = gen.split('meanings: ')[-1].split(' and')[0]
    d2 = gen.split('or ')[-1]
    
    # check validity of generations
    d1, d2 = trim_generation(d1), trim_generation(d2)
    if (word not in ambiguous_words_dict 
#         and in_conceptnet(d1) and in_conceptnet(d2) 
        and ' ' in d1 and ' ' in d2
        and np.all([pos not in d1] for pos in ['verb', 'noun', 'adjective', 'adverb'])
        and np.all([pos not in d2] for pos in ['verb', 'noun', 'adjective', 'adverb'])
        and word in d1.split(' ') and word in d2.split(' ')):
        ambiguous_words_dict[word] = [d1, d2]
    
    results.append({
        'prompt': prompt,
        'gen': gen
    })

since Python 3.9 and will be removed in a subsequent version.
  in_context_examples = random.sample(ambiguous_words_dict.items(), 3)
  4%|▍         | 4/100 [00:04<01:44,  1.09s/it]


KeyboardInterrupt: 

In [46]:
ambiguous_words_dict

{'jam': ['strawberry jam', 'traffic jam'],
 'bow': ['violin bow', 'hair bow'],
 'mouse': ['computer mouse', 'house mouse'],
 'trunk': ['elephant trunk', 'car trunk'],
 'bass': ['sea bass', 'double bass'],
 'crane': ['construction crane', 'whooping crane'],
 'rose': ['rose bush', 'rose color'],
 'hole': ['donut hole', 'golf hole'],
 'rod': ['fishing rod', 'metal rod'],
 'suit': ['business suit', 'playing card suit'],
 'bat': ['baseball bat', 'bat the mammal'],
 'run': ['to run as in to move quickly on foot',
  'run as in a streak of success'],
 'map': ['physical map', 'mental map'],
 'lead': ['lead poisoning', 'lead in a pencil'],
 'bear': ['teddy bear', 'grizzly bear'],
 'fly': ['fly the insect', 'fly the verb'],
 'spruce': ['Norway spruce', 'spruce up'],
 'fair': ['fair the event', 'fair the adjective'],
 'screen': ['computer screen', 'movie screen'],
 'duck': ['duck the bird', 'duck the verb'],
 'break': ['break the verb', 'break the noun'],
 'sink': ['kitchen sink', 'sink the ship']

In [241]:
pd.DataFrame(ambiguous_words_dict.items(), columns=['word', 'disambiguations']).to_json('dalle/ambiguous_words.jsonl', lines=True, orient='records')

In [429]:
def get_good_neighbors(edges, ambiguous_word, num=2, edge_direction='incoming'):
    neighbor_names = []
    words = []
    for e in edges:
        neighbor_name = e.start.text if edge_direction == 'incoming' else e.end.text
        for word in neighbor_name.split('_'):
            words.extend([word]*int(np.ceil(e.etc['weight'])))
        neighbor_names.append(neighbor_name)
    counter = Counter(words)
    neighbor_common_words = [k for k,v in dict(sorted(counter.items(), key=lambda x: x[1], reverse=True)).items() if k != ambiguous_word][:num]
    
    neighbor_nodes = []
    for word in neighbor_common_words:
        if word not in neighbor_names:
            # find a neighbor with the desired word
            neighbor_nodes.append(random.sample([n for n in neighbor_names if word in n], k=1)[0])
        else:
            neighbor_nodes.append(word)
    return neighbor_nodes

In [433]:
def get_lexical_set(ambiguous_word, disambiguated_word, num_end_nodes=2):
    words = []
    disambiguated_word = '_'.join(disambiguated_word.split(' '))
    lexical_set = set()

    outgoing_edges = get_outgoing_edges(disambiguated_word)
    neighbor_nodes = get_good_neighbors(outgoing_edges, ambiguous_word=ambiguous_word, edge_direction='outgoing')
    lexical_set.update(neighbor_nodes)
    
    for neighbor_node in neighbor_nodes:
        incoming_nodes = get_good_neighbors(get_incoming_edges(neighbor_node), ambiguous_word=ambiguous_word, edge_direction='incoming')
        new_nodes = random.sample(incoming_nodes, k=min(len(incoming_nodes), 2))
        lexical_set.update(new_nodes)
    
    lexical_set = [y for y in lexical_set if not any(x in y for x in lexical_set if x != y)]
    return [' '.join(w.split('_')) for w in lexical_set]

In [253]:
def get_constraints(lexical_set, num_constraints=2):
    return random.sample(lexical_set, k=min(len(lexical_set), num_constraints))

In [438]:
constraints = []
for word, disambiguations in tqdm(ambiguous_words_dict.items()):
    d1, d2 = disambiguations
    constraints.append([d2] + get_constraints(get_lexical_set(word, d1)))
    constraints.append([d1] + get_constraints(get_lexical_set(word, d2)))

100%|██████████| 7/7 [00:22<00:00,  3.23s/it]


In [434]:
get_lexical_set('jam', 'traffic_jam')

['fast', 'slow', 'traffic', 'sign']

In [355]:
def construct_caption_prompt(lexical_subset):
    return f'Write an image caption with the words {", ".join(lexical_subset[:-1])} and {lexical_subset[-1]}.'

In [444]:
def generate_caption(constraints, max_tries=5):
    caption_prompt = construct_caption_prompt(constraints)
    prompt = ''
    num_tries = 0
    while not np.all([w.lower() in prompt.lower() for w in constraints]):  # ensure constraints are met
        prompt = request(
            prompt=caption_prompt,
            model='text-davinci-002',
            top_p=0.8,
            return_only_text=True,
            stop=None
        )
        num_tries += 1
        if num_tries > max_tries:
            return None
            break
    return prompt

In [446]:
for c in constraints:
    print(f'{c}: {generate_caption(c)}')

['traffic jam', 'jar', 'tomato chutney']: A traffic jam of cars, all bumper to bumper, extends as far as the eye can see. In the backseat of one of the cars, a jar of tomato chutney sits unopened.
['strawberry jam', 'sign', 'slow']: A jar of strawberry jam with a hand-written sign that says "Slow down and enjoy life."
['bow tie', 'fiddlestick', 'violin bow']: A bow tie, fiddlestick, and violin bow on a table.
['violin bow', 'knot', 'tie']: A violin bow tied with a knot, ready to be used.
['house mouse', 'computer', 'desktop']: The house mouse scurries across the computer desktop, looking for a crumb or two.
['computer mouse', 'opossum', 'mammal']: A computer mouse sits atop an opossum, both creatures classified as mammals.
['car trunk', 'planet', 'star']: A car trunk, a planet and a star.
['elephant trunk']: None
['double bass', 'shark', 'ocean']: The double bassist on the shark-infested ocean floor.
['sea bass', 'playing in jazz band', 'music']: Sea bass, playing in jazz band and musi