In [1]:
import gensim.models.word2vec as w2v
import pickle
import multiprocessing
import re
import ast
import nltk
import time

import pandas as pd
import numpy as np

import lasagne
import theano
import theano.tensor as T

In [2]:
# data = open('datasetV2twplustag.txt').read().split('\n')
# data = pd.Series(data)

In [3]:
with open('twitter_df.pickle', 'rb') as dfile:
    raw_data = pickle.load(dfile)
labels = raw_data['tags']
labels = pd.Series([ast.literal_eval(label)[0] for label in labels])
raw_data = raw_data['tokenized_text']

In [4]:
t2v = w2v.Word2Vec.load('tweet2vec.w2v')

In [5]:
sentences = []
for row in raw_data:
    for sentence in row:
        sentences.append(sentence)

In [6]:
data = raw_data.apply(lambda row: [[t2v.wv[word] for word in sentence] for sentence in row])
data = np.array([np.array(batch[0]) for batch in data])

In [7]:
labs, counts = np.unique(labels, return_counts=True)

In [8]:
counts, labs = zip(*sorted(zip(counts, labs), reverse=True))

In [9]:
split = round(0.1*len(labs))
keep = labs[:split]
remove = labs[split:]

In [10]:
labels.shape

(279395,)

In [11]:
labels2 = pd.Series([lab if lab in keep else None for lab in labels])

In [12]:
y = np.array([t2v.wv[word] if word is not None else None for word in labels2])

In [13]:
mask = np.array([el is not None for el in y])

In [14]:
y = y[mask]
# data = data[mask]

In [15]:
y = np.vstack(y)

In [16]:
y.shape

(169086, 600)

In [17]:
data_size = labels.shape[0]

border = round(data_size*0.8)
X_train = data[:border]
X_test = data[border:]
y_train = y[:border]
y_test = y[border:]

In [19]:
with open('data.pickle', 'wb') as dfile:
    pickle.dump(zip(data), dfile)

MemoryError: 

In [158]:
y_train.shape

(169086, 600)

In [159]:
input_var = T.tensor3('inputs')
target_var = T.matrix('targets')

In [160]:
#Lasagne Seed for Reproducibility
lasagne.random.set_rng(np.random.RandomState(1))

# Sequence Length
SEQ_LENGTH = 7

# Number of units in the two hidden (LSTM) layers
N_HIDDEN = 40

# Optimization learning rate
LEARNING_RATE = .1

# All gradients above this will be clipped
GRAD_CLIP = 100

# How often should we check the output?
PRINT_FREQ = 500

# Number of epochs to train the net
NUM_EPOCHS = 50

# Batch Size
BATCH_SIZE = 100

# Number of classes
CLASS_NUMBER = len(keep)

In [26]:
def build_nn(input_var=None):
    print("Building network ...")

    # First, we build the network, starting with an input layer
    # Recurrent layers expect input of shape
    # (batch size, SEQ_LENGTH, num_features)    
    l_in = lasagne.layers.InputLayer(shape=(None, 1, 600), input_var=input_var)
    
#     l_emb = lasagne.layers.EmbeddingLayer(l_in, 600, )

    #Two stacked LSTM layers
    l_forward_1 = lasagne.layers.LSTMLayer(
        l_in, N_HIDDEN, grad_clipping=GRAD_CLIP,
        nonlinearity=lasagne.nonlinearities.tanh)

    l_forward_2 = lasagne.layers.LSTMLayer(
        l_forward_1, N_HIDDEN, grad_clipping=GRAD_CLIP,
        nonlinearity=lasagne.nonlinearities.tanh)

    #Slice the output from the LSTM layers to only take the final prediction
    l_forward_slice = lasagne.layers.SliceLayer(l_forward_2, -1, 1)

    l_out = lasagne.layers.DenseLayer(l_forward_2, num_units=CLASS_NUMBER, 
                                      W = lasagne.init.Normal(), nonlinearity=lasagne.nonlinearities.softmax)

    return l_out

