An implementation of GloVe (Global vectors for word representation)

In [112]:
import nltk
from collections import Counter
from itertools import combinations_with_replacement
from tqdm import tqdm
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

In [113]:
torch.manual_seed(2345678)

<torch._C.Generator at 0x215a1c03710>

# Load data using the NLTK corpus

In [114]:
nltk.corpus.gutenberg.fileids() # Show all texts in the Gutenberg corpus in the NLTK

['austen-emma.txt',
 'austen-persuasion.txt',
 'austen-sense.txt',
 'bible-kjv.txt',
 'blake-poems.txt',
 'bryant-stories.txt',
 'burgess-busterbrown.txt',
 'carroll-alice.txt',
 'chesterton-ball.txt',
 'chesterton-brown.txt',
 'chesterton-thursday.txt',
 'edgeworth-parents.txt',
 'melville-moby_dick.txt',
 'milton-paradise.txt',
 'shakespeare-caesar.txt',
 'shakespeare-hamlet.txt',
 'shakespeare-macbeth.txt',
 'whitman-leaves.txt']

Build the word to index mapping

In [115]:
# Select the melville-moby_dick.txt as the training sample
melville_words = nltk.corpus.gutenberg.words('melville-moby_dick.txt')[:500]

# Lower all words
vocab = list(set([w.lower() for w in melville_words]))

# Show the first five words in the processed list 
print(vocab[:5])
print(f'Total words in the text: {len(vocab)}')

['piggledy', 'worm', 'have', 'this', 'greek']
Total words in the text: 258


# Build co-occurence blobs (the $X_{ij}$ values)

Build a word co-occurrence list for the given corpus, where each element (representing
a cooccurrence pair) is of the form $(i_{id}, j_{id}, X_{ij})$, where 
$i_{id}$   : the ID of the main word in the cooccurrence;
$j_{id}$: the ID of the context word;
$X_{ij}$: the the cooccurrence value

In [116]:
melville_sents = nltk.corpus.gutenberg.sents('melville-moby_dick.txt')[:500]
melville_sents = [[w.lower() for w in s] for s in melville_sents]

for sent in melville_sents[:5]:
    print(sent)

['[', 'moby', 'dick', 'by', 'herman', 'melville', '1851', ']']
['etymology', '.']
['(', 'supplied', 'by', 'a', 'late', 'consumptive', 'usher', 'to', 'a', 'grammar', 'school', ')']
['the', 'pale', 'usher', '--', 'threadbare', 'in', 'coat', ',', 'heart', ',', 'body', ',', 'and', 'brain', ';', 'i', 'see', 'him', 'now', '.']
['he', 'was', 'ever', 'dusting', 'his', 'old', 'lexicons', 'and', 'grammars', ',', 'with', 'a', 'queer', 'handkerchief', ',', 'mockingly', 'embellished', 'with', 'all', 'the', 'gay', 'flags', 'of', 'all', 'the', 'known', 'nations', 'of', 'the', 'world', '.']


In [117]:
flatten = lambda l: [item for sublist in l for item in sublist]
vocab = list(set(flatten(melville_sents)))

# Mark each word in the vocab list by its index
word2index = {word:index for index, word in enumerate(vocab)}
# print(word2index)

# Also build the integer to word list
index2word = {index:word for word, index in word2index.items()}
# print(index2word)

In [118]:
# Insert each sentence with <S> at the begining and append <E> at the end
WINDOW_SIZE = 5
_sents = [['<S>'] * WINDOW_SIZE + s + ['<E>'] * WINDOW_SIZE for s in melville_sents]

for s in _sents[:5]:
    print(s)

