In [101]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm_notebook

In [90]:
class NPLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, window_size, activation_size):
        super(NPLM, self).__init__()
        
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.window_size = window_size
        self.activation_size = activation_size
        
        self.embeddings = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Linear(window_size * embed_dim, activation_size)
        self.tanh = nn.Tanh()
        
        self.fc2 = nn.Linear(activation_size + window_size * embed_dim, vocab_size)
        
        self.softmax = nn.Softmax()
        self.init_weights()
        
    def init_weights(self):
        initrange = 0.5
        self.embeddings.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        
        self.fc2.weight.data.uniform_(-initrange, initrange)
        self.fc2.bias.data.zero_()
        
    def forward(self, x):
        X = self.embeddings(x).view(-1, self.embed_dim * self.window_size)
        tanh_X = self.tanh(self.fc(X))
        
        X = torch.cat((X, tanh_X), dim=1)
        X = self.fc2(X)
        
        return self.softmax(X)

# Data Preprocessing

In [18]:
from nltk.corpus import brown

In [21]:
import nltk
nltk.download('brown')

[nltk_data] Downloading package brown to /home/ubuntu/nltk_data...
[nltk_data]   Unzipping corpora/brown.zip.


True

In [25]:
brown_texts, brown_cate = {}, brown.categories()
for category in brown_cate:
    brown_texts[category] = brown.words(categories=category)

In [32]:
def generate_w2i(texts_dict):
    w2i, i2w = {}, {}
    i = 0
    for cate, words in brown_texts.items():
        for word in words:
            if word.lower() not in w2i:
                w2i[word.lower()] = i
                i2w[i] = word.lower()
                i += 1
    return w2i, i2w

In [33]:
w2i, i2w = generate_w2i(brown_texts)

In [56]:
brown_texts["fiction"][0:-4]

['Thirty-three', 'Scotty', 'did', 'not', 'go', 'back', ...]

In [85]:
def preprocessing_data_in_ngram(texts_dict, n, w2i):
    train_pairs, validation_pairs, test_pairs = [], [], []
    for category in texts_dict:
        print(category)
        cate_pairs = []
        for i in range(len(texts_dict[category])-n):
            cate_pairs.append(([w2i[w.lower()] for w in texts_dict[category][i:i+n]], 
                               w2i[texts_dict[category][i+n].lower()]))
        train_size = len(cate_pairs) * 7 // 10
        val_size = len(cate_pairs) * 9 // 10
        train_pairs.extend(cate_pairs[:train_size])
        validation_pairs.extend(cate_pairs[train_size:val_size])
        test_pairs.extend(cate_pairs[val_size:])
    return train_pairs, validation_pairs, test_pairs

In [86]:
train_set, validation_set, test_set = preprocessing_data_in_ngram(brown_texts, 3, w2i)

adventure
belles_lettres
editorial
fiction
government
hobbies
humor
learned
lore
mystery
news
religion
reviews
romance
science_fiction


In [87]:
train_set

[([0, 1, 2], 3),
 ([1, 2, 3], 4),
 ([2, 3, 4], 5),
 ([3, 4, 5], 6),
 ([4, 5, 6], 7),
 ([5, 6, 7], 8),
 ([6, 7, 8], 9),
 ([7, 8, 9], 4),
 ([8, 9, 4], 10),
 ([9, 4, 10], 11),
 ([4, 10, 11], 12),
 ([10, 11, 12], 13),
 ([11, 12, 13], 14),
 ([12, 13, 14], 9),
 ([13, 14, 9], 4),
 ([14, 9, 4], 15),
 ([9, 4, 15], 16),
 ([4, 15, 16], 17),
 ([15, 16, 17], 18),
 ([16, 17, 18], 19),
 ([17, 18, 19], 20),
 ([18, 19, 20], 10),
 ([19, 20, 10], 21),
 ([20, 10, 21], 22),
 ([10, 21, 22], 7),
 ([21, 22, 7], 9),
 ([22, 7, 9], 23),
 ([7, 9, 23], 4),
 ([9, 23, 4], 24),
 ([23, 4, 24], 25),
 ([4, 24, 25], 14),
 ([24, 25, 14], 26),
 ([25, 14, 26], 27),
 ([14, 26, 27], 28),
 ([26, 27, 28], 29),
 ([27, 28, 29], 30),
 ([28, 29, 30], 31),
 ([29, 30, 31], 32),
 ([30, 31, 32], 9),
 ([31, 32, 9], 33),
 ([32, 9, 33], 34),
 ([9, 33, 34], 13),
 ([33, 34, 13], 35),
 ([34, 13, 35], 10),
 ([13, 35, 10], 36),
 ([35, 10, 36], 9),
 ([10, 36, 9], 37),
 ([36, 9, 37], 4),
 ([9, 37, 4], 38),
 ([37, 4, 38], 39),
 ([4, 38, 39], 40),

In [91]:
train_X = torch.tensor([pair[0] for pair in train_set])
train_Y = torch.tensor([pair[1] for pair in train_set])

In [98]:
def generate_batch(train_X, train_Y, size):
    a = np.random.choice(len(train_X), size, replace=False)
    return train_X[a], train_Y[a]

In [None]:
vocab_size = len(w2i)
embed_dim = 100
window_size = 3
activation_size = 150
size = 32

nplm = NPLM(vocab_size, embed_dim, window_size, activation_size)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(nplm.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

epoch = 10

for i in range(epoch):
    epoch_loss = 0
    for _ in tqdm_notebook(range(len(train_X) // size)):
        x, y = generate_batch(train_X, train_Y, size)
        pred_y = nplm(x)
        loss = criterion(pred_y, y)
        #print(loss)
        epoch_loss += loss.item()
        optimizer.step()
    scheduler.step()
    print(epoch_loss / len(train_X))

HBox(children=(FloatProgress(value=0.0, max=25399.0), HTML(value='')))