In [1]:
import torch
import numpy as np

SEED = 41

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters

In [2]:
K = 100
C = 3
NUM_EPOCHS = 2
MAX_VOCAB_SIZE = 30000
BATCH_SIZE = 128
LEARNING_RATE = 0.2
EMBEDDING_SIZE = 100

# Build Vocabulary

In [3]:
from collections import Counter

with open('data/text8/train.txt') as f:
    text = f.read()
text = text.lower().split()

UNK_LABEL = '<unk>'

class Vocabulary:

    def __init__(self, text, vocab_size, freq_pow=0.75):
        d = dict(Counter(text).most_common(vocab_size - 1))
        d[UNK_LABEL] = len(text) - sum(d.values())
        self.itos = list(d.keys())
        self.stoi = {word: i for i, word in enumerate(self.itos)}
        self.counts = np.array(list(d.values()))
        self.freqs = self.counts / len(text)
        self.freqs **= freq_pow
        self.freqs = self.freqs / self.freqs.sum()

vocab = Vocabulary(text, MAX_VOCAB_SIZE)

# Build Dataset and DataLoader

In [4]:
from torch.utils.data import Dataset

class WordEmbeddingDataset(Dataset):

    def __init__(self, text, vocab, window_size, neg_sample_rate):
        super(WordEmbeddingDataset, self).__init__()
        text_encoded = [vocab.stoi.get(t, vocab.stoi[UNK_LABEL]) for t in text]
        self.text_encoded = torch.LongTensor(text_encoded)
        self.word_freqs = vocab.freqs
        self.window_size = window_size
        self.neg_sample_rate = neg_sample_rate

    def __len__(self):
        return len(self.text_encoded)

    def __getitem__(self, idx):
        center_word = self.text_encoded[idx]

        indices = [i % len(self.text_encoded) for i in range(idx - self.window_size, idx + self.window_size + 1) if i != idx]
        pos_words = self.text_encoded[indices]

        neg_freqs = torch.Tensor(self.word_freqs)
        neg_freqs[pos_words] = 0
        neg_words = torch.multinomial(neg_freqs, 2 * self.window_size * self.neg_sample_rate, True)

        return center_word, pos_words, neg_words

In [5]:
from torch.utils.data import DataLoader

dataset = WordEmbeddingDataset(text, vocab, C, K)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [6]:
center, pos, neg = next(iter(dataloader))
print(center.shape)
print(pos.shape)
print(neg.shape)

torch.Size([128])
torch.Size([128, 6])
torch.Size([128, 600])


# Build model

In [7]:
import torch.nn as nn
import torch.nn.functional as F

class EmbeddingModel(nn.Module):

    def __init__(self, vocab_size, embed_size):
        super(EmbeddingModel, self).__init__()
        self.in_embed = nn.Embedding(vocab_size, embed_size)
        self.out_embed = nn.Embedding(vocab_size, embed_size)

        initrange = 0.5 / embed_size
        self.in_embed.weight.data.uniform_(-initrange, initrange)
        self.out_embed.weight.data.uniform_(-initrange, initrange)

    def forward(self, inp_labels, pos_labels, neg_labels):
        '''
        inp_label: batch_size
        pos_labels: batch_size, 2 * C
        neg_labels: batch_size, 2 * C * K
        '''
        input_embedding = self.in_embed(inp_labels).unsqueeze(2)  # batch_size, embed_size, 1
        pos_embedding = self.out_embed(pos_labels)  # batch_size, 2 * C, embed_size
        neg_embedding = self.out_embed(neg_labels)  # batch_size, 2 * C * K, embed_size

        pos_dot = torch.bmm(pos_embedding, input_embedding).squeeze(1)  # batch_size, 2 * C
        neg_dot = torch.bmm(neg_embedding, -input_embedding).squeeze(1) # batch_size, 2 * C * K

        return -(F.logsigmoid(pos_dot).sum(1) + F.logsigmoid(neg_dot).sum(1))

    @property
    def input_embeddings(self):
        return self.in_embed.weight.data.cpu().numpy()

In [8]:
model = EmbeddingModel(MAX_VOCAB_SIZE, EMBEDDING_SIZE)
model.to(device)

EmbeddingModel(
  (in_embed): Embedding(30000, 100)
  (out_embed): Embedding(30000, 100)
)

# Functions for evaluating model

In [9]:
import pandas as pd

men = pd.read_csv('data/men.txt', sep='\t')
simlex = pd.read_csv('data/simlex-999.txt', sep='\t')
wordsim = pd.read_csv('data/wordsim353.csv', sep=',')

In [10]:
import scipy
from sklearn.metrics.pairwise import cosine_similarity

def evaluate(df, vocab, embedding_weights):
    human_similarity, model_similarity = [], []
    for i, row in df.iterrows():
        word1, word2, human_sim = row
        if word1 not in vocab.stoi or word2 not in vocab.stoi:
            continue
        word1_embed = embedding_weights[[vocab.stoi[word1]]]
        word2_embed = embedding_weights[[vocab.stoi[word2]]]
        model_similarity.append(float(cosine_similarity(word1_embed, word2_embed)))
        human_similarity.append(human_sim)
    return scipy.stats.spearmanr(human_similarity, model_similarity)

def find_nearest(word, vocab, embedding_weights):
    idx = vocab.stoi[word]
    embedding = embedding_weights[idx]
    cos_dis = np.array([scipy.spatial.distance.cosine(e, embedding) for e in embedding_weights])
    return [vocab.itos[i] for i in cos_dis.argsort()[1:10]]

