In [114]:
import torch, pickle, sys
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import pandas as pd

In [115]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.o2o = nn.Linear(hidden_size + output_size, output_size)
        self.dropout = nn.Dropout(0.1)
        self.softmax = nn.LogSoftmax()

    def forward(self, input, hidden):
        input_combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(input_combined)
        output = self.i2o(input_combined)
        output_combined = torch.cat((hidden, output), 1)
        output = self.o2o(output_combined)
        output = self.dropout(output)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self, value=0.001):
        return Variable(torch.from_numpy(np.random.uniform(low=-value, high=value, size=(1,self.hidden_size))).float())


In [160]:
def getNorm(i2o):
    return np.linalg.norm(i2o, ord='fro')

In [116]:
# -- import embeddings
with open('../../data/lemmas_embeddings.pickle', 'rb') as f:
    data = pickle.load(f)

In [117]:
data.keys()

dict_keys(['vocab', 'embeddings'])

In [118]:
embeddings = data['embeddings'].astype(np.float32)
vocab = data['vocab']
reverse = {v:k for k,v in vocab.items()}
len(vocab)

17663

In [119]:
# -- load lemma text
with open('../../data/lemmas.pickle', 'rb') as g:
    lemma = pickle.load(g)

In [None]:
embedding_size = 100
hidden_dim = 100
rnn = RNN(embedding_size, hidden_dim, len(vocab)).cuda()
criterion = nn.CrossEntropyLoss() # use a Classification Cross-Entropy loss
optimizer = optim.SGD(rnn.parameters(), lr=0.0001, momentum=0.02, weight_decay=0.2)


In [None]:
for idx, sentence in enumerate(lemma):
    hidden = rnn.initHidden()
    loss = 0.0
    for i in range(len(sentence) - 1):
        try: # -- get rid of digits in vocab
            current_word = sentence[i]
            current_index = vocab[current_word]
            current_embedding = torch.from_numpy(embeddings[current_index]).unsqueeze(0)
            vembedding = Variable(current_embedding).cuda()
            next_word = torch.from_numpy(np.array([vocab[sentence[i+1]]]))
            output, hidden = rnn(vembedding, hidden.cuda())
            loss += criterion(output, Variable(next_word).cuda())

        except KeyError as e:
            pass

    loss.backward()
    optimizer.step()
    print(idx)

In [None]:
for i in range(10):
    _hidden = rnn.initHidden()
    if i == 0:
        word = 'b'
        next_word = torch.from_numpy(embeddings[vocab[word]]).unsqueeze(0)
    print(word)
    _output, _hidden = rnn(Variable(next_word).cuda(), _hidden.cuda())
    probs = np.exp(_output.data.cpu().numpy().ravel())
    amax = np.argmax(probs)
    word = reverse[amax]
    next_word = torch.from_numpy(embeddings[vocab[word]]).unsqueeze(0)

#### Compute norm of gradients

In [156]:
embedding_size = 100
hidden_dim = 100
vocab_size = 100
rnn = RNN(embedding_size, hidden_dim, vocab_size).cuda()
criterion = nn.CrossEntropyLoss() # use a Classification Cross-Entropy loss
optimizer = optim.SGD(rnn.parameters(), lr=0.0001, momentum=0.02, weight_decay=0.2)


In [121]:
for parm in rnn.parameters():
    parm.requires_grad = True

In [163]:
for k in range(100):
    hidden = rnn.initHidden()
    loss = 0.0
    for i in range(10 - 1):
        try: # -- get rid of digits in vocab
            current_embedding = torch.from_numpy(embeddings[np.random.randint(10)]).unsqueeze(0)
            vembedding = Variable(current_embedding, requires_grad=True).cuda()
            next_word = torch.from_numpy(np.array([np.random.randint(10)]))
            output, hidden = rnn(vembedding, hidden.cuda())
            loss += criterion(output, Variable(next_word).cuda())
            torch.nn.utils.clip_grad_norm(rnn.parameters(), 0.25)

        except KeyError as e:
            pass

    loss.backward()
    optimizer.step()
    i2o = rnn.i2o.weight.grad.data.cpu().numpy()
    i2h = rnn.i2h.weight.grad.data.cpu().numpy()
    o2o = rnn.o2o.weight.grad.data.cpu().numpy()
    
    norm_i2o = getNorm(i2o)
    norm_i2h = getNorm(i2h)
    norm_o2o = getNorm(o2o)
    print(k, norm_i2o, norm_i2h, norm_o2o)

0 6.00102 6.94652 28.4438
1 27.665 111.153 48.0943
2 8.68394 27.0407 23.2355
3 10.0539 18.525 43.0136
4 5.4819 16.115 21.4013
5 7.08332 15.76 34.4673
6 15.5867 38.8045 22.886
7 17.3505 53.2685 37.4421
8 10.4344 22.9897 32.6837
9 4.19784 9.43011 17.7511
10 4.09357 4.68125 19.1784
11 8.97172 15.6166 26.7831
12 6.66302 8.04506 22.9977
13 29.7504 81.2312 33.1098
14 7.03675 16.544 31.3246
15 13.7566 48.2204 38.1774
16 2.60717 4.0136 14.1946
17 4.66888 5.19463 15.014
18 10.4952 24.9476 29.7376
19 4.39865 9.23239 20.6538
20 7.30202 10.6385 30.5975
21 14.2109 40.6298 18.3761
22 27.3421 93.8539 37.5096
23 15.1732 41.9079 31.9011
24 14.2663 28.1796 24.7966
25 11.8236 31.0986 24.2477
26 6.39691 10.7796 21.2292
27 8.34457 23.9766 22.1174
28 7.4809 15.1798 40.5001
29 18.5032 51.8982 43.5337
30 18.4574 48.4326 37.3832
31 8.83204 26.1046 22.6469
32 9.43912 28.6126 28.8431
33 13.6185 38.2131 41.5015
34 7.22373 12.5876 24.0256
35 12.3383 31.8123 30.5087
36 5.98024 8.58123 22.2111
37 4.60865 6.46872 19.

In [150]:

i2o

array([[-0.56467068,  0.34446806,  0.16331168, ..., -0.07057101,
        -0.04895362,  0.55580968],
       [ 0.01528262, -0.0856595 ,  0.11427888, ..., -0.07430661,
        -0.0210433 ,  0.12422596],
       [ 0.24313246,  0.07197672,  0.35094213, ..., -0.16292627,
        -0.17816915,  0.45773843],
       ..., 
       [-0.42307332,  0.29978159, -0.28576756, ...,  0.02942524,
         0.03117448,  0.02862721],
       [ 0.26053941, -0.13403273,  0.25640178, ...,  0.03677401,
        -0.00440274,  0.41172504],
       [-0.4059085 ,  0.0619502 , -0.37544531, ..., -0.10164777,
        -0.07685494, -0.02930087]], dtype=float32)

In [152]:
np.linalg.norm(i2o, ord='fro')

37.699448