In [0]:
%matplotlib inline

In [2]:
!pip install torch



In [3]:
!pip install torchvision



In [4]:
!pip install torchtext

Collecting torchtext
[?25l  Downloading https://files.pythonhosted.org/packages/78/90/474d5944d43001a6e72b9aaed5c3e4f77516fbef2317002da2096fd8b5ea/torchtext-0.2.3.tar.gz (42kB)
[K    100% |████████████████████████████████| 51kB 2.2MB/s 
[?25hCollecting tqdm (from torchtext)
[?25l  Downloading https://files.pythonhosted.org/packages/c7/e0/52b2faaef4fd87f86eb8a8f1afa2cd6eb11146822033e29c04ac48ada32c/tqdm-4.25.0-py2.py3-none-any.whl (43kB)
[K    100% |████████████████████████████████| 51kB 3.6MB/s 
Building wheels for collected packages: torchtext
  Running setup.py bdist_wheel for torchtext ... [?25l- \ done
[?25h  Stored in directory: /root/.cache/pip/wheels/42/a6/f4/b267328bde6bb680094a0c173e8e5627ccc99543abded97204
Successfully built torchtext
Installing collected packages: tqdm, torchtext
Successfully installed torchtext-0.2.3 tqdm-4.25.0


Restart: Runtime -> Restart runtime ...

In [1]:
import torchtext.vocab

glove = torchtext.vocab.GloVe(name='6B', dim=100)

print(f'There are {len(glove.itos)} words in the vocabulary')

.vector_cache/glove.6B.zip: 862MB [00:59, 14.4MB/s]                           
100%|██████████| 400000/400000 [00:20<00:00, 19193.33it/s]


There are 400000 words in the vocabulary


In [2]:
glove.vectors.shape

torch.Size([400000, 100])

In [3]:
glove.itos[:10]

['the', ',', '.', 'of', 'to', 'and', 'in', 'a', '"', "'s"]

In [4]:
glove.stoi['the']

0

In [5]:
glove.vectors[glove.stoi['the']].shape

torch.Size([100])

In [0]:
def get_vector(embeddings, word):
    assert word in embeddings.stoi, f'*{word}* is not in the vocab!'
    return embeddings.vectors[embeddings.stoi[word]]

In [7]:
get_vector(glove, 'the').shape

torch.Size([100])

In [0]:
import torch

def closest_words(embeddings, vector, n=10):
    distances = [(w, torch.dist(vector, get_vector(embeddings, w)).item()) for w in embeddings.itos]
    return sorted(distances, key = lambda w: w[1])[:n]

In [9]:
closest_words(glove, get_vector(glove, 'korea'))

[('korea', 0.0),
 ('pyongyang', 3.9039554595947266),
 ('korean', 4.068886756896973),
 ('dprk', 4.2631049156188965),
 ('seoul', 4.340494155883789),
 ('japan', 4.551243782043457),
 ('koreans', 4.615609169006348),
 ('south', 4.65822696685791),
 ('china', 4.839518070220947),
 ('north', 4.986356735229492)]

In [10]:
closest_words(glove, get_vector(glove, 'india'))

[('india', 0.0),
 ('pakistan', 3.6954822540283203),
 ('indian', 4.114313125610352),
 ('delhi', 4.155975818634033),
 ('bangladesh', 4.261017799377441),
 ('lanka', 4.435845851898193),
 ('sri', 4.515716552734375),
 ('australia', 4.806082725524902),
 ('thailand', 4.994781017303467),
 ('malaysia', 5.009334087371826)]

In [0]:
def print_tuples(tuples):
    for w, d in tuples:
        print(f'({d:02.04f}) {w}') 

In [12]:
print_tuples(closest_words(glove, get_vector(glove, 'sports')))

(0.0000) sports
(3.5875) sport
(4.4590) soccer
(4.6508) basketball
(4.6561) baseball
(4.8028) sporting
(4.8763) football
(4.9624) professional
(4.9824) entertainment
(5.0975) media


In [0]:
def analogy(embeddings, word1, word2, word3, n=5):
    
    candidate_words = closest_words(embeddings, get_vector(embeddings, word2) - get_vector(embeddings, word1) + get_vector(embeddings, word3), n+3)
    
    candidate_words = [x for x in candidate_words if x[0] not in [word1, word2, word3]][:n]
    
    print(f'{word1} is to {word2} as {word3} is to...')
    
    return candidate_words

In [14]:
print_tuples(analogy(glove, 'man', 'king', 'woman'))

man is to king as woman is to...
(4.0811) queen
(4.6429) monarch
(4.9055) throne
(4.9216) elizabeth
(4.9811) prince


In [15]:
print_tuples(analogy(glove, 'man', 'actor', 'woman'))

man is to actor as woman is to...
(2.8133) actress
(5.0039) comedian
(5.1399) actresses
(5.2773) starred
(5.3085) screenwriter


In [16]:
print_tuples(analogy(glove, 'cat', 'kitten', 'dog'))

cat is to kitten as dog is to...
(3.8146) puppy
(4.2944) rottweiler
(4.5888) puppies
(4.6086) pooch
(4.6520) pug


In [17]:
print_tuples(analogy(glove, 'france', 'paris', 'england'))

france is to paris as england is to...
(4.1426) london
(4.4938) melbourne
(4.7087) sydney
(4.7630) perth
(4.7952) birmingham


In [18]:
print_tuples(analogy(glove, 'elvis', 'rock', 'eminem'))

elvis is to rock as eminem is to...
(5.6597) rap
(6.2057) rappers
(6.2161) rapper
(6.2444) punk
(6.2690) hop


In [19]:
print_tuples(analogy(glove, 'beer', 'barley', 'wine'))

beer is to barley as wine is to...
(5.6021) grape
(5.6760) beans
(5.8174) grapes
(5.9035) lentils
(5.9454) figs


In [20]:
print_tuples(analogy(glove, 'yen', 'japan', 'ruble'))

yen is to japan as ruble is to...
(7.0202) russia
(7.0706) greece
(7.1550) republic
(7.1591) romania
(7.1675) kazakhstan


In [21]:
print_tuples(analogy(glove, 'ottawa', 'canada', 'nairobi'))

ottawa is to canada as nairobi is to...
(4.6967) kenya
(5.3135) africa
(5.5256) tanzania
(5.6443) uganda
(5.8666) asia


In [22]:
print_tuples(analogy(glove, 'mac', 'apple', 'windows'))

mac is to apple as windows is to...
(7.0108) microsoft
(7.0934) window
(7.0987) glass
(7.2496) adobe
(7.4983) desktop


**WARNING - The larger GloVe embedding needs a lot of resources, you probably can't run the remaining code on Colab**

In [0]:
glove_300d = torchtext.vocab.GloVe(name='840B', dim=300)

.vector_cache/glove.840B.300d.zip: 2.18GB [02:25, 15.0MB/s]                            
 56%|█████▌    | 1220026/2196017 [02:47<02:13, 7299.36it/s]

In [0]:
glove_300d.vectors.shape

torch.Size([2196017, 300])

In [0]:
print_tuples(closest_words(glove_300d, get_vector(glove_300d, 'korea')))

(0.0000) korea
(3.9857) taiwan
(4.4022) korean
(4.9016) asia
(4.9593) japan
(5.0721) seoul
(5.4058) thailand
(5.6025) singapore
(5.7010) russia
(5.7240) hong


In [0]:
print_tuples(closest_words(glove_300d, get_vector(glove_300d, 'relieable')))

(0.0000) relieable
(5.0366) relyable
(5.2610) realible
(5.4719) realiable
(5.5402) relable
(5.5917) relaible
(5.6412) reliabe
(5.8802) relaiable
(5.9593) stabel
(5.9981) consitant


In [0]:
reliable_vector = get_vector(glove_300d, 'reliable')

reliable_misspellings = ['relieable', 'relyable', 'realible', 'realiable', 'relable', 'relaible', 'reliabe', 'relaiable']

diff_reliable = [(reliable_vector - get_vector(glove_300d, s)).unsqueeze(0) for s in reliable_misspellings]

In [0]:
misspelling_vector = torch.cat(diff_reliable, dim=0).mean(dim=0)

In [0]:
print_tuples(closest_words(glove_300d, get_vector(glove_300d, 'becuase') + misspelling_vector))

(6.1090) because
(6.4250) even
(6.4358) fact
(6.4914) sure
(6.5094) though
(6.5601) obviously
(6.5682) reason
(6.5856) if
(6.6099) but
(6.6415) why


In [0]:
print_tuples(closest_words(glove_300d, get_vector(glove_300d, 'defintiely') + misspelling_vector))

(5.4070) definitely
(5.5643) certainly
(5.7192) sure
(5.8152) well
(5.8588) always
(5.8812) also
(5.9557) simply
(5.9667) consider
(5.9821) probably
(5.9948) definately


In [0]:
print_tuples(closest_words(glove_300d, get_vector(glove_300d, 'consistant') + misspelling_vector))

(5.9641) consistent
(6.3674) reliable
(7.0195) consistant
(7.0299) consistently
(7.1605) accurate
(7.2737) fairly
(7.3037) good
(7.3520) reasonable
(7.3801) dependable
(7.4027) ensure


In [0]:
print_tuples(closest_words(glove_300d, get_vector(glove_300d, 'pakage') + misspelling_vector))

(6.6117) package
(6.9315) packages
(7.0195) pakage
(7.0911) comes
(7.1241) provide
(7.1469) offer
(7.1861) reliable
(7.2431) well
(7.2434) choice
(7.2453) offering
