This notebook will serve as the master notebook which the TAs can run from start to finish

In [13]:
import pandas as pd
import numpy as np
from src.utils import get_batches, shuffle, train_val_split


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [36]:
# Define the path to the data
data_path = './data/essay_df.pkl'
df = pd.read_pickle(data_path)

In [37]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 8870 entries, 0 to 10683
Data columns (total 8 columns):
essay_id         8870 non-null int64
essay_set        8870 non-null int64
essay            8870 non-null object
domain1_score    8870 non-null int64
essays_embed     8870 non-null object
word_count       8870 non-null int64
max_score        8870 non-null float64
norm_score       8870 non-null float64
dtypes: float64(2), int64(4), object(2)
memory usage: 623.7+ KB


In [38]:
X = np.array(df['essays_embed'])
y = np.array(df['norm_score'])
sets = np.array(df['essay_set'])

In [39]:
del df


In [40]:
X = np.stack(X, axis=0)
print(X.shape)

(8870, 600, 100)


In [41]:
#X = np.reshape(X, [X.shape[0], -1])
print(X.shape)

(8870, 600, 100)


In [42]:
# instead of flattening, take the average value of each part of vectorized word in essay
X = np.mean(X, axis = 1)
print(X.shape)

(8870, 100)


In [43]:
X, y, sets = shuffle(X, y, sets)

In [44]:
X_train, y_train, s_train, X_val, y_val, s_val = train_val_split(X, y, sets, train_prop=0.8)
print(X_train.shape, X_val.shape, y_train.shape, y_val.shape )

[1 6 3 ... 5 5 1]
(7096, 100) (1774, 100) (7096,) (1774,)


In [64]:
from src.mlp import MLP
input_dim = X_train.shape[1]
batch_size = 128

batch_gen = get_batches(X_train, y_train, s_train, batch_size, net_type='mlp')

my_net = MLP(input_dim=input_dim, hidden_dims=[256, 64], num_classes=12, reg=False, l2_reg=1e-4)

In [65]:
my_net.train(gen=batch_gen, X_val=X_val, y_val=y_val, s_val=s_val, n_epochs=20, lr=1e-4)

