This notebook will serve as the master notebook which the TAs can run from start to finish

In [2]:
import pandas as pd
import numpy as np
from src.utils import get_batches, shuffle, train_val_split


%load_ext autoreload
%autoreload 2

In [3]:
# Define the path to the data
data_path = './data/essay_df.pkl'
df = pd.read_pickle(data_path)

In [4]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 8764 entries, 0 to 10683
Data columns (total 8 columns):
essay_id         8764 non-null int64
essay_set        8764 non-null int64
essay            8764 non-null object
domain1_score    8764 non-null int64
essays_embed     8764 non-null object
word_count       8764 non-null int64
max_score        8764 non-null float64
norm_score       8764 non-null float64
dtypes: float64(2), int64(4), object(2)
memory usage: 616.2+ KB


In [5]:
X = np.array(df['essays_embed'])
y = np.array(df['norm_score'])
sets = np.array(df['essay_set'])

In [6]:
X_stacked = np.stack(X, axis=0)
print(X_stacked.shape)

(8764, 600, 100)


In [7]:
X_flat = np.reshape(X_stacked, [X_stacked.shape[0], -1])
print(X_flat.shape)

(8764, 60000)


In [128]:
# instead of flattening, take the average value of each part of vectorized word in essay
#X_flat = np.mean(X_stacked, axis = 1)

In [8]:
X_shuffled, y_shuffled, sets_shuffled = shuffle(X_flat, y, sets)

In [9]:
print(X_shuffled.shape)

(8764, 60000)


In [10]:
X_train, y_train, s_train, X_val, y_val, s_val = train_val_split(X_shuffled, y_shuffled, sets_shuffled, train_prop=0.8)
print(X_train.shape, X_val.shape, y_train.shape, y_val.shape )

[5 6 5 ... 4 1 5]
(7012, 60000) (1752, 60000) (7012,) (1752,)


In [11]:
from src.mlp import MLP
input_dim = X_train.shape[1]
batch_size = 10

batch_gen = get_batches(X_train, y_train, s_train, batch_size, net_type='mlp')

my_net = MLP(input_dim=input_dim, hidden_dims=[100, 100], num_classes=12, reg=False, l2_reg=1e-4)

In [12]:
my_net.train(gen=batch_gen, X_val=X_val, y_val=y_val, s_val=s_val, n_epochs=20, lr=1e-4)

