In [99]:
import sys
import torch
import numpy as np
import pandas as pd

sys.path.append('/usr/local/lib/python3.8/site-packages/')
import torchtext

In [18]:
dict_data_path = '../data/dictionary/reverse-dict-singleton.tsv'

glove_embed_dim = 50 # other options are 100, 200, 300
glove_embed_path = f'../data/glove_embed/glove.6B.{glove_embed_dim}d.txt'

In [12]:
# Load dictionary data
dict_data = pd.read_csv(dict_data_path, sep='\t', header=None)

In [15]:
dict_data.head(5)

Unnamed: 0,0,1
0,.22-caliber,of or relating to the bore of a gun (or its am...
1,.22-calibre,of or relating to the bore of a gun (or its am...
2,.22_caliber,of or relating to the bore of a gun (or its am...
3,.22_calibre,of or relating to the bore of a gun (or its am...
4,.38-caliber,of or relating to the bore of a gun (or its am...


In [19]:
# Load pretrained GloVe embeddings
glove_embed = torchtext.vocab.Vectors(glove_embed_path)

100%|█████████▉| 400000/400001 [00:11<00:00, 35331.98it/s]


In [33]:
glove_embed.get_vecs_by_tokens(['22-caliber', 'penguin', 'amortize'])

tensor([[-1.0735, -0.5600,  0.5466, -0.4803, -0.2448,  0.9606,  0.0671,  0.5215,
          0.6887,  0.2482,  0.4354, -0.8980,  0.5264,  0.8516, -1.0347, -0.9734,
         -0.4571,  0.0364, -0.8753, -1.0254, -0.2804, -0.1455, -0.5470, -0.5499,
          0.1977,  0.1307, -0.1276,  0.3804,  0.4866,  0.2864, -1.0666, -0.3158,
          0.3660,  0.3191, -0.1315,  0.5860,  0.5309,  0.2319, -0.0744,  0.2468,
          1.1335, -0.0796, -0.4461,  0.7638,  0.7927, -1.1476,  0.2690,  0.6993,
          0.2323, -0.4295],
        [ 0.2681,  0.1954, -0.8524,  0.1708,  0.4934,  0.1788, -1.1485, -1.5259,
          0.3272, -0.4143,  0.3262,  1.0239,  0.5542,  0.3874,  1.0458,  0.3105,
          0.0677, -0.1454, -1.2247,  0.8357,  0.0674,  0.4555,  0.4163,  0.3000,
         -0.2766,  0.2840, -0.7809, -0.3603, -0.6686, -0.4047,  0.3373, -0.3918,
          0.1391,  1.0417, -0.8919, -0.2836, -0.6314, -1.0775,  0.1857, -0.9852,
          0.8577, -0.2593,  0.2592,  0.6400, -0.2766,  0.4561,  0.5695, -0.3686,


In [119]:
class DictDataset(torch.utils.data.Dataset):
    def __init__(self, definitions, embeddings, embedding_dim, tokenizer=None):
        super(DictDataset, self).__init__()
        if tokenizer is None:
            tokenizer = torchtext.data.utils.get_tokenizer("basic_english")
        self.tokenizer = tokenizer
        
        self.embedding_dim = embedding_dim
        self.embeddings = embeddings
        
        f = lambda x: not torch.all(embeddings.get_vecs_by_tokens([x]) == 0)
        # Filter out words that do not have embeddings
        self.definitions = definitions.loc[definitions[0].apply(f)]
        self.definitions.index = pd.RangeIndex(len(self.definitions.index))
        
    def __getitem__(self, i):
        word, def_text = self.definitions.loc[i] # definition, in plain text form
        tokens = self.tokenizer(def_text)
        return self.embeddings.get_vecs_by_tokens(tokens), self.embeddings.get_vecs_by_tokens([word])
    
    def __len__(self):
        return len(self.definitions)
    
    @staticmethod
    def collate_fn(self, batch):
        batch.sort(key=lambda elem: len(elem[0]), reverse=True)
        return torch.nn.utils.rnn.pack_sequence(batch)

In [120]:
data = DictDataset(dict_data, glove_embed, 50)

In [121]:
data.definitions

Unnamed: 0,0,1
0,1000th,the ordinal number of one thousand in counting...
1,100th,the ordinal number of one hundred in counting ...
2,101,being one more than one hundred
3,101st,the ordinal number of one hundred one in count...
4,105,being five more than one hundred
...,...,...
30821,winterize,prepare for winter
30822,woosh,move with a sibilant sound
30823,wreak,cause to happen or to occur as a consequence
30824,wrest,"obtain by seizing forcibly or violently, also ..."


In [124]:
loader = torch.utils.data.DataLoader(data, shuffle=True, batch_size=1, collate_fn=DictDataset.collate_fn)

In [126]:
for _, (x, y) in zip(range(3), loader):
    print(x.shape, y.shape)

torch.Size([1, 21, 50]) torch.Size([1, 1, 50])
torch.Size([1, 14, 50]) torch.Size([1, 1, 50])
torch.Size([1, 30, 50]) torch.Size([1, 1, 50])