loss for counter 100 is 2.1105144023895264
counter 100: valid acc = 0.15276211500167847
[2 2 8 3 3 2 3 3 3 3 3 3 3 8 3 8 2 3 2 8]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 200 is 1.7985703945159912
counter 200: valid acc = 0.22604283690452576
[2 2 9 3 3 2 3 3 3 3 3 3 3 9 3 9 2 3 2 9]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 300 is 1.867525339126587
counter 300: valid acc = 0.22604283690452576
[2 2 9 3 3 2 3 3 3 3 3 3 3 9 3 9 2 3 2 9]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 400 is 1.7991520166397095
counter 400: valid acc = 0.22604283690452576
[2 2 9 3 3 2 3 3 3 3 3 3 3 9 3 9 2 3 2 9]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 500 is 1.6400249004364014
counter 500: valid acc = 0.22604283690452576
[2 2 9 3 3 2 3 3 3 3 3 3 3 9 3 9 2 3 2 9]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 600 is 1.6365081071853638
counter 600: valid acc = 0.22604283690452576
[2 2 9 3 3 2 3 3 3 3 3 3 3 9 3 9 2 3 2 9]
[2 1 6 2 3 

loss for counter 4900 is 1.8626999855041504
counter 4900: valid acc = 0.24239008128643036
[2 2 9 3 3 2 3 2 3 2 3 2 3 9 3 9 2 3 2 9]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 5000 is 2.301772356033325
counter 5000: valid acc = 0.22491544485092163
[2 2 9 3 3 2 3 3 3 3 3 3 3 9 3 9 2 3 2 9]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 5100 is 1.918776035308838
counter 5100: valid acc = 0.21815107762813568
[2 1 8 3 3 2 3 1 3 1 3 1 3 8 3 8 2 3 1 8]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 5200 is 2.1200835704803467
counter 5200: valid acc = 0.22604283690452576
[2 2 9 3 3 2 3 3 3 3 3 3 3 9 3 9 2 3 2 9]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 5300 is 2.159639835357666
counter 5300: valid acc = 0.23562569916248322
[2 2 9 3 3 2 3 3 3 3 3 3 3 9 3 9 2 3 2 9]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 5400 is 2.034022569656372
counter 5400: valid acc = 0.23957158625125885
[2 2 6 2 2 2 2 2 2 2 2 2 2 6 2 9 2 2 2 9]
[2

loss for counter 9600 is 2.922919750213623
counter 9600: valid acc = 0.1538895219564438
[ 1  1  3  1  1  1  1  1  1  1  1  1  1  3  1 10  1  1  1 10]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 9700 is 3.087693214416504
counter 9700: valid acc = 0.20518602430820465
[2 2 8 3 3 2 3 2 3 2 3 2 3 8 2 8 2 3 2 8]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 9800 is 2.8724889755249023
counter 9800: valid acc = 0.2626832127571106
[ 2  2  9  3  3  2  3  2  3  2  3  2  3  9  2 10  2  3  2 10]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 9900 is 2.894460439682007
counter 9900: valid acc = 0.20856820046901703
[2 1 6 1 1 2 2 1 2 1 1 1 1 6 1 6 1 2 1 6]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 10000 is 4.312363147735596
counter 10000: valid acc = 0.3004509508609772
[ 2  1  9  3  3  2  3  1  3  1  3  1  3  9  3 10  2  3  1 10]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 10100 is 3.151651382446289
counter 10100: valid acc = 0.15

loss for counter 14300 is 4.189611434936523
counter 14300: valid acc = 0.22491544485092163
[2 2 9 3 3 2 3 3 3 3 3 3 3 9 3 9 2 3 2 9]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 14400 is 5.518614292144775
counter 14400: valid acc = 0.18320180475711823
[2 1 8 3 3 2 3 1 3 3 3 1 3 8 3 8 2 3 1 8]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 14500 is 5.027713775634766
counter 14500: valid acc = 0.2057497203350067
[2 2 6 2 2 2 2 2 2 2 2 2 2 6 2 6 2 2 2 6]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 14600 is 5.1901421546936035
counter 14600: valid acc = 0.284667432308197
[2 2 9 2 2 2 2 2 2 2 2 2 2 9 2 9 2 2 2 9]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 14700 is 7.384910583496094
counter 14700: valid acc = 0.29312288761138916
[2 2 9 3 3 2 3 2 3 2 3 2 3 9 2 9 2 3 2 9]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 14800 is 4.410954475402832
counter 14800: valid acc = 0.21533258259296417
[1 1 8 1 1 2 3 1 3 1 1 1 1 8 1 8 1 3

loss for counter 18900 is 7.197197914123535
counter 18900: valid acc = 0.28015783429145813
[1 1 9 1 1 2 3 1 3 1 1 1 1 9 1 9 1 3 1 9]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 19000 is 7.355042457580566
counter 19000: valid acc = 0.14712513983249664
[1 1 3 1 1 1 1 1 1 1 1 1 1 3 1 8 1 1 1 8]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 19100 is 10.464118957519531
counter 19100: valid acc = 0.22717024385929108
[2 2 9 3 3 2 3 3 3 3 3 3 3 9 3 9 2 3 2 9]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 19200 is 6.745813369750977
counter 19200: valid acc = 0.20631341636180878
[2 2 6 2 2 2 2 2 2 2 2 2 2 6 2 6 2 2 2 6]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 19300 is 9.172621726989746
counter 19300: valid acc = 0.23224352300167084
[ 2  2  6  2  2  2  2  2  2  2  2  2  2  6  2 10  2  2  2 10]
[2 1 6 2 3 3 3 1 3 1 2 1 2 7 2 8 2 3 1 8]
loss for counter 19400 is 7.102790832519531
counter 19400: valid acc = 0.2282976359128952
[2 2 9 3 3 2 3

RNN:

In [66]:
# Define the path to the data
data_path = './data/essay_df.pkl'
df = pd.read_pickle(data_path)

In [67]:
X = np.array(df['essays_embed'])
y = np.array(df['norm_score'])
sets = np.array(df['essay_set'])

In [68]:
X = np.stack(X, axis=0)
print(X.shape)
X, y, sets = shuffle(X, y, sets)
X_train, y_train, s_train, X_val, y_val, s_val = train_val_split(X, y, sets, train_prop=0.8)
print(X_train.shape, X_val.shape, y_train.shape, y_val.shape )

(8870, 600, 100)
[3 3 6 ... 3 6 1]
(7096, 600, 100) (1774, 600, 100) (7096,) (1774,)


In [88]:
from src.lstm import RNN
batch_size = 32
num_classes = 12
seq_length = X_train.shape[1]
embed_size = X_train.shape[2]
batch_gen = get_batches(X_train, y_train, s_train, batch_size, net_type='lstm')

my_net = RNN(num_classes, batch_size, seq_length, embed_size=100, cell_type='gru',
                 rnn_size=64, num_layers=2, learning_rate=0.005, train_keep_prob=0.5, sampling=False)

In [89]:
X_val = X_val[:32]
y_val = y_val[:32]
my_net.train(batch_gen, X_val, y_val)

Initializing training
step: 1  loss: 2.0190  0.4909 sec/batch
step: 2  loss: 2.3222  0.4641 sec/batch
step: 3  loss: 2.1593  0.4651 sec/batch
step: 4  loss: 2.1203  0.4593 sec/batch
step: 5  loss: 2.2270  0.4709 sec/batch
step: 6  loss: 1.7147  0.4604 sec/batch
step: 7  loss: 1.8309  0.4565 sec/batch
step: 8  loss: 3.2342  0.4624 sec/batch
step: 9  loss: 3.0545  0.4599 sec/batch
step: 10  loss: 2.7494  0.4602 sec/batch
step: 10  validation accuracy 0.0 
[5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5]
[ 6  8  9  4  6  6  3  8  9  6  9  9  9  3  3  4  6  3  3  4 12  6 12  9
  6  6  9  8  9  6  9  6]
step: 11  loss: 3.0725  0.4759 sec/batch
step: 12  loss: 2.3494  0.4657 sec/batch
step: 13  loss: 2.6320  0.4639 sec/batch
step: 14  loss: 2.3067  0.4634 sec/batch
step: 15  loss: 2.3900  0.4587 sec/batch
step: 16  loss: 2.0849  0.4640 sec/batch
step: 17  loss: 2.1625  0.4664 sec/batch
step: 18  loss: 2.3277  0.4643 sec/batch
step: 19  loss: 2.1554  0.4596 sec/batch
step: 20

step: 136  loss: 1.9716  0.4585 sec/batch
step: 137  loss: 1.4088  0.4528 sec/batch
step: 138  loss: 1.7613  0.4535 sec/batch
step: 139  loss: 1.5522  0.4601 sec/batch
step: 140  loss: 1.6818  0.4586 sec/batch
step: 140  validation accuracy 0.28125 
[9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9]
[ 6  8  9  4  6  6  3  8  9  6  9  9  9  3  3  4  6  3  3  4 12  6 12  9
  6  6  9  8  9  6  9  6]
step: 141  loss: 1.6355  0.4559 sec/batch
step: 142  loss: 1.7065  0.4548 sec/batch
step: 143  loss: 1.6323  0.4538 sec/batch
step: 144  loss: 1.6927  0.4587 sec/batch
step: 145  loss: 1.4994  0.4579 sec/batch
step: 146  loss: 1.4771  0.4552 sec/batch
step: 147  loss: 1.6474  0.4576 sec/batch
step: 148  loss: 1.8811  0.4574 sec/batch
step: 149  loss: 1.4619  0.4588 sec/batch
step: 150  loss: 1.7792  0.4516 sec/batch
step: 150  validation accuracy 0.28125 
[9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9]
[ 6  8  9  4  6  6  3  8  9  6  9  9  9  3  3  4  6  3  3  

KeyboardInterrupt: 