['<S>', '<S>', '<S>', '<S>', '<S>', '[', 'moby', 'dick', 'by', 'herman', 'melville', '1851', ']', '<E>', '<E>', '<E>', '<E>', '<E>']
['<S>', '<S>', '<S>', '<S>', '<S>', 'etymology', '.', '<E>', '<E>', '<E>', '<E>', '<E>']
['<S>', '<S>', '<S>', '<S>', '<S>', '(', 'supplied', 'by', 'a', 'late', 'consumptive', 'usher', 'to', 'a', 'grammar', 'school', ')', '<E>', '<E>', '<E>', '<E>', '<E>']
['<S>', '<S>', '<S>', '<S>', '<S>', 'the', 'pale', 'usher', '--', 'threadbare', 'in', 'coat', ',', 'heart', ',', 'body', ',', 'and', 'brain', ';', 'i', 'see', 'him', 'now', '.', '<E>', '<E>', '<E>', '<E>', '<E>']
['<S>', '<S>', '<S>', '<S>', '<S>', 'he', 'was', 'ever', 'dusting', 'his', 'old', 'lexicons', 'and', 'grammars', ',', 'with', 'a', 'queer', 'handkerchief', ',', 'mockingly', 'embellished', 'with', 'all', 'the', 'gay', 'flags', 'of', 'all', 'the', 'known', 'nations', 'of', 'the', 'world', '.', '<E>', '<E>', '<E>', '<E>', '<E>']


In [119]:
# Construct the ngrams for building word pairs (i, j) 
windows = flatten([list(nltk.ngrams(s, WINDOW_SIZE * 2 + 1)) for s in _sents])

for s in windows[:20]:
    print(s)

