In [3]:
# For full reproducibility, install additional packages
# mentioned in ../requirement-optional.txt

# basic Python imports
import os
import sys
sys.path.append('../scripts/')
from tqdm.notebook import tqdm
from pathlib import Path
import numpy as np
import pickle

# NLTK
import nltk
#nltk.download('wordnet')
from nltk.corpus import wordnet as wn

# Visualization
from matplotlib import pyplot as plt
from IPython import display
from PIL import Image
%config InlineBackend.figure_format = 'retina'

# PyTorch
import torch
from torch.utils.data import Dataset, DataLoader

# 3rd party and custom scripts 
from coco_utils import load_coco_data, sample_coco_minibatch, decode_captions
from pytorch_pretrained_biggan.utils import IMAGENET, one_hot_from_names, save_as_images
from transformers import AutoTokenizer, AutoModel
from pytorch_pretrained_biggan import BigGAN, truncated_noise_sample
import catalyst
from catalyst.dl import SupervisedRunner
from catalyst.dl.callbacks import EarlyStoppingCallback, CriterionCallback
from catalyst.dl.utils import set_global_seed, prepare_cudnn, plot_metrics

In [5]:
class_to_synset = dict((v, wn.synset_from_pos_and_offset('n', k)) 
                       for k, v in IMAGENET.items())

In [7]:
# Initialize DistilBERT tokenizer
lm_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

In [8]:
words_dataset = {}
all_words = set()
for i, synset in tqdm(class_to_synset.items()):
    current_synset = synset
    while current_synset:
        for lemma in current_synset.lemmas():
            name = lemma.name().replace('_', ' ').lower()
            if name in all_words:
                continue  # Word is already assigned
            if lm_tokenizer.convert_tokens_to_ids(name) != lm_tokenizer.unk_token_id:
                # Word is in Bert tokenizer vocabulary
                words_dataset[i] = name
                all_words.add(name)
                current_synset = False # We're good
                break
        if current_synset and current_synset.hypernyms():
            current_synset = current_synset.hypernyms()[0]
        else:
            current_synset = False

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




In [11]:
'buss' in all_words

False