In [96]:
network = build_nn(input_var)

Building network ...


In [97]:
prediction = lasagne.layers.get_output(network)
loss = lasagne.objectives.categorical_crossentropy(prediction, target_var)
loss = loss.mean()



In [98]:
params = lasagne.layers.get_all_params(network, trainable=True)
updates = lasagne.updates.nesterov_momentum(
        loss, params, learning_rate=0.01, momentum=0.9)

In [99]:
test_prediction = lasagne.layers.get_output(network, deterministic=True)
test_loss = lasagne.objectives.categorical_crossentropy(test_prediction,
                                                        target_var)
test_loss = test_loss.mean()



In [100]:
test_acc = T.mean(T.eq(T.argmax(test_prediction, axis=1), target_var),
                  dtype=theano.config.floatX)

In [101]:
train_fn = theano.function([input_var, target_var], loss, updates=updates)

In [102]:
val_fn = theano.function([input_var, target_var], [test_loss, test_acc])

In [103]:
# def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
#     assert len(inputs) == len(targets)
#     if shuffle:
#             indices = np.arange(len(inputs))
#             np.random.shuffle(indices)
#     for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
#         if shuffle:
#             excerpt = indices[start_idx:start_idx + batchsize]
#         else:
#             excerpt = slice(start_idx, start_idx + batchsize)
#         inp = inputs[excert]
#         yield inputs[excerpt], targets[excerpt]

In [161]:
def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
    assert len(inputs) == len(targets)
    if shuffle:
            indices = np.arange(len(inputs))
            np.random.shuffle(indices)
    for start_idx in range(len(inputs)):
        if shuffle:
            excerpt = indices[start_idx]
            inp = inputs[excerpt]
        else:
            inp = inputs[start_idx]
        inp = np.expand_dims(inp, 1)
#         inp = np.array(inp, dtype=np.float32)
        yield inp, targets[excerpt]

In [162]:
for epoch in range(NUM_EPOCHS):
    # In each epoch, we do a full pass over the training data:
    train_err = 0
    train_batches = 0
    start_time = time.time()
    for batch in iterate_minibatches(X_train, y_train, 500, shuffle=True):
        inputs, targets = batch
        print(inputs.shape)
        train_err += train_fn(inputs, targets, allow_input_downcast=True)
        train_batches += 1

    # And a full pass over the validation data:
    val_err = 0
    val_acc = 0
    val_batches = 0
    for batch in iterate_minibatches(X_val, y_val, 500, shuffle=False):
        inputs, targets = batch
        err, acc = val_fn(inputs, targets)
        val_err += err
        val_acc += acc
        val_batches += 1

    # Then we print the results for this epoch:
    print("Epoch {} of {} took {:.3f}s".format(
        epoch + 1, NUM_EPOCHS, time.time() - start_time))
    print("  training loss:\t\t{:.6f}".format(train_err / train_batches))
    print("  validation loss:\t\t{:.6f}".format(val_err / val_batches))
    print("  validation accuracy:\t\t{:.2f} %".format(
        val_acc / val_batches * 100))

(5, 1, 600)