loss for counter 10 is 2.4922592639923096
counter 10: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 20 is 2.4893555641174316
counter 20: valid acc = 0.21289955079555511
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 1 2 3 1 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 30 is 1.9858654737472534
counter 30: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 40 is 1.711927056312561
counter 40: valid acc = 0.1997716873884201
[1 1 1 8 8 4 1 1 1 1 1 1 1 1 1 1 1 1 8 1]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 50 is 1.880250096321106
counter 50: valid acc = 0.17922374606132507
[1 1 1 8 8 4 1 1 1 1 1 1 1 1 1 1 1 1 8 1]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 60 is 1.9276753664016724
counter 60: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 70 is 1.5337598323822021
counter 70: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 80 is 1.6081604957580566
counter 80: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 90 is 1.6711384057998657
counter 90: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 100 is 1.5579503774642944
counter 100: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 110 is 1.5981481075286865
counter 110: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 120 is 1.9655112028121948
counter 120: valid acc = 0.2614155113697052
[3 3 3 9 9 8 2 3 3 3 3 2 3 3 3 2 3 2 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 130 is 2.1824564933776855
counter 130: valid acc = 0.20890410244464874
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 140 is 1.7378581762313843
counter 140: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 150 is 1.4493087530136108
counter 150: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 160 is 1.998924970626831
counter 160: valid acc = 0.20890410244464874
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 170 is 1.5681707859039307
counter 170: valid acc = 0.20947489142417908
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 180 is 2.1195425987243652
counter 180: valid acc = 0.20890410244464874
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 190 is 1.4394320249557495
counter 190: valid acc = 0.20947489142417908
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 200 is 1.9869011640548706
counter 200: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 210 is 1.7892197370529175
counter 210: valid acc = 0.20890410244464874
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 220 is 1.8454251289367676
counter 220: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 230 is 1.6450084447860718
counter 230: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 240 is 1.585709810256958
counter 240: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 250 is 1.6652820110321045
counter 250: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 260 is 1.288562297821045
counter 260: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 270 is 1.5014675855636597
counter 270: valid acc = 0.21004566550254822
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 280 is 1.2381091117858887
counter 280: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 290 is 1.671948790550232
counter 290: valid acc = 0.2083333283662796
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 300 is 1.5001879930496216
counter 300: valid acc = 0.20947489142417908
[3 3 3 8 8 8 2 3 3 3 3 2 3 3 3 2 3 2 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 310 is 1.7944539785385132
counter 310: valid acc = 0.3173516094684601
[3 1 1 9 9 8 2 3 3 3 1 2 3 3 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 320 is 1.5396727323532104
counter 320: valid acc = 0.3076483905315399
[3 1 1 8 9 8 2 3 3 3 1 1 1 1 1 1 1 1 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 330 is 1.2894526720046997
counter 330: valid acc = 0.28367578983306885
[3 1 1 8 8 8 2 3 3 3 1 1 1 1 1 1 1 1 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 340 is 1.5152039527893066
counter 340: valid acc = 0.2859589159488678
[3 1 1 8 8 8 2 3 3 3 1 2 3 3 1 1 1 1 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 350 is 0.9962880611419678
counter 350: valid acc = 0.2699771821498871
[3 3 1 8 8 8 2 3 3 3 1 2 3 3 1 1 3 1 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 360 is 1.452153205871582
counter 360: valid acc = 0.32363012433052063
[3 1 1 9 9 8 2 3 3 3 1 2 3 3 1 1 3 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 370 is 1.363992691040039
counter 370: valid acc = 0.33961185812950134
[3 1 1 9 9 8 2 3 3 3 1 2 3 3 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 380 is 1.4899107217788696
counter 380: valid acc = 0.33390411734580994
[3 1 1 9 9 8 2 3 3 3 1 1 3 1 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 390 is 1.7613537311553955
counter 390: valid acc = 0.30821916460990906
[3 1 1 9 9 8 2 3 3 3 1 1 1 1 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 400 is 1.3467143774032593
counter 400: valid acc = 0.327625572681427
[3 1 1 9 9 8 2 3 3 3 1 1 1 1 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 410 is 1.190451979637146
counter 410: valid acc = 0.33675798773765564
[3 1 1 9 9 8 2 3 3 3 1 1 1 1 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 420 is 1.4907236099243164
counter 420: valid acc = 0.33390411734580994
[3 1 1 9 9 8 2 3 3 3 1 2 3 3 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 430 is 1.3996771574020386
counter 430: valid acc = 0.31221461296081543
[3 1 1 8 9 8 2 3 3 3 1 1 3 1 1 1 1 1 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 440 is 1.0973596572875977
counter 440: valid acc = 0.327625572681427
[3 1 1 8 9 8 2 3 3 3 1 1 3 1 1 1 1 1 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 450 is 0.7682636380195618
counter 450: valid acc = 0.3361872136592865
[3 1 1 9 9 8 2 3 3 3 1 1 3 1 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 460 is 1.4762300252914429
counter 460: valid acc = 0.3321917951107025
[3 1 1 9 9 8 2 3 3 3 1 1 1 1 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 470 is 0.8406582474708557
counter 470: valid acc = 0.3173516094684601
[3 1 1 9 9 8 2 3 3 3 1 1 3 1 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 480 is 1.2321497201919556
counter 480: valid acc = 0.3321917951107025
[3 1 1 9 9 8 2 3 3 3 1 1 3 1 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 490 is 1.2480964660644531
counter 490: valid acc = 0.3390410840511322
[3 1 1 9 9 8 2 3 3 3 1 1 1 1 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 500 is 1.2429252862930298
counter 500: valid acc = 0.33675798773765564
[3 1 1 9 9 8 2 3 3 3 1 1 3 1 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 510 is 2.0172605514526367
counter 510: valid acc = 0.32534247636795044
[3 1 1 9 9 6 2 3 3 3 1 1 1 1 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 520 is 1.0887621641159058
counter 520: valid acc = 0.31963470578193665
[3 1 1 9 9 6 2 2 3 2 1 1 2 1 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 530 is 1.1476936340332031
counter 530: valid acc = 0.2956621050834656
[3 1 1 8 9 8 2 3 3 3 1 1 3 1 1 1 1 1 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 540 is 1.3759324550628662
counter 540: valid acc = 0.3133561611175537
[3 1 1 8 9 8 2 3 3 3 1 1 3 1 1 1 1 1 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 550 is 1.4562139511108398
counter 550: valid acc = 0.3430365324020386
[3 1 1 9 9 8 2 3 3 3 1 2 2 2 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 560 is 1.3627638816833496
counter 560: valid acc = 0.2871004641056061
[3 1 1 8 9 8 2 3 3 3 1 1 1 1 1 1 1 1 8 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


loss for counter 570 is 1.1372416019439697
counter 570: valid acc = 0.31107306480407715
[3 1 1 9 9 8 2 3 3 3 1 1 1 1 1 1 1 1 9 3]
[4 1 1 8 8 5 1 2 3 3 1 2 2 3 1 1 0 1 8 3]


KeyboardInterrupt: 