# Train model

In [11]:
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

LOG_FILE = 'word-embedding.log'

for e in range(NUM_EPOCHS):
    for i, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
        input_labels = input_labels.to(device)
        pos_labels = pos_labels.to(device)
        neg_labels = neg_labels.to(device)
        
        optimizer.zero_grad()
        loss = model(input_labels, pos_labels, neg_labels).mean()

        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            log = f'epoch: {e}, iter: {i}, loss: {loss.item()}\n'
            print(log)
            with open(LOG_FILE, 'a') as f:
                f.write(log)

        if i % 2000 == 0:
            embedding_weights = model.input_embeddings
            sim_simlex = evaluate(simlex, vocab, embedding_weights)
            sim_men = evaluate(men, vocab, embedding_weights)
            sim_353 = evaluate(wordsim, vocab, embedding_weights)
            nearest_to_monster = find_nearest('monster', vocab, embedding_weights)
            log = f'epoch: {e}, iteration: {i}, simlex-999: {sim_simlex}, men: {sim_men}, sim353: {sim_353}, nearest to monster: {nearest_to_monster}\n'
            print(log)
            with open(LOG_FILE, 'a') as f:
                f.write(log)
    
    np.save(f'embedding-{EMBEDDING_SIZE}', model.input_embeddings)
    torch.save(model.state_dict(), f'embedding-{EMBEDDING_SIZE}.pt')

epoch: 0, iter: 0, loss: 420.0469970703125

epoch: 0, iteration: 0, simlex-999: SpearmanrResult(correlation=-0.011059767251777975, pvalue=0.7329798282448805), men: SpearmanrResult(correlation=0.00039138665541115484, pvalue=0.9841406544389526), sim353: SpearmanrResult(correlation=0.05486145223688137, pvalue=0.3286948425888109), nearest to monster: ['epistles', 'herr', 'clytemnestra', 'paran', 'from', 'aqaba', 'quantifier', 'frobisher', 'product']

epoch: 0, iter: 100, loss: 272.5193786621094

epoch: 0, iter: 200, loss: 202.49777221679688

epoch: 0, iter: 300, loss: 194.82191467285156

epoch: 0, iter: 400, loss: 180.51918029785156

epoch: 0, iter: 500, loss: 150.2568359375

epoch: 0, iter: 600, loss: 117.285400390625

epoch: 0, iter: 700, loss: 100.47791290283203

epoch: 0, iter: 800, loss: 105.71221923828125

epoch: 0, iter: 900, loss: 100.15678405761719

epoch: 0, iter: 1000, loss: 80.7895736694336

epoch: 0, iter: 1100, loss: 77.871826171875

epoch: 0, iter: 1200, loss: 79.36443328857

# Evaluate trained embedding

In [14]:
model.load_state_dict(torch.load(f'embedding-{EMBEDDING_SIZE}.pt'))
embedding_weights = model.input_embeddings

In [15]:
print('simlex-999', evaluate(simlex, vocab, embedding_weights))
print('men', evaluate(men, vocab, embedding_weights))
print('wordsim353', evaluate(wordsim, vocab, embedding_weights))

simlex-999 SpearmanrResult(correlation=0.1681873305716443, pvalue=1.7393186706477882e-07)
men SpearmanrResult(correlation=0.17935776642428256, pvalue=4.156408870655226e-20)
wordsim353 SpearmanrResult(correlation=0.28591694763153264, pvalue=2.0422417518282842e-07)


In [16]:
for word in ['good', 'fresh', 'monster', 'green', 'like', 'america', 'chicago', 'work', 'computer', 'language']:
    print(word, find_nearest(word, vocab, embedding_weights))

good ['bad', 'perfect', 'hard', 'experience', 'money', 'really', 'questions', 'evil', 'something']
fresh ['grain', 'waste', 'raw', 'dense', 'drinking', 'thermal', 'fiber', 'atmospheric', 'warm']
monster ['giant', 'robot', 'clown', 'snake', 'hammer', 'melody', 'demon', 'cube', 'bird']
green ['blue', 'yellow', 'white', 'cross', 'orange', 'red', 'black', 'gold', 'purple']
like ['etc', 'eating', 'drink', 'unlike', 'soft', 'similarly', 'whereas', 'sounds', 'eat']
america ['korea', 'africa', 'india', 'europe', 'turkey', 'carolina', 'australia', 'argentina', 'pakistan']
chicago ['boston', 'illinois', 'texas', 'massachusetts', 'london', 'toronto', 'indiana', 'berkeley', 'harvard']
work ['writing', 'writings', 'vision', 'marx', 'solo', 'aristotle', 'songs', 'recording', 'poetry']
computer ['digital', 'audio', 'video', 'graphics', 'electronic', 'software', 'hardware', 'computers', 'program']
language ['languages', 'alphabet', 'arabic', 'programming', 'pronunciation', 'grammar', 'dialect', 'spoke

In [18]:
man_idx = vocab.stoi['man']
woman_idx = vocab.stoi['woman']
king_idx = vocab.stoi['king']
embedded = embedding_weights[woman_idx] - embedding_weights[man_idx] + embedding_weights[king_idx]
cos_dis = np.array([scipy.spatial.distance.cosine(e, embedded) for e in embedding_weights])
for i in cos_dis.argsort()[:20]:
    print(vocab.itos[i])

king
henry
charles
queen
pope
prince
iii
elizabeth
alexander
edward
constantine
james
son
louis
iv
duke
frederick
mary
francis
emperor
