In [257]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import nltk
import random
import numpy as np
from collections import Counter
import os
import math
random.seed(0)

In [258]:
BUFFER_SIZE = 10000
WINDOW = 3
K = 5
EPOCH = 5
BATCH_SIZE = 128
EMBEDDING_SIZE = 15
KERNEL_SIZES = [1, 2, 3, 4, 5, 6, 7]
KERNEL_DIMEN = [50, 100, 150, 200, 200, 200, 200]
LR = 0.0001

In [259]:
USE_CUDA = torch.cuda.is_available()
gpus = [0]
torch.cuda.set_device(gpus[0])
FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor

# Data Preprocessing

In [262]:
word_vocab = Counter()
char_vocab = Counter()
char_vocab.update(['{', '}'])
text_location = os.path.join(os.getcwd(), 'corpus/')
filenames = os.listdir(text_location)
for filename in filenames:
    filename = os.path.join(text_location, filename)
    with open(filename, 'r', encoding='utf8') as f:
        line = f.read()
        word_vocab.update(line.lower().split())
        char_vocab.update(line)
print(word_vocab.most_common(5))

[('.', 1975658), ('the', 1417475), ('of', 888701), ('\uf8ff', 781685), ('and', 726249)]


In [263]:
# +1 as 0 is the PAD
char_to_index = {e:n+1 for n, e in enumerate(char_vocab)}
index_to_char = {n+1:e for n, e in enumerate(char_vocab)}

In [264]:
num_total_words = sum([num for word, num in word_vocab.items()])
unigram_table = []
Z = 0.001
for word in word_vocab:
    unigram_table.extend([word] * int(((word_vocab[word]/num_total_words)**0.75)/Z))

In [265]:
def get_negative(word):
    neg_samples = []
    word = "".join([index_to_char[ind] for ind in word])
    while len(neg_samples) < K:
        neg = random.choice(unigram_table)
        if neg == word.lower():
            continue
        neg_samples.append(prepare_word(neg, char_to_index))
    return neg_samples

def prepare_files(filenames):
    MIN_COUNT = 2
    for filename in filenames:
        with open(filename, 'r', encoding='utf8') as f:
            for line in f:
                words = line.split()
                max_j = len(words)
                for i, word in enumerate(words):
                    if word_vocab[word.lower()] <= MIN_COUNT:
                        continue
                    frequency = word_vocab[word.lower()] / num_total_words
                    number = 1 - math.sqrt(0.00005/frequency)
                    if random.uniform(0, 1) <= number:
                        continue
                    for j in range(i - WINDOW, i + WINDOW):
                        if (i == j) or (j < 0) or (j >= max_j):
                            continue
                        target = words[j]
                        yield word, target

def prepare_word(word, char_to_index):
    start = [char_to_index['{']]
    finish = [char_to_index['}']]
    return start + [char_to_index[char] for char in word] + finish

In [266]:
def get_buffer(filenames, buffer_size):
    random.shuffle(filenames)
    buffer = []
    for word, target in prepare_files(filenames):
        word = prepare_word(word, char_to_index)
        target = prepare_word(target, char_to_index)
        buffer.append([word, target])
        if len(buffer) == buffer_size:
            yield buffer
            buffer = []
    yield buffer
    
def get_batch(filenames, buffer_size, batch_size):
    for buffer in get_buffer(filenames, buffer_size):
        random.shuffle(buffer)
        sindex = 0
        eindex = batch_size
        while eindex < len(buffer):
            batch = buffer[sindex:eindex]
            temp = eindex
            eindex = eindex + batch_size
            sindex = temp
            yield batch
        if eindex >= len(buffer):
            batch = buffer[sindex:]
            yield batch
            
def pad_to_batch(batch):
    max_length = max([len(e) for e in batch])
    padded_batch = []
    for i in range(len(batch)):
        padded_batch.append(batch[i] + [0] * (max_length - len(batch[i])))
    return Variable(LongTensor(padded_batch))

In [267]:
text_location = os.path.join(os.getcwd(), 'corpus/')
filenames = [os.path.join(text_location, filename) for filename in os.listdir(text_location)]
batches = get_batch(filenames, BUFFER_SIZE, 100)

# Model