TypeError: ('Bad input argument to theano function with name "<ipython-input-101-a9a5249efc94>:1"  at index 1(0-based)', 'TensorType(int32, vector) cannot store a value of dtype float32 without risking loss of precision. If you do not mind this loss, you can: 1) explicitly cast your data to int32, or 2) set "allow_input_downcast=True" when calling "function".', array([ -5.93071512e-04,   6.51585637e-04,   6.13727025e-04,
         8.17847787e-04,   8.39702043e-05,   7.50090345e-04,
        -3.38226731e-04,  -7.38658928e-05,   8.30631936e-04,
        -1.04247360e-04,   7.12678535e-04,  -6.22358813e-04,
        -1.48864056e-04,   3.36624740e-04,   6.42150408e-04,
        -2.52210099e-04,   1.92125517e-04,  -6.03689230e-04,
         3.27086018e-04,   2.16319531e-04,  -1.31946654e-04,
        -2.04107389e-04,  -5.46766212e-04,   9.32203402e-05,
         3.04789661e-04,   1.21234392e-04,  -8.76966806e-05,
        -5.90709737e-04,   6.31688803e-04,   7.49788422e-04,
        -5.90504205e-04,  -4.89021302e-04,  -3.91051464e-04,
         1.40076372e-04,   2.62788060e-04,   1.78742834e-04,
        -8.24576884e-04,   5.60953165e-04,   2.79683270e-04,
         1.48518797e-04,  -6.02615473e-04,  -3.23135173e-04,
        -6.01884327e-04,  -3.78239289e-04,  -3.68017878e-04,
         4.59601724e-04,   7.51008571e-04,   7.60906143e-04,
        -4.23393329e-04,  -6.02487999e-04,   1.16527335e-05,
         4.42400342e-04,   3.26727924e-04,   1.65696285e-04,
        -7.01913319e-04,  -5.84205263e-04,   6.48779678e-04,
         1.92950305e-04,  -2.59088323e-04,   1.77397931e-04,
         4.62925847e-04,   5.32488455e-04,  -6.33837306e-04,
         6.40665123e-04,   6.31423667e-04,   4.28656422e-05,
        -6.26290042e-04,   3.68068315e-04,   3.94959556e-04,
         5.88301809e-05,  -8.01854010e-04,  -2.89515912e-04,
         7.63234333e-04,  -1.95336383e-04,   7.43453158e-04,
        -8.08637415e-04,  -1.40723365e-04,  -7.47784041e-04,
         8.18699948e-04,  -1.47183702e-04,   3.83782317e-04,
        -6.29168469e-04,   2.79930362e-04,   1.21222598e-04,
        -7.44459452e-04,  -6.33256626e-04,  -1.59701140e-05,
        -2.08427919e-05,   4.02735197e-04,  -1.50768145e-04,
         5.26952208e-04,  -7.72824045e-04,  -9.25935747e-05,
         6.57697907e-04,   5.30949212e-04,  -8.07468081e-04,
        -5.37136511e-04,  -2.64260831e-04,   5.39364584e-04,
        -6.37324003e-04,   4.08371736e-04,   7.97667133e-04,
        -6.11755648e-04,   9.13010881e-05,   6.87439984e-04,
        -6.07724942e-04,  -7.49100815e-04,   9.11651023e-06,
         4.17898293e-04,   8.07912787e-04,  -2.87273026e-04,
        -6.63674378e-04,  -8.12184880e-05,   5.89477480e-04,
         3.08095652e-04,  -4.60068317e-04,   6.23230648e-04,
         7.34410482e-04,   1.04129562e-04,  -4.14012116e-04,
        -3.69918213e-04,  -5.47299802e-04,   8.09232297e-04,
        -2.95126578e-04,  -1.11022877e-04,   3.01359600e-04,
        -1.67617647e-04,   4.54203982e-05,   6.96651638e-04,
        -4.39130032e-04,  -5.51907870e-04,   8.40659704e-05,
        -7.35917201e-05,   7.01475830e-04,   2.74137972e-04,
         6.17405691e-04,   8.44509050e-05,  -2.23866344e-04,
        -7.40984629e-04,   7.68631697e-04,  -1.53603687e-04,
        -7.75838853e-04,   1.79205643e-04,  -4.29917593e-04,
        -5.11282356e-04,  -1.51059794e-04,  -5.81332482e-04,
         1.13498740e-04,   7.28890707e-04,  -5.87427523e-04,
         4.88674792e-04,   7.22385594e-04,   4.55909118e-04,
         1.52744862e-04,  -2.10317507e-04,   3.84051877e-04,
         6.97783427e-04,  -7.65004137e-04,  -7.36687158e-04,
        -6.65528059e-04,   5.17337874e-04,  -9.49818277e-05,
        -1.54197682e-04,   1.74750632e-04,  -7.90711259e-04,
        -1.28484855e-04,   1.47550207e-04,   8.01376184e-04,
        -8.16057654e-05,   7.69008941e-04,   7.43513985e-04,
         1.46954990e-04,   5.51904668e-04,  -4.05719911e-04,
        -6.46523782e-04,  -6.02198183e-04,  -3.03544803e-04,
        -5.09515339e-05,  -4.14105103e-04,   1.71726715e-04,
        -5.05660893e-04,   4.98418347e-04,   5.79211919e-04,
        -7.35922134e-04,   6.30855153e-04,   4.29697364e-04,
         1.70005922e-04,  -2.50986137e-04,  -4.03152459e-04,
        -3.71721952e-04,   7.00740376e-04,   2.28500154e-04,
         6.66533713e-04,   1.47083570e-04,   1.52552384e-05,
        -7.82866671e-04,  -3.30163457e-04,  -2.66725518e-04,
         4.45723454e-05,   7.12555775e-04,   6.77068718e-04,
         4.62288561e-04,   5.22762304e-04,  -3.45389242e-04,
        -8.21077789e-04,   2.65314797e-04,   5.21514099e-04,
        -5.21232840e-04,   2.12528787e-04,   5.55540319e-04,
         2.86903269e-05,   3.09718103e-04,   4.44250210e-04,
         1.12569614e-04,  -5.88786264e-04,   2.94330792e-04,
        -7.82909221e-04,   3.47263413e-04,   1.13626927e-04,
         6.97172596e-04,   4.03677725e-04,   6.33042015e-04,
         6.97984302e-04,   4.59453411e-04,   6.43218518e-04,
        -3.48618632e-04,  -6.93947717e-04,   3.04759538e-04,
        -6.87727879e-04,  -7.48737057e-06,   6.17567915e-04,
        -3.09254887e-04,   6.76818774e-04,   8.73765457e-05,
        -2.18817746e-04,   3.92522576e-04,   3.94261733e-04,
        -6.85245555e-04,   3.08866671e-04,  -1.46844293e-04,
         8.10684636e-04,   4.61519638e-04,  -3.30317052e-05,
         9.76889060e-05,  -1.70235639e-04,  -2.04291238e-04,
         4.94898995e-04,   6.21366315e-04,   5.12193656e-04,
         3.48531234e-04,   2.78790336e-04,  -2.29491547e-04,
        -1.69841296e-04,   4.60751413e-04,  -6.60153397e-04,
         6.20359060e-05,  -3.74687981e-04,   4.92113002e-04,
         3.82195372e-04,   2.91994511e-04,   5.17335247e-05,
        -1.66881713e-04,  -6.74011011e-04,   7.86630204e-04,
         4.50609194e-04,  -5.29798446e-04,   4.85681958e-04,
        -6.53330935e-04,  -1.38944248e-04,   5.33316925e-04,
        -5.60374348e-04,  -5.85049333e-04,  -6.76932454e-04,
        -6.66189822e-04,   6.07573369e-04,  -4.26794490e-04,
         5.26920951e-04,   1.82903721e-04,  -4.42308810e-04,
         1.35167633e-04,  -5.20843023e-04,  -1.90217121e-04,
         7.29731633e-04,  -2.63781694e-04,   7.41553376e-04,
         5.47357660e-04,  -8.01699061e-04,  -6.18975086e-04,
         9.02306056e-05,  -7.91593557e-05,   1.04806684e-04,
         7.60288909e-04,  -2.51167832e-04,   1.30133252e-04,
        -5.66033239e-04,   4.53779357e-04,  -1.17735319e-04,
         5.43756003e-04,  -3.57830490e-04,  -7.89521378e-04,
         3.71446280e-04,   3.48506321e-04,  -7.59187242e-05,
         7.04335456e-04,   1.49069267e-04,   6.54479954e-04,
        -1.31253692e-04,   1.61772157e-04,  -1.25927822e-04,
         5.67479408e-04,   6.65553671e-04,  -1.01062025e-04,
         7.01200857e-04,   2.33441780e-04,  -5.49271062e-04,
        -6.85029489e-04,  -5.70030592e-04,  -2.85639719e-04,
         7.21555145e-04,   9.74668728e-05,  -5.47162956e-04,
        -8.18895234e-04,   2.34294057e-05,   1.33216497e-04,
         5.82332490e-04,   4.39334020e-04,   4.00353689e-04,
         3.78119206e-04,  -2.17703113e-04,  -1.62688448e-04,
        -3.29285453e-04,   6.56152086e-04,   2.87993695e-04,
        -4.58713359e-04,  -8.15036881e-04,  -4.92247636e-04,
        -1.56379727e-04,  -7.51712767e-04,   5.04496100e-04,
        -5.20541798e-05,  -5.04141441e-04,   4.48098959e-04,
        -6.46446540e-04,  -7.98197056e-04,   1.26529209e-04,
         3.69696965e-04,   5.25369367e-04,   3.62511928e-04,
        -5.81139582e-04,  -3.48197587e-04,  -1.83066179e-04,
         7.80537783e-04,   7.25522754e-04,   6.18905236e-04,
        -3.29212402e-04,  -6.22547959e-05,  -5.93280129e-04,
         2.88171665e-04,   5.45165967e-04,  -1.46308143e-04,
         6.33926596e-04,  -5.18482702e-04,   5.37175045e-04,
         7.89777783e-04,   5.79748128e-04,   5.71088574e-04,
        -3.78990720e-04,   4.98284295e-04,  -4.47885453e-04,
        -7.41796626e-04,  -7.54249981e-04,  -4.68114304e-04,
         7.27764796e-04,  -3.56830016e-04,   6.41756691e-04,
         4.77724359e-04,   5.74130856e-04,   8.17811873e-04,
         3.85474210e-04,   2.52757891e-04,  -1.97131536e-04,
        -8.19618755e-04,   2.53034814e-04,   4.58743307e-06,
        -3.18703766e-04,  -6.89402106e-04,   7.00159115e-04,
        -3.90774629e-04,   4.45555546e-04,  -6.81063568e-04,
        -3.04608926e-04,   3.80430080e-04,   4.13821079e-04,
        -7.37241760e-04,   4.22740297e-04,  -7.60201874e-05,
         5.05024800e-04,   1.52759822e-04,   1.21739715e-04,
        -7.59264745e-04,  -2.06294018e-04,  -1.43882629e-04,
        -4.05269966e-04,  -6.52428425e-04,  -5.93357545e-04,
         7.38192175e-04,  -3.06758768e-04,   1.86563411e-04,
         1.90153351e-05,   6.91113994e-04,  -4.53243061e-04,
         4.98032663e-04,   7.80194998e-04,   7.43840123e-04,
         4.67209291e-04,   6.41185208e-04,   6.92477159e-04,
        -7.88703546e-05,  -2.19786059e-04,  -2.14519343e-04,
         7.65284873e-04,  -1.20627257e-04,   5.81165776e-04,
        -5.21226611e-05,   1.79074865e-04,  -8.16412852e-04,
        -4.08373162e-04,  -6.14334480e-04,   9.22988766e-05,
         1.46124425e-04,   5.27402328e-04,  -2.14802785e-04,
        -3.42507119e-04,   4.49145795e-04,   3.29270464e-04,
         1.60454565e-05,   6.47726876e-04,  -2.62784859e-04,
        -2.22449758e-04,  -5.40938810e-04,   4.12660651e-04,
        -2.01690782e-04,  -1.49345375e-04,   6.66666834e-04,
        -2.23873765e-04,  -6.66555658e-04,  -7.91008060e-04,
         2.86109775e-04,  -3.74661322e-04,  -3.33917909e-04,
         4.07560350e-04,   4.98899426e-05,  -4.87202749e-04,
         6.93333801e-04,   4.47051643e-05,  -3.68415785e-04,
         2.73735175e-04,   7.58710448e-05,   7.54949928e-04,
         6.76523312e-04,   3.95542767e-04,  -2.60204659e-04,
        -2.80607666e-04,  -7.42372475e-04,  -4.18171054e-04,
        -3.09507479e-04,   5.28832956e-04,  -6.46698638e-04,
         2.27894008e-04,   8.28957884e-04,  -6.07318711e-04,
         7.61700969e-04,  -6.42352912e-04,  -3.75312840e-04,
        -1.40672782e-04,   7.03952901e-05,   4.13720118e-04,
         3.03478242e-04,   6.64031424e-04,   1.06194071e-04,
        -6.87821710e-04,  -9.65245927e-05,   1.11120928e-04,
         8.18175555e-04,  -7.49000581e-04,  -7.34822475e-04,
         3.21381027e-04,   5.22675458e-04,   5.32209524e-04,
        -4.47934726e-05,   7.73555948e-04,  -7.00422272e-04,
         2.77828745e-04,  -3.83046747e-04,   7.65279674e-06,
        -4.98918467e-04,  -3.84346029e-04,  -1.92921798e-04,
         1.23125341e-04,  -7.12997920e-04,   7.04575970e-04,
        -2.59563298e-04,   3.00374140e-05,  -6.07788970e-04,
         1.58420415e-04,   1.41490993e-04,  -7.58446578e-04,
        -2.07775709e-04,   5.55641484e-04,   5.75613987e-04,
         3.60766047e-04,   3.21754866e-04,   1.44255711e-04,
         2.74255784e-04,  -9.31240793e-05,  -7.65065721e-04,
        -4.53322806e-04,   5.37150307e-04,  -6.82592799e-04,
         1.06610845e-04,  -5.99223422e-04,  -3.19529849e-04,
        -6.44852291e-04,  -5.28086035e-04,   4.43903875e-04,
        -4.29233012e-04,  -7.74502987e-04,   3.40445287e-04,
         1.80612624e-04,  -3.30389594e-04,   4.57357237e-04,
         2.39470231e-04,  -4.73392021e-04,  -6.84068771e-04,
        -5.05923643e-04,  -3.30606796e-04,  -3.37650679e-04,
        -8.17542605e-04,  -2.22238334e-04,   3.29041999e-04,
        -4.58517665e-04,   1.28378364e-04,   7.51682819e-05,
        -1.18623095e-04,  -4.86501609e-04,  -6.05423935e-04,
        -2.44871277e-04,   5.07494842e-04,   3.23682529e-04,
         6.42572122e-04,  -7.04590173e-04,   7.83122377e-05,
        -2.63612834e-04,  -7.97940826e-04,   2.06567653e-04,
         4.13778936e-04,   8.21053749e-04,  -7.66645768e-04,
        -4.12367488e-04,  -5.67862706e-04,  -3.60579434e-05,
        -4.19578369e-04,  -4.30853455e-04,   2.95838166e-04,
        -6.79642180e-05,  -7.63674791e-04,  -5.16910281e-04,
         7.65874051e-04,  -2.36434455e-04,  -6.88928005e-04,
        -4.53685818e-04,  -3.45462962e-04,  -6.96038594e-04,
         7.08565232e-04,   2.51020625e-04,  -7.35030568e-04,
        -5.24864183e-04,  -6.70706562e-04,   3.43386811e-04,
        -6.02863031e-04,   1.42131583e-04,  -5.72334393e-04,
        -6.86692540e-04,  -6.86338230e-04,   7.10251566e-04,
        -1.84933320e-04,  -4.93268424e-04,  -2.49760742e-05,
        -1.90427436e-05,  -5.85754053e-04,   4.34103713e-04,
         3.27537913e-04,  -7.46294216e-04,  -5.38122142e-04,
         7.88384568e-05,   6.81603444e-04,  -3.86899017e-04,
         2.34360239e-04,  -4.60492709e-04,  -3.89606110e-04], dtype=float32))

In [95]:
X_train[0].shape

(7, 600)