('<S>', '<S>', '<S>', '<S>', '<S>', '[', 'moby', 'dick', 'by', 'herman', 'melville')
('<S>', '<S>', '<S>', '<S>', '[', 'moby', 'dick', 'by', 'herman', 'melville', '1851')
('<S>', '<S>', '<S>', '[', 'moby', 'dick', 'by', 'herman', 'melville', '1851', ']')
('<S>', '<S>', '[', 'moby', 'dick', 'by', 'herman', 'melville', '1851', ']', '<E>')
('<S>', '[', 'moby', 'dick', 'by', 'herman', 'melville', '1851', ']', '<E>', '<E>')
('[', 'moby', 'dick', 'by', 'herman', 'melville', '1851', ']', '<E>', '<E>', '<E>')
('moby', 'dick', 'by', 'herman', 'melville', '1851', ']', '<E>', '<E>', '<E>', '<E>')
('dick', 'by', 'herman', 'melville', '1851', ']', '<E>', '<E>', '<E>', '<E>', '<E>')
('<S>', '<S>', '<S>', '<S>', '<S>', 'etymology', '.', '<E>', '<E>', '<E>', '<E>')
('<S>', '<S>', '<S>', '<S>', 'etymology', '.', '<E>', '<E>', '<E>', '<E>', '<E>')
('<S>', '<S>', '<S>', '<S>', '<S>', '(', 'supplied', 'by', 'a', 'late', 'consumptive')
('<S>', '<S>', '<S>', '<S>', '(', 'supplied', 'by', 'a', 'late', 'consu

In [120]:
# Building word pairs
word_pairs = []

for window in windows:
    for i in range(WINDOW_SIZE * 2 + 1):
        # Center at the word at index WINDOW_SIZE (which is at index 5), denoted as i
        # Ignore <S> and <E> and the word i itself, all other words in the windows are context, denoted as j 
        # Then make the pair in the form as (i, j)
        if i == WINDOW_SIZE or window[i] == '<S>' or window[i] == '<E>': 
            continue
        word = window[WINDOW_SIZE]
        context = window[i]
        word_pairs.append((word, context))    

In [121]:
print(word_pairs[:10])

[('[', 'moby'), ('[', 'dick'), ('[', 'by'), ('[', 'herman'), ('[', 'melville'), ('moby', '['), ('moby', 'dick'), ('moby', 'by'), ('moby', 'herman'), ('moby', 'melville')]


Build co-occurence matrix $X$ 

In [122]:
# Record the occurence time of each word in the corpus, denoted as X_i
X_i = Counter(flatten(melville_sents))

# Record the co-occurence time of each word within window_size using onstructed word pairs 
X_ik_window = Counter(word_pairs)

In [227]:
# X_i

In [123]:
X_ik_window

Counter({('[', 'moby'): 1,
         ('[', 'dick'): 1,
         ('[', 'by'): 1,
         ('[', 'herman'): 1,
         ('[', 'melville'): 1,
         ('moby', '['): 1,
         ('moby', 'dick'): 1,
         ('moby', 'by'): 1,
         ('moby', 'herman'): 1,
         ('moby', 'melville'): 1,
         ('moby', '1851'): 1,
         ('dick', '['): 1,
         ('dick', 'moby'): 1,
         ('dick', 'by'): 1,
         ('dick', 'herman'): 1,
         ('dick', 'melville'): 1,
         ('dick', '1851'): 1,
         ('dick', ']'): 1,
         ('by', '['): 1,
         ('by', 'moby'): 1,
         ('by', 'dick'): 1,
         ('by', 'herman'): 1,
         ('by', 'melville'): 1,
         ('by', '1851'): 1,
         ('by', ']'): 1,
         ('herman', '['): 1,
         ('herman', 'moby'): 1,
         ('herman', 'dick'): 1,
         ('herman', 'by'): 1,
         ('herman', 'melville'): 1,
         ('herman', '1851'): 1,
         ('herman', ']'): 1,
         ('melville', '['): 1,
         ('melville', 'mo

In [124]:
X_ik = {}
weighting_dict = {}

In [125]:
# All the possible subsets or arrangements of the iterator
# Elements are allowed to repeat in a subset because of the word 'replacement'
# In fact, combinations_with_replacement(vocab, 2) is a matrix with size len(vocab) x len(vocab)
# The following line list the first five samples  
vocab_mat = list(combinations_with_replacement(vocab, 2))

In [126]:
print(vocab_mat[:5])

[('return', 'return'), ('return', 'elements'), ('return', 'exercise'), ('return', 'savages'), ('return', 'holding')]


Compute the weighting function

In [127]:
def weighting(w_i, w_j):
    try:
        x_ij = X_ik[(w_i, w_j)]
    except:
        x_ij = 1
        
    x_max = 100 #100 # fixed in paper
    alpha = 0.75
    
    if x_ij < x_max:
        result = (x_ij/x_max)**alpha
    else:
        result = 1
    
    return result

In [128]:
for i in tqdm(range(len(vocab_mat))):
    pair = vocab_mat[i]
    if X_ik_window.get(pair) is not None: # If the pair exists in text (sentence)
        co_occur = X_ik_window[pair]      # Get its co-occurence time
        X_ik[(pair[0], pair[1])] = co_occur + 1 # log(Xik) -> log(Xik+1) to prevent divergence
        X_ik[(pair[1], pair[0])] = co_occur + 1 # Symmetry 
    else: # If the pair not exist, then do nothing
        pass

    # Compute the weight of the pair, symmetrically
    weighting_dict[(pair[0], pair[1])] = weighting(pair[0], pair[1])
    weighting_dict[(pair[1], pair[0])] = weighting(pair[1], pair[0])

100%|██████████| 3399528/3399528 [00:11<00:00, 283462.07it/s]


# Prepare dataset

In [131]:
word = []
context = []
co_occur = []
weights = []

for pair in word_pairs:
    # Convert the word into a tensor after integer representation
    word.append(torch.tensor(word2index[pair[0]], device='cuda').view(1, -1))
    context.append(torch.tensor(word2index[pair[1]], device='cuda').view(1, -1))

    try:
        co_occur_count = torch.tensor([X_ik[pair]], device='cuda').view(1, -1)
    except:
        co_occur_count = torch.tensor([1.0], device='cuda').view(1, -1)
    
    # Record the log of the co-occurence and the weight of the pair
    co_occur.append(torch.log(co_occur_count))
    weights.append(torch.tensor([weighting_dict[pair]], device='cuda').view(1, -1))

# Gather all four list together as the training data
train_data = list(zip(word, context, co_occur, weights))

In [132]:
print(word_pairs[0])
print(train_data[0])

('[', 'moby')
(tensor([[2515]], device='cuda:0'), tensor([[1006]], device='cuda:0'), tensor([[0.6931]], device='cuda:0'), tensor([[0.0532]], device='cuda:0'))


# Build the model

In [133]:
class GloVe(nn.Module):
    
    def __init__(self, vocab_size, projection_dim):
        super(GloVe, self).__init__()
        self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding
        self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding
        
        self.v_bias = nn.Embedding(vocab_size, 1)
        self.u_bias = nn.Embedding(vocab_size, 1)
        
        initrange = (2.0 / (vocab_size + projection_dim))**0.5 # Xavier init
        self.embedding_v.weight.data.uniform_(-initrange, initrange) # init
        self.embedding_u.weight.data.uniform_(-initrange, initrange) # init
        self.v_bias.weight.data.uniform_(-initrange, initrange) # init
        self.u_bias.weight.data.uniform_(-initrange, initrange) # init
        
    def forward(self, center_words, target_words, coocs, weights):
        center_embeds = self.embedding_v(center_words) # B x 1 x D
        target_embeds = self.embedding_u(target_words) # B x 1 x D
        
        center_bias = self.v_bias(center_words).squeeze(1)
        target_bias = self.u_bias(target_words).squeeze(1)
        
        inner_product = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1
        
        loss = weights * torch.pow(inner_product + center_bias + target_bias - coocs, 2)
        
        return torch.sum(loss)
    
    def prediction(self, inputs):
        v_embeds = self.embedding_v(inputs) # B x 1 x D
        u_embeds = self.embedding_u(inputs) # B x 1 x D
                
        return v_embeds + u_embeds # final embed

Train the model

In [148]:
EMBEDDING_SIZE = 50
BATCH_SIZE = 256
EPOCH = 50

In [135]:
losses = []
model = GloVe(len(word2index), EMBEDDING_SIZE)
if torch.cuda.is_available():
    model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [154]:
def get_batch(batch_size, train_data):
    random.shuffle(train_data) # Shuffle all samples in the dataset
    start_index = 0     
    end_index = batch_size

    while end_index < len(train_data):
        # Slide samples from start_index to end_index
        batch = train_data[start_index:end_index]
        # Update start and end index to point to the next batch
        temp = end_index
        end_index += batch_size
        start_index = temp
        yield batch # MUse yield to remember the index in the next calling
    
    if end_index >= len(train_data):
        batch = train_data[start_index:] # Get all rest of samples to be the batch
        yield batch

In [156]:
for epoch in range(EPOCH):
    for i, batch in enumerate(get_batch(BATCH_SIZE, train_data)):
        
        inputs, targets, coocs, weights = zip(*batch)
        
        inputs = torch.cat(inputs) # B x 1
        targets = torch.cat(targets) # B x 1
        coocs = torch.cat(coocs)
        weights = torch.cat(weights)
        model.zero_grad()

        loss = model(inputs, targets, coocs, weights)
        
        loss.backward()
        optimizer.step()
    
        losses.append(loss.data.tolist())
    if epoch % 10 == 0:
        print("Epoch : %d, mean_loss : %.02f" % (epoch, np.mean(losses)))
        losses = []

Epoch : 0, mean_loss : 215.86
Epoch : 10, mean_loss : 2.37
Epoch : 20, mean_loss : 0.52
Epoch : 30, mean_loss : 0.12
Epoch : 40, mean_loss : 0.04


# Test

In [213]:
def word_similarity(target, vocab, n_words):
    # Encode the target words
    target_V = model.prediction(torch.tensor(word2index[target], device='cuda'))

    similarities = []
    for i in range(len(vocab)):
        if vocab[i] == target: # Skip the word itself
            continue
        
        # Use other words for comparison and distance computation using cosine similarity
        vector = model.prediction(torch.tensor(word2index[vocab[i]], device='cuda'))
        cosine_sim = F.cosine_similarity(target_V, vector, 0).data.tolist()

        similarities.append([vocab[i], cosine_sim])
    
    # Return the most similar n_words according to the cosine similarity
    return sorted(similarities, key=lambda x: x[1], reverse=True)[:n_words]

In [223]:
test = random.choice(vocab)
print(f'Random choose: {test}')
print(f'10 most related words:')
for pred in word_similarity(test, vocab, 10):
    print(pred)

Random choose: attempt
10 most related words:
['rude', 0.8951979279518127]
['sprat', 0.8603125810623169]
['demanded', 0.8229526281356812]
['barques', 0.8207601308822632]
['captains', 0.816744327545166]
['beast', 0.8139459490776062]
['pulpit', 0.8119369745254517]
['schoolmaster', 0.8063094615936279]
['gateway', 0.8044277429580688]
['very', 0.7918776869773865]