In [268]:
class Word2CNN(nn.Module):
    
    def __init__(self, vocab_size, embedding_dim, kernel_dims, kernel_sizes,
                 highway_layers=2):
        super(Word2CNN, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.convs = nn.ModuleList([nn.Conv2d(1, dim, (size, embedding_dim)) for dim, size in zip(kernel_dims, kernel_sizes)])
        self.internal_dim = sum(kernel_dims)
        self.hw_num_layers = highway_layers
        self.hw_nonlinear = nn.ModuleList([nn.Linear(self.internal_dim, self.internal_dim) for _ in range(highway_layers)])
        self.hw_linear = nn.ModuleList([nn.Linear(self.internal_dim, self.internal_dim) for _ in range(highway_layers)])
        self.hw_gate = nn.ModuleList([nn.Linear(self.internal_dim, self.internal_dim) for _ in range(highway_layers)])
        self.final_layer = nn.Linear(self.internal_dim * 2, 2)
        self.logsigmoid = nn.LogSigmoid()
        
    def char_cnn(self, inputs):
        inputs = self.embeddings(inputs).unsqueeze(1) # [BATCH, 1, MAX_LENGTH, EM_SIZE]
        inputs = [F.tanh(conv(inputs)).squeeze(3) for conv in self.convs] # [BATCH, K_DIM, MAX_LENGTH]*len(Ks)
        inputs = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in inputs] # [BATCH, K_DIM]*len(Ks)
        inputs = torch.cat(inputs, 1) # [BATCH, K_DIM*len(Ks)]
        for layer in range(self.hw_num_layers):
            gate = F.sigmoid(self.hw_gate[layer](inputs))
            nonlinear = F.relu(self.hw_nonlinear[layer](inputs))
            linear = self.hw_linear[layer](inputs)
            inputs = gate * nonlinear + (1 - gate) * linear
        return inputs
    
    def forward(self, center_words, target_words, negati_words, is_training=False):      
        center_embeds = self.char_cnn(center_words).unsqueeze(1) # [B, 1, D]
        target_embeds = self.char_cnn(target_words).unsqueeze(1) # [B, 1, D]
        size = negati_words.size()
        batch_size = size[0]
        K = size[1]
        wlen = size[2]
        negati_words = negati_words.view(batch_size*K, wlen)
        negati_embeds = -self.char_cnn(negati_words).view(batch_size, K, self.internal_dim) # [B, K, D]
        
        positive_score = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1
        negative_score = torch.sum(negati_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2), 1).view(batch_size, -1)
        loss = self.logsigmoid(positive_score) + self.logsigmoid(negative_score)
        return -torch.mean(loss)
    
    def prediction(self, inputs):
        return self.char_cnn(inputs)

In [269]:
vocab_size = len(char_to_index) + 1
model = Word2CNN(vocab_size, EMBEDDING_SIZE, KERNEL_DIMEN, KERNEL_SIZES)
if USE_CUDA:
    model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=LR)

In [270]:
flatten = lambda l: [item for sublist in l for item in sublist]
for epoch in range(EPOCH):
    losses = []
    for i, batch in enumerate(get_batch(filenames, BUFFER_SIZE, BATCH_SIZE)):
        inputs, targets = zip(*batch)
        negatives = []
        for inpt in inputs:
            negatives.extend(get_negative(inpt))
        inputs = pad_to_batch(inputs)
        targets = pad_to_batch(targets)
        negatives = pad_to_batch(negatives)
        negatives = negatives.view(len(inputs), K, -1)
        model.zero_grad()
        loss = model(inputs, targets, negatives, True)
        loss.backward()
        optimizer.step()
        losses.append(loss.data.tolist()[0])
        if i % 100 == 0:
            print("[%d/%d] mean_loss : %0.2f" %(epoch, EPOCH, np.mean(losses)))
            losses = []

[0/5] mean_loss : 29.49
[0/5] mean_loss : 1.59
[0/5] mean_loss : 1.29
[0/5] mean_loss : 1.20
[0/5] mean_loss : 1.09
[0/5] mean_loss : 0.90
[0/5] mean_loss : 1.07
[0/5] mean_loss : 1.13
[0/5] mean_loss : 1.06
[0/5] mean_loss : 1.07
[0/5] mean_loss : 1.06
[0/5] mean_loss : 1.00
[0/5] mean_loss : 0.98
[0/5] mean_loss : 0.84
[0/5] mean_loss : 0.41
[0/5] mean_loss : 0.75
[0/5] mean_loss : 0.95
[0/5] mean_loss : 0.85
[0/5] mean_loss : 0.81
[0/5] mean_loss : 0.73
[0/5] mean_loss : 0.77
[0/5] mean_loss : 0.67
[0/5] mean_loss : 0.82
[0/5] mean_loss : 0.81
[0/5] mean_loss : 0.79
[0/5] mean_loss : 0.78
[0/5] mean_loss : 0.82
[0/5] mean_loss : 0.83
[0/5] mean_loss : 0.85
[0/5] mean_loss : 0.81
[0/5] mean_loss : 0.76
[0/5] mean_loss : 0.84
[0/5] mean_loss : 0.73
[0/5] mean_loss : 0.80
[0/5] mean_loss : 0.84
[0/5] mean_loss : 0.89
[0/5] mean_loss : 0.81
[0/5] mean_loss : 0.75
[0/5] mean_loss : 0.81
[0/5] mean_loss : 0.73
[0/5] mean_loss : 0.70
[0/5] mean_loss : 1.76
[0/5] mean_loss : 1.11


KeyboardInterrupt: 

## Save the model

In [271]:
torch.save(model.state_dict(), "the_model.model")