<a href="https://colab.research.google.com/github/Rt247/Not_NLP_CW/blob/sentence-level-word-embeddings/sentence_level_word_embeddings_with_ff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

Download datasets:

In [1]:
from os.path import exists

if not exists('enzh_data.zip'):
    !wget -O enzh_data.zip https://competitions.codalab.org/my/datasets/download/03e23bd7-8084-4542-997b-6a1ca6dd8a5f
    !unzip enzh_data.zip

--2020-02-25 19:42:12--  https://competitions.codalab.org/my/datasets/download/03e23bd7-8084-4542-997b-6a1ca6dd8a5f
Resolving competitions.codalab.org (competitions.codalab.org)... 129.175.22.230
Connecting to competitions.codalab.org (competitions.codalab.org)|129.175.22.230|:443... connected.
HTTP request sent, awaiting response... 302 FOUND
Location: https://newcodalab.lri.fr/prod-private/dataset_data_file/None/630ec/en-zh.zip?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Expires=86400&X-Amz-SignedHeaders=host&X-Amz-Signature=d08c8c6d890c740ad0656172d1a2e0fc713a2e187ed964709a91a20e20d92c96&X-Amz-Date=20200225T194217Z&X-Amz-Credential=AZIAIOSAODNN7EX123LE%2F20200225%2Fnewcodalab%2Fs3%2Faws4_request [following]
--2020-02-25 19:42:18--  https://newcodalab.lri.fr/prod-private/dataset_data_file/None/630ec/en-zh.zip?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Expires=86400&X-Amz-SignedHeaders=host&X-Amz-Signature=d08c8c6d890c740ad0656172d1a2e0fc713a2e187ed964709a91a20e20d92c96&X-Amz-Date=20200225T194

Check data downloaded successfully:

In [2]:
with open("./train.enzh.src", "r") as enzh_src:
  print("Source: ",enzh_src.readline())
with open("./train.enzh.mt", "r") as enzh_mt:
  print("Translation: ",enzh_mt.readline())
with open("./train.enzh.scores", "r") as enzh_scores:
  print("Score: ",enzh_scores.readline())

Source:  The last conquistador then rides on with his sword drawn.

Translation:  最后的征服者骑着他的剑继续前进.

Score:  -1.5284005772625449



### English Models Setup

Download English models:

In [3]:
!spacy download en_core_web_md
!spacy link en_core_web_md en300

Collecting en_core_web_md==2.1.0
[?25l  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_md-2.1.0/en_core_web_md-2.1.0.tar.gz (95.4MB)
[K     |████████████████████████████████| 95.4MB 5.0MB/s 
[?25hBuilding wheels for collected packages: en-core-web-md
  Building wheel for en-core-web-md (setup.py) ... [?25l[?25hdone
  Created wheel for en-core-web-md: filename=en_core_web_md-2.1.0-cp36-none-any.whl size=97126236 sha256=b9271e98eb9cb4e3cf2c32d0a6fc22ecc5f6fbb169bf0b728be2e4b85e8217c8
  Stored in directory: /tmp/pip-ephem-wheel-cache-9c2rqkog/wheels/c1/2c/5f/fd7f3ec336bf97b0809c86264d2831c5dfb00fc2e239d1bb01
Successfully built en-core-web-md
Installing collected packages: en-core-web-md
Successfully installed en-core-web-md-2.1.0
[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('en_core_web_md')
[38;5;2m✔ Linking successful[0m
/usr/local/lib/python3.6/dist-packages/en_core_web_md -->
/usr/local/lib/py

Load a GloVe English model with dim 100.

Some Chinese models only have **dim 100**, so we will need to **tokenize with spaCy, then embed with GloVe**.

In [4]:
import torchtext
import spacy

# Embedding for English when dim 100
glove = torchtext.vocab.GloVe(name='6B', dim=100)

# Tokenizer for English when dim 100, Tokenizer and Embedding when dim 300
nlp_en = spacy.load('en300')


.vector_cache/glove.6B.zip: 862MB [06:39, 2.16MB/s]                           
 99%|█████████▉| 397721/400000 [00:31<00:00, 23835.67it/s]

Functions for processing English dataset:

In [5]:
import numpy as np
import torch
from nltk import download
from nltk.corpus import stopwords

#downloading stopwords from the nltk package
download('stopwords') #stopwords dictionary, run once
stop_words_en = set(stopwords.words('english'))


def preprocess_en(sentence, nlp):
    text = sentence.lower()
    doc = [token.lemma_ for token in  nlp.tokenizer(text)]
    doc = [word for word in doc if word not in stop_words_en]
    doc = [word for word in doc if word.isalpha()] #restricts string to alphabetic characters only
    return doc

def get_word_vector_en(embeddings, word):
    try:
      vec = embeddings.vectors[embeddings.stoi[word]]
      return vec
    except KeyError:
      #print(f"Word {word} does not exist")
      pass
      

def get_sentence_emb_en(line, nlp):
  text = line.lower()
  l = [token.lemma_ for token in nlp.tokenizer(text)]
  l = ' '.join([word for word in l if word not in stop_words_en])

  sen = nlp(l)
  return sen.vector


[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


### Chinese Models Setup

Download Chinese stopwords:

In [6]:
!wget -c https://github.com/Tony607/Chinese_sentiment_analysis/blob/master/data/chinese_stop_words.txt

--2020-02-25 19:50:39--  https://github.com/Tony607/Chinese_sentiment_analysis/blob/master/data/chinese_stop_words.txt
Resolving github.com (github.com)... 192.30.255.113
Connecting to github.com (github.com)|192.30.255.113|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/html]
Saving to: ‘chinese_stop_words.txt’

chinese_stop_words.     [ <=>                ] 417.14K  --.-KB/s    in 0.1s    

2020-02-25 19:50:40 (3.21 MB/s) - ‘chinese_stop_words.txt’ saved [427150]



Download and load Chinese model with **dim 100** (University of Oslo):

In [7]:
if not exists('zh_100.zip'):
  !wget -O zh_100.zip http://vectors.nlpl.eu/repository/20/35.zip
  !unzip zh_100.zip -d ./zh_100

from gensim.models import KeyedVectors

wv_from_bin_100 = KeyedVectors.load_word2vec_format("./zh_100/model.bin", binary=True) 

--2020-02-25 19:50:41--  http://vectors.nlpl.eu/repository/20/35.zip
Resolving vectors.nlpl.eu (vectors.nlpl.eu)... 129.240.189.225
Connecting to vectors.nlpl.eu (vectors.nlpl.eu)|129.240.189.225|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1458485917 (1.4G) [application/zip]
Saving to: ‘zh_100.zip’


2020-02-25 19:52:10 (15.9 MB/s) - ‘zh_100.zip’ saved [1458485917/1458485917]

Archive:  zh_100.zip
  inflating: ./zh_100/LIST           
  inflating: ./zh_100/meta.json      
  inflating: ./zh_100/model.bin      
  inflating: ./zh_100/model.txt      
  inflating: ./zh_100/README         


  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL


Functions for processing Chinese dataset:

In [0]:
import string
import jieba
import gensim 
import spacy
import numpy as np

stop_words = [ line.rstrip() for line in open('./chinese_stop_words.txt',"r", encoding="utf-8") ]

def preprocess_zh(sentence):
  seg_list = jieba.lcut(sentence,cut_all=True)
  doc = [word for word in seg_list if word not in stop_words]
  docs = [e for e in doc if e.isalpha()]
  return docs


## Process Scores

In [0]:
import spacy
import torchtext
from torchtext import data

f_train_scores = open("./train.enzh.scores", 'r')
zh_train_scores = f_train_scores.readlines()

f_val_scores = open("./dev.enzh.scores", 'r')
zh_val_scores = f_val_scores.readlines()

train_scores = np.array(zh_train_scores).astype(float)
y_train_zh = train_scores

val_scores = np.array(zh_val_scores).astype(float)
y_val_zh = val_scores

## Word Embedding Variants

Calculating various values as an aggregation of a sentence's word embeddings.

In [0]:
def get_embeddings_zh(f, word_vectors):
  file = open(f)
  lines = file.readlines()
  embeddings = []
  for l in lines:
    sent = preprocess_zh(l)
    embeddings_sent = []
    for w in sent:
      try:
        emb = word_vectors[w]
        embeddings_sent.append(emb)
      except:
        pass
    if not embeddings_sent:
      embeddings_sent = [[0] * 100]
    embeddings.append(np.array(embeddings_sent))
  return embeddings

def get_embeddings_en(f, word_vectors, nlp):
  file = open(f) 
  lines = file.readlines() 
  embeddings = []
  for l in lines:
    sent = preprocess_en(l, nlp)
    embeddings_sent = []
    for w in sent:
      emb = get_word_vector_en(word_vectors, w)
      if emb is not None:
        embeddings_sent.append(emb.numpy())
    if not embeddings_sent:
      embeddings_sent = [[0] * 100]
    embeddings.append(np.array(embeddings_sent))
  return embeddings


In [0]:
zh_train_mt_100_emb = get_embeddings_zh("./train.enzh.mt", wv_from_bin_100)
zh_train_src_100_emb = get_embeddings_en("./train.enzh.src", glove, nlp_en)

zh_val_mt_100_emb = get_embeddings_zh("./dev.enzh.mt", wv_from_bin_100)
zh_val_src_100_emb = get_embeddings_en("./dev.enzh.src", glove, nlp_en)

zh_test_mt_100_emb = get_embeddings_zh("./test.enzh.mt", wv_from_bin_100)
zh_test_src_100_emb = get_embeddings_en("./test.enzh.src", glove, nlp_en)

### Average Word Embedding

In [0]:
def average_word_embeddings(src, mt):
  src_m = np.array([np.mean(e, axis=0) for e in src])
  mt_m = np.array([np.mean(e, axis=0) for e in mt])
  return np.concatenate((src_m, mt_m), axis=1)


In [0]:
X_train_zh_100_a = average_word_embeddings(zh_train_src_100_emb, zh_train_mt_100_emb)
X_val_zh_100_a = average_word_embeddings(zh_val_src_100_emb, zh_val_mt_100_emb)


### Sum Word Embedding

In [0]:
def sum_word_embeddings(src, mt):
  src_m = np.array([np.sum(e, axis=0) for e in src])
  mt_m = np.array([np.sum(e, axis=0) for e in mt])
  return np.concatenate((src_m, mt_m), axis=1)

In [0]:
X_train_zh_100_s = sum_word_embeddings(zh_train_src_100_emb, zh_train_mt_100_emb)
X_val_zh_100_s = sum_word_embeddings(zh_val_src_100_emb, zh_val_mt_100_emb)

### Min/Max Word Embedding

In [0]:
def min_word_embeddings(src, mt):
  src_m = np.array([np.amin(e, axis=0) for e in src])
  mt_m = np.array([np.amin(e, axis=0) for e in mt])
  return np.concatenate((src_m, mt_m), axis=1)

def max_word_embeddings(src, mt):
  src_m = np.array([np.amax(e, axis=0) for e in src])
  mt_m = np.array([np.amax(e, axis=0) for e in mt])
  return np.concatenate((src_m, mt_m), axis=1)

In [0]:
X_train_zh_100_min = min_word_embeddings(zh_train_src_100_emb, zh_train_mt_100_emb)
X_val_zh_100_min = min_word_embeddings(zh_val_src_100_emb, zh_val_mt_100_emb)

X_train_zh_100_max = max_word_embeddings(zh_train_src_100_emb, zh_train_mt_100_emb)
X_val_zh_100_max = max_word_embeddings(zh_val_src_100_emb, zh_val_mt_100_emb)

## Support Vector Regression

In [0]:
# Setup
def rmse(predictions, targets):
    return np.sqrt(((predictions - targets) ** 2).mean())

from sklearn.svm import SVR
from scipy.stats.stats import pearsonr


### Average

In [136]:
for k in ['linear','poly','rbf','sigmoid']:
    clf_t = SVR(kernel=k)
    clf_t.fit(X_train_zh_100_a, y_train_zh)
    print(k)
    predictions = clf_t.predict(X_val_zh_100_a)
    pearson = pearsonr(y_val_zh, predictions)
    print(f'RMSE: {rmse(predictions,y_val_zh)} Pearson {pearson[0]}')
    print()

"""
linear
RMSE: 0.9023060200743632 Pearson 0.3072337367525771

poly
RMSE: 0.8993484622463825 Pearson 0.3044765902787262

rbf
RMSE: 0.8895045629309669 Pearson 0.3424205188758487

sigmoid
RMSE: 7.04363364235969 Pearson -0.03875077289090174
"""

linear
RMSE: 0.9023060200743632 Pearson 0.3072337367525771

poly
RMSE: 0.8993484622463825 Pearson 0.3044765902787262

rbf
RMSE: 0.8895045629309669 Pearson 0.3424205188758487

sigmoid
RMSE: 7.04363364235969 Pearson -0.03875077289090174



'\nlinear\nRMSE: 0.9023060200743632 Pearson 0.3072337367525771\n\npoly\nRMSE: 0.8993484622463825 Pearson 0.3044765902787262\n\nrbf\nRMSE: 0.8895045629309669 Pearson 0.3424205188758487\n\nsigmoid\nRMSE: 7.04363364235969 Pearson -0.03875077289090174\n'

### Sum



In [139]:
for k in ['linear','poly','rbf','sigmoid']:
    clf_t = SVR(kernel=k)
    clf_t.fit(X_train_zh_100_s, y_train_zh)
    print(k)
    predictions = clf_t.predict(X_val_zh_100_s)
    pearson = pearsonr(y_val_zh, predictions)
    print(f'RMSE: {rmse(predictions,y_val_zh)} Pearson {pearson[0]}')
    print()

"""
linear
RMSE: 0.9227835782009712 Pearson 0.2425756512117508

poly
RMSE: 0.9516944090249123 Pearson 0.1856020876506383

rbf
RMSE: 0.9054660793535703 Pearson 0.29358229219039883

sigmoid
RMSE: 33.540682964794215 Pearson -0.013899488606520101
"""

linear
RMSE: 0.9227835782009712 Pearson 0.2425756512117508

poly
RMSE: 0.9516944090249123 Pearson 0.1856020876506383

rbf
RMSE: 0.9054660793535703 Pearson 0.29358229219039883

sigmoid
RMSE: 33.540682964794215 Pearson -0.013899488606520101



'\nlinear\nRMSE: 0.9227835782009712 Pearson 0.2425756512117508\n\npoly\nRMSE: 0.9516944090249123 Pearson 0.1856020876506383\n\nrbf\nRMSE: 0.9054660793535703 Pearson 0.29358229219039883\n\nsigmoid\nRMSE: 33.540682964794215 Pearson -0.013899488606520101\n'

### Min/Max



In [0]:
print("min")
for k in ['linear','poly','rbf','sigmoid']:
    clf_t = SVR(kernel=k)
    clf_t.fit(X_train_zh_100_min, y_train_zh)
    print(k)
    predictions = clf_t.predict(X_val_zh_100_min)
    pearson = pearsonr(y_val_zh, predictions)
    print(f'RMSE: {rmse(predictions,y_val_zh)} Pearson {pearson[0]}')
    print()

print("max")
for k in ['linear','poly','rbf','sigmoid']:
    clf_t = SVR(kernel=k)
    clf_t.fit(X_train_zh_100_max, y_train_zh)
    print(k)
    predictions = clf_t.predict(X_val_zh_100_max)
    pearson = pearsonr(y_val_zh, predictions)
    print(f'RMSE: {rmse(predictions,y_val_zh)} Pearson {pearson[0]}')
    print()

"""
min
linear
RMSE: 0.9143167236911504 Pearson 0.27536526116188104

poly
RMSE: 0.970271685744294 Pearson 0.23242310658976342

rbf
RMSE: 0.9025641754217032 Pearson 0.2973480407106239

sigmoid
RMSE: 28.280921526236867 Pearson 0.009983214811094767

max
linear
RMSE: 0.9284620494004884 Pearson 0.2359812007567792

poly
RMSE: 1.0235373606067586 Pearson 0.17306578845193377

rbf
RMSE: 0.9079366185041642 Pearson 0.27905090847445657

sigmoid
RMSE: 28.406444719807993 Pearson -0.004419994996657859
"""


min
linear
RMSE: 0.9143167236911504 Pearson 0.27536526116188104

poly
RMSE: 0.970271685744294 Pearson 0.23242310658976342

rbf
RMSE: 0.9025641754217032 Pearson 0.2973480407106239

sigmoid
RMSE: 28.280921526236867 Pearson 0.009983214811094767

max
linear
RMSE: 0.9284620494004884 Pearson 0.2359812007567792



### Combinations

In [0]:
# min + max
X_train_100_mm = [sum(t, []) for t in zip(zh_train_src_100_min, zh_train_src_100_max, zh_train_mt_100_min, zh_train_mt_100_max)]
X_train_zh_100_mm = np.array(X_train_100_mm)
X_val_100_mm = [sum(t, []) for t in zip(zh_val_src_100_min, zh_val_src_100_max, zh_val_mt_100_min, zh_val_mt_100_max)]
X_val_zh_100_mm = np.array(X_val_100_mm)

# min + avg + max
X_train_100_mam = [sum(t, []) for t in zip(zh_train_src_100_min, zh_train_src_100_a, zh_train_src_100_max, zh_train_mt_100_min, zh_train_mt_100_a, zh_train_mt_100_max)]
X_train_zh_100_mam = np.array(X_train_100_mam)
X_val_100_mam = [sum(t, []) for t in zip(zh_val_src_100_min, zh_val_src_100_a, zh_val_src_100_max, zh_val_mt_100_min, zh_val_mt_100_a, zh_val_mt_100_max)]
X_val_zh_100_mam = np.array(X_val_100_mam)

# avg + sum
X_train_100_as = [sum(t, []) for t in zip(zh_train_src_100_a, zh_train_src_100_s, zh_train_mt_100_a, zh_train_mt_100_s)]
X_train_zh_100_as = np.array(X_train_100_mam)
X_val_100_as = [sum(t, []) for t in zip(zh_val_src_100_a, zh_val_src_100_s, zh_val_mt_100_a, zh_val_mt_100_s)]
X_val_zh_100_as = np.array(X_val_100_mam)

# min + avg + max + sum
X_train_100_mams = [sum(t, []) for t in zip(zh_train_src_100_min, zh_train_src_100_a, zh_train_src_100_max, zh_train_src_100_s, zh_train_mt_100_min, zh_train_mt_100_a, zh_train_mt_100_max, zh_train_src_100_s)]
X_train_zh_100_mams = np.array(X_train_100_mam)
X_val_100_mams = [sum(t, []) for t in zip(zh_val_src_100_min, zh_val_src_100_a, zh_val_src_100_max, zh_val_src_100_s, zh_val_mt_100_min, zh_val_mt_100_a, zh_val_mt_100_max, zh_val_mt_100_s)]
X_val_zh_100_mams = np.array(X_val_100_mam)

In [0]:
print("min + max")
for k in ['linear','poly','rbf','sigmoid']:
    clf_t = SVR(kernel=k)
    clf_t.fit(X_train_zh_100_mm, y_train_zh)
    print(k)
    predictions = clf_t.predict(X_val_zh_100_mm)
    pearson = pearsonr(y_val_zh, predictions)
    print(f'RMSE: {rmse(predictions,y_val_zh)} Pearson {pearson[0]}')
    print()

print("min + avg + max")
for k in ['linear','poly','rbf','sigmoid']:
    clf_t = SVR(kernel=k)
    clf_t.fit(X_train_zh_100_mam, y_train_zh)
    print(k)
    predictions = clf_t.predict(X_val_zh_100_mam)
    pearson = pearsonr(y_val_zh, predictions)
    print(f'RMSE: {rmse(predictions,y_val_zh)} Pearson {pearson[0]}')
    print()

print("avg + sum")
for k in ['linear','poly','rbf','sigmoid']:
    clf_t = SVR(kernel=k)
    clf_t.fit(X_train_zh_100_as, y_train_zh)
    print(k)
    predictions = clf_t.predict(X_val_zh_100_as)
    pearson = pearsonr(y_val_zh, predictions)
    print(f'RMSE: {rmse(predictions,y_val_zh)} Pearson {pearson[0]}')
    print()

print("min + avg + max + sum")
for k in ['linear','poly','rbf','sigmoid']:
    clf_t = SVR(kernel=k)
    clf_t.fit(X_train_zh_100_mams, y_train_zh)
    print(k)
    predictions = clf_t.predict(X_val_zh_100_mams)
    pearson = pearsonr(y_val_zh, predictions)
    print(f'RMSE: {rmse(predictions,y_val_zh)} Pearson {pearson[0]}')
    print()

"""
min + max
linear
RMSE: 0.9141733794206062 Pearson 0.27337556488854026

poly
RMSE: 0.9063297287258996 Pearson 0.31037822893487643

rbf
RMSE: 0.9050521266690811 Pearson 0.3130501459492051

sigmoid
RMSE: 3.584827555910887 Pearson -0.015479223808303745

min + avg + max
linear
RMSE: 0.9081466259557258 Pearson 0.28795604452909246

poly
RMSE: 0.8962147090765858 Pearson 0.334701838648443

rbf
RMSE: 0.8964166238974365 Pearson 0.3355584468671853

sigmoid
RMSE: 2.647039384739339 Pearson 0.006553679583206628


"""

min + max
linear
RMSE: 0.9141733794206062 Pearson 0.27337556488854026

poly
RMSE: 0.9063297287258996 Pearson 0.31037822893487643

rbf
RMSE: 0.9050521266690811 Pearson 0.3130501459492051

sigmoid
RMSE: 3.584827555910887 Pearson -0.015479223808303745

min + avg + max
linear
RMSE: 0.9081466259557258 Pearson 0.28795604452909246

poly
RMSE: 0.8962147090765858 Pearson 0.334701838648443

rbf
RMSE: 0.8964166238974365 Pearson 0.3355584468671853

sigmoid
RMSE: 2.647039384739339 Pearson 0.006553679583206628

avg + sum
linear
RMSE: 0.9081466259557258 Pearson 0.28795604452909246

poly
RMSE: 0.8962147090765858 Pearson 0.334701838648443

rbf
RMSE: 0.8964166238974365 Pearson 0.3355584468671853

sigmoid
RMSE: 2.647039384739339 Pearson 0.006553679583206628

min + avg + max + sum
linear
RMSE: 0.9081466259557258 Pearson 0.28795604452909246

poly
RMSE: 0.8962147090765858 Pearson 0.334701838648443

rbf
RMSE: 0.8964166238974365 Pearson 0.3355584468671853

sigmoid
RMSE: 2.647039384739339 Pearson 0.00655367958

'\nmin + max\nlinear\nRMSE: 0.9165262212864648 Pearson 0.26668369681069803\n\npoly\nRMSE: 0.9081786905379852 Pearson 0.3058041114768354\n\nrbf\nRMSE: 0.9084429289156349 Pearson 0.30464644770878846\n\nsigmoid\nRMSE: 3.5248833190245237 Pearson -0.006792767476776504\n\nmin + avg + max\nlinear\nRMSE: 0.9081244309157551 Pearson 0.28851234087420935\n\npoly\nRMSE: 0.8982943971746072 Pearson 0.3296975280939544\n\nrbf\nRMSE: 0.8992479975159785 Pearson 0.3292233306043897\n\nsigmoid\nRMSE: 2.600000956750181 Pearson 0.014142349847365135\n\navg + sum\nlinear\nRMSE: 0.9081244309157551 Pearson 0.28851234087420935\n\npoly\nRMSE: 0.8982943971746072 Pearson 0.3296975280939544\n\nrbf\nRMSE: 0.8992479975159785 Pearson 0.3292233306043897\n\nsigmoid\nRMSE: 2.600000956750181 Pearson 0.014142349847365135\n\nmin + avg + max + sum\nlinear\nRMSE: 0.9081244309157551 Pearson 0.28851234087420935\n\npoly\nRMSE: 0.8982943971746072 Pearson 0.3296975280939544\n\nrbf\nRMSE: 0.8992479975159785 Pearson 0.3292233306043897\

## FFNN

### Setup Environment

In [0]:
import torch
from torch import nn
import time
import math

###############
# Torch setup #
###############
print('Torch version: {}, CUDA: {}'.format(torch.__version__, torch.version.cuda))
cuda_available = torch.cuda.is_available()
if not torch.cuda.is_available():
  print('WARNING: You may want to change the runtime to GPU for Neural LM experiments!')
  DEVICE = 'cpu'
else:
  DEVICE = 'cuda:0'


Torch version: 1.4.0, CUDA: 10.1


### Feed Forward Neural Network

In [0]:
import torch.nn.functional as F
import torch.utils.data as Data
from torch.autograd import Variable
from sklearn.metrics import mean_squared_error
from scipy.stats.stats import pearsonr


def ffln(train, valid, hidden_sizes=[64], batch_size=64, epochs=100, weight_decay=0, verbose=2, early_stop=True):
  torch.manual_seed(42)

  # Setup NN
  sizes = [train[0].size] + hidden_sizes
  prev_s = None
  layers = []
  for s in sizes:
    if prev_s:
      layers.append(nn.Linear(prev_s, s).cuda())
      layers.append(nn.LeakyReLU().cuda())
    prev_s = s
  layers.append(nn.Linear(prev_s, 1).cuda())

  net = nn.Sequential(*layers)
  
  optimizer = torch.optim.Adam(net.parameters(), lr=0.0001, weight_decay=weight_decay)
  loss_func = nn.MSELoss()

  ## Setup inputs
  net_X = Variable(torch.from_numpy(train))
  net_y = Variable(torch.from_numpy(y_train_zh))
  torch_dataset = Data.TensorDataset(net_X, net_y)
  
  loader = Data.DataLoader(
      dataset=torch_dataset,
      batch_size=batch_size,
      shuffle=True,
      num_workers=2,
  )
  
  net = net.float()

  ## Training
  final_epoch = 0
  last_pearson = None
  for epoch in range(epochs):
    training_loss = 0
    for step, (batch_x, batch_y) in enumerate(loader):
      b_x = Variable(batch_x.float().to(DEVICE))
      b_y = Variable(batch_y.float().to(DEVICE))
      prediction = torch.flatten(net(b_x))
      loss = loss_func(prediction, b_y)
      training_loss += mean_squared_error(b_y.cpu().detach().numpy(), prediction.cpu().detach().numpy())
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    # Evaluate validation loss every 10 epochs
    if epoch % 10 == 0:
      training_loss = training_loss / step
      with torch.no_grad():
        net.eval()
        net_val = Variable(torch.from_numpy(valid).float().to(DEVICE))
        pred = torch.flatten(net.forward(net_val)).cpu().detach().numpy()
        net.train()
        pearson = pearsonr(y_val_zh, pred)
        # If using early stopping, then stop when validation loss has increased
        # since the last time it was checked.
        if last_pearson and early_stop and last_pearson > pearson[0]:
          final_epoch = epoch
          break
        else:
          last_pearson = pearson[0]
        if verbose >= 2:
          print(f"Epoch {epoch} Training Loss: {training_loss}, Pearson Score: {pearson[0]}, MSE: {mean_squared_error(y_val_zh, pred)}")
    final_epoch = epoch
  
  # Return the final validation Pearson score
  with torch.no_grad():
    net.eval()
    net_val = Variable(torch.from_numpy(valid).float().to(DEVICE))
    pred = torch.flatten(net.forward(net_val)).cpu().detach().numpy()
    net_train = Variable(torch.from_numpy(train).float().to(DEVICE))
    pred_t = torch.flatten(net.forward(net_train)).cpu().detach().numpy()
    net.train()
    pearson = pearsonr(y_val_zh, pred)
    pearson_t = pearsonr(y_train_zh, pred_t)
    if verbose >= 1:
      print(f"Final Validation Pearson Score: {pearson[0]}")
    return pearson[0], pearson_t[0], final_epoch, net


### Grid Search with One Layer

In [0]:
hidden_sizes = [32, 64, 128, 256, 512]
batch_sizes = [32, 64, 128, 256, 512]
weight_decay = 0.001

print("| Hidden Sizes | Batch Size | Final Epoch | Training Pearson Score | Validation Pearson Score |")
print("|---|---|---|---|---|")
for h in hidden_sizes:
  for b in batch_sizes:
    p, p_t, e, _ = ffln(X_train_zh_100_a, X_val_zh_100_a, hidden_sizes=[h], batch_size=b, epochs=500, verbose=0, weight_decay=weight_decay)
    print(f"| {h} | {b} | {e} | {p_t} | {p} |")



| Hidden Sizes | Batch Size | Final Epoch | Training Pearson Score | Validation Pearson Score |
|---|---|---|---|---|
| 32 | 32 | 90 | 0.28910033875627006 | 0.26650870390463255 |
| 32 | 64 | 60 | 0.28783177387349757 | 0.26429830239885116 |
| 32 | 128 | 110 | 0.28867307595844666 | 0.2655793457472889 |
| 32 | 256 | 110 | 0.28750555067685357 | 0.2636813566968737 |
| 32 | 512 | 170 | 0.2884581717795725 | 0.26508557152058276 |
| 64 | 32 | 40 | 0.2867930311266828 | 0.26372049701310835 |
| 64 | 64 | 40 | 0.2867451608821309 | 0.26402893717198245 |


Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)


KeyboardInterrupt: ignored

  File "/usr/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 230, in _feed
    close()
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor


#### Results from One Layer Grid Search

##### Weight Decay = 0

| Hidden Sizes | Batch Size | Final Epoch | Training Pearson Score | Validation Pearson Score |
|---|---|---|---|---|
| 32 | 32 | 70 | 0.48299684617124355 | 0.35568926872569034 |
| 32 | 64 | 90 | 0.47402861701221255 | 0.35484624967389417 |
| 32 | 128 | 120 | 0.46769460605137325 | 0.354001807956961 |
| 32 | 256 | 140 | 0.4451696316549598 | 0.35419410833769244 |
| 32 | 512 | 200 | 0.4439794759089259 | 0.35376695722948476 |
| 64 | 32 | 50 | 0.5050977337865824 | 0.35642318172718346 |
| 64 | 64 | 60 | 0.48461355581325777 | 0.3576238432523523 |
| 64 | 128 | 90 | 0.4899753302682104 | 0.3561842136832902 |
| 64 | 256 | 120 | 0.48020191069491663 | 0.35734185721004297 |
| 64 | 512 | 200 | 0.496231710595813 | 0.35519182285016304 |
| 128 | 32 | 40 | 0.5300551965029648 | 0.35997446559039575 |
| 128 | 64 | 50 | 0.5153791051222808 | 0.3611079819831347 |
| 128 | 128 | 80 | 0.5344236286008841 | 0.3616800834614517 |
| 128 | 256 | 100 | 0.5108236652821491 | 0.36388654741407084 |
| 128 | 512 | 130 | 0.4962281367272859 | **0.3639496293878147** |
| 256 | 32 | 30 | 0.5584957328071897 | 0.36131444904084803 |
| 256 | 64 | 40 | 0.5495577458050304 | 0.3600462018518822 |
| 256 | 128 | 50 | 0.5320562994631689 | 0.36105712427084846 |
| 256 | 256 | 70 | 0.5245011760276505 | 0.36195596951857956 |
| 256 | 512 | 90 | 0.5054828308375421 | 0.3613732695075446 |
| 512 | 32 | 20 | 0.556784413338219 | 0.359580644331573 |
| 512 | 64 | 40 | 0.6384220275150658 | 0.3584139600716416 |
| 512 | 128 | 50 | 0.6129006760987404 | 0.35740668383572766 |
| 512 | 256 | 60 | 0.5723914156685365 | 0.3580794536925124 |
| 512 | 512 | 70 | 0.5298419125443378 | 0.35974505667758333 |

##### Weight Decay = 0.01

| Hidden Sizes | Batch Size | Final Epoch | Training Pearson Score | Validation Pearson Score |
|---|---|---|---|---|
| 32 | 32 | 90 | 0.45100997469770016 | 0.3582803012815051 |
| 32 | 64 | 120 | 0.4503648686536401 | 0.3612327906812081 |
| 32 | 128 | 140 | 0.43456235490520323 | 0.35732681839123887 |
| 32 | 256 | 140 | 0.40911696781469237 | 0.35051562078207565 |
| 32 | 512 | 380 | 0.4573539398883728 | 0.36226482868230936 |
| 64 | 32 | 110 | 0.5150871707999816 | 0.3638623578357723 |
| 64 | 64 | 140 | 0.5101560100532213 | 0.36533322301784016 |
| 64 | 128 | 190 | 0.501001080294028 | 0.364107813010426 |
| 64 | 256 | 170 | 0.45934007507494534 | 0.3602363082137179 |
| 64 | 512 | 310 | 0.4794312797086712 | 0.36314950087155257 |
| 128 | 32 | 80 | 0.5254772408655755 | 0.36413656838985187 |
| 128 | 64 | 80 | 0.4924648511059825 | **0.36534145933484247** |
| 128 | 128 | 110 | 0.4905548920330267 | 0.3652737593370598 |
| 128 | 256 | 120 | 0.4596690298995526 | 0.3607400995139613 |
| 128 | 512 | 190 | 0.4703827088565318 | 0.35990875219965446 |
| 256 | 32 | 50 | 0.5078747999445999 | 0.35920905100898687 |
| 256 | 64 | 70 | 0.512184065421008 | 0.3620287430543616 |
| 256 | 128 | 70 | 0.47870476273285056 | 0.3620495830468129 |
| 256 | 256 | 120 | 0.49626925570140756 | 0.36270427046399956 |
| 256 | 512 | 150 | 0.47698570786025907 | 0.35977281719496734 |
| 512 | 32 | 50 | 0.5401908906558154 | 0.3570021396696617 |
| 512 | 64 | 50 | 0.5072977378281833 | 0.3605957816283808 |
| 512 | 128 | 90 | 0.5415603401013815 | 0.36270088337099327 |
| 512 | 256 | 90 | 0.4955292147759141 | 0.3635272873351458 |
| 512 | 512 | 120 | 0.48188509952735753 | 0.36111544420801234 |

##### Weight Decay = 0.001

| Hidden Sizes | Batch Size | Final Epoch | Training Pearson Score | Validation Pearson Score |
|---|---|---|---|---|
| 32 | 32 | 80 | 0.49336796221998863 | 0.3595207129989028 |
| 32 | 64 | 90 | 0.4686487002414254 | 0.3582774027958467 |
| 32 | 128 | 150 | 0.489292727653548 | 0.3600779802129854 |
| 32 | 256 | 180 | 0.4674018276179893 | 0.35810249130930877 |
| 32 | 512 | 280 | 0.47452774933243747 | 0.36091831430719745 |
| 64 | 32 | 50 | 0.49944192717692193 | 0.35990356184421063 |
| 64 | 64 | 80 | 0.5136684848041514 | 0.35919112510262313 |
| 64 | 128 | 90 | 0.4849341517291386 | 0.3586888914929235 |
| 64 | 256 | 150 | 0.5024117906414646 | 0.36029578906596255 |
| 64 | 512 | 200 | 0.49155204667787317 | 0.3579592555397539 |
| 128 | 32 | 40 | 0.5212065808627342 | **0.3655701207861885** |
| 128 | 64 | 60 | 0.5320505549107768 | 0.3629479167009998 |
| 128 | 128 | 80 | 0.5274899085367852 | 0.36497695862740553 |
| 128 | 256 | 100 | 0.5040941793383894 | 0.36454298354300246 |
| 128 | 512 | 120 | 0.4780463011632069 | 0.36431483254699737 |
| 256 | 32 | 40 | 0.5933778761141248 | 0.36153553996319515 |
| 256 | 64 | 40 | 0.5345114685466408 | 0.36418110838174944 |
| 256 | 128 | 70 | 0.576063105124408 | 0.3638639717583646 |
| 256 | 256 | 70 | 0.5120956157535899 | 0.3650715612822175 |
| 256 | 512 | 110 | 0.5253553772479195 | 0.36316481818471735 |
| 512 | 32 | 30 | 0.6222180693364254 | 0.36158233314745136 |
| 512 | 64 | 50 | 0.6632538171675477 | 0.35826857180175015 |
| 512 | 128 | 50 | 0.591131995809352 | 0.36331944190464277 |
| 512 | 256 | 70 | 0.5824858191970835 | 0.3622091989794235 |
| 512 | 512 | 90 | 0.5600560401158342 | 0.36213104755775855 |

##### Weight Decay = 0.0001

| Hidden Sizes | Batch Size | Final Epoch | Training Pearson Score | Validation Pearson Score |
|---|---|---|---|---|
| 32 | 32 | 70 | 0.4830532067296478 | 0.3555060634157582 |
| 32 | 64 | 90 | 0.47275643227640973 | 0.3543644805032278 |
| 32 | 128 | 120 | 0.46714670011179266 | 0.3544942159923127 |
| 32 | 256 | 180 | 0.47071218804547743 | 0.3566361717224038 |
| 32 | 512 | 240 | 0.46171948937276924 | 0.35494854249653307 |
| 64 | 32 | 50 | 0.5046060631003529 | 0.3561336258710358 |
| 64 | 64 | 70 | 0.5006609401752118 | 0.35764026845550645 |
| 64 | 128 | 90 | 0.48977853618588074 | 0.35660579119065317 |
| 64 | 256 | 130 | 0.4884750707106939 | 0.35750823890033995 |
| 64 | 512 | 160 | 0.47038291546523764 | 0.3553940025455295 |
| 128 | 32 | 40 | 0.5296528806458696 | 0.36051136999747474 |
| 128 | 64 | 50 | 0.5138629835356658 | 0.3614184783217593 |
| 128 | 128 | 80 | 0.5340628954126573 | 0.36152075923793214 |
| 128 | 256 | 100 | 0.5101477581738477 | 0.3628146586490763 |
| 128 | 512 | 130 | 0.4954820923170049 | **0.3635680289756267** |
| 256 | 32 | 30 | 0.5569831868147893 | 0.36280076970257435 |
| 256 | 64 | 40 | 0.5483943339029308 | 0.36122493157130764 |
| 256 | 128 | 60 | 0.5616400250894735 | 0.3620758934990978 |
| 256 | 256 | 70 | 0.5231049330654992 | 0.3630111659792139 |
| 256 | 512 | 110 | 0.5381738236831985 | 0.36019270269875187 |
| 512 | 32 | 20 | 0.554446011862658 | 0.3598765906754196 |
| 512 | 64 | 40 | 0.6356313286771624 | 0.361338152367777 |
| 512 | 128 | 50 | 0.6101331943464575 | 0.3589041646949778 |
| 512 | 256 | 60 | 0.5699815807507783 | 0.3602698468993568 |
| 512 | 512 | 70 | 0.5273997386224675 | 0.3608117223004388 |

### Grid Search with Two Layers

In [0]:
import itertools

sizes = [32, 64, 128, 256, 512]
batch_sizes = [32, 64, 128, 256, 512]
weight_decay=0.01
sizes = list(list(t) for t in itertools.product(sizes, sizes))
search = list(itertools.product(sizes, batch_sizes))

print("| Hidden Sizes | Batch Size | Final Epoch | Training Pearson Score | Validation Pearson Score |")
print("|---|---|---|---|---|")
for s, b in search:
  p, p_t, e, _ = ffln(X_train_zh_100_a, X_val_zh_100_a, hidden_sizes=s, batch_size=b, epochs=500, verbose=0, weight_decay=weight_decay)
  print(f"| {s} | {b} | {e} | {p_t} | {p} |")

| Hidden Sizes | Batch Size | Final Epoch | Training Pearson Score | Validation Pearson Score |
|---|---|---|---|---|
| [64, 64] | 512 | 210 | 0.5227163099320284 | 0.36775365203777477 |


#### Results from Two layer Grid Search

##### Weight Decay = 0



| Hidden Sizes | Batch Size | Final Epoch | Training Pearson Score | Validation Pearson Score |
|---|---|---|---|---|
| [32, 32] | 32 | 40 | 0.5081770770385975 | 0.34801201441426727 |
| [32, 32] | 64 | 50 | 0.4897875190023316 | 0.35180442781202537 |
| [32, 32] | 128 | 60 | 0.471319249596159 | 0.35216127932418095 |
| [32, 32] | 256 | 80 | 0.4592414925340225 | 0.3519866572138941 |
| [32, 32] | 512 | 120 | 0.46425772682932503 | 0.35151247242958056 |
| [32, 64] | 32 | 30 | 0.4796861792859908 | 0.3498461575667321 |
| [32, 64] | 64 | 50 | 0.5138079761615646 | 0.35160007685387956 |
| [32, 64] | 128 | 60 | 0.49111954140700154 | 0.35501585078311043 |
| [32, 64] | 256 | 80 | 0.48100806849344646 | 0.3527689110013705 |
| [32, 64] | 512 | 110 | 0.47345423786359253 | 0.35092546076354725 |
| [32, 128] | 32 | 30 | 0.5008741717910612 | 0.33964560616524064 |
| [32, 128] | 64 | 40 | 0.5011052842994936 | 0.33933330300888603 |
| [32, 128] | 128 | 40 | 0.4582418901142382 | 0.34739491249119014 |
| [32, 128] | 256 | 60 | 0.4621263640810455 | 0.3465269388711895 |
| [32, 128] | 512 | 90 | 0.465597542007843 | 0.3457915977785479 |
| [32, 256] | 32 | 30 | 0.5329860116715605 | 0.34357907783261743 |
| [32, 256] | 64 | 30 | 0.4865428087678733 | 0.35572479528131656 |
| [32, 256] | 128 | 40 | 0.48052476891754525 | 0.3543011238452829 |
| [32, 256] | 256 | 60 | 0.4860260578215678 | 0.3525535644582506 |
| [32, 256] | 512 | 80 | 0.46862449033538783 | 0.3529549622986453 |
| [32, 512] | 32 | 20 | 0.5055240087029065 | 0.32601374416241413 |
| [32, 512] | 64 | 30 | 0.5188875280925833 | 0.3179975788422171 |
| [32, 512] | 128 | 30 | 0.4699261060747212 | 0.33679649307565207 |
| [32, 512] | 256 | 40 | 0.455737890436249 | 0.3344219385019655 |
| [32, 512] | 512 | 60 | 0.4599190009601511 | 0.3328937301128072 |
| [64, 32] | 32 | 30 | 0.5355110287248864 | 0.3478778050987554 |
| [64, 32] | 64 | 40 | 0.5264094792582278 | 0.34376367455800805 |
| [64, 32] | 128 | 40 | 0.4779046113722663 | 0.35275295871734835 |
| [64, 32] | 256 | 60 | 0.4822133899414166 | 0.3495061148312315 |
| [64, 32] | 512 | 80 | 0.46967093639034196 | 0.35023971412168714 |
| [64, 64] | 32 | 30 | 0.548199615084581 | 0.34246616954252085 |
| [64, 64] | 64 | 40 | 0.5384632016469784 | 0.34812512840948795 |
| [64, 64] | 128 | 50 | 0.5161190781765183 | 0.355902634793197 |
| [64, 64] | 256 | 60 | 0.4835348501398026 | 0.3598794758880293 |
| [64, 64] | 512 | 90 | 0.489188526573068 | **0.3620220512957468** |
| [64, 128] | 32 | 20 | 0.5238038214951334 | 0.3414280655711996 |
| [64, 128] | 64 | 30 | 0.5380299500555019 | 0.3414050223679859 |
| [64, 128] | 128 | 40 | 0.5335752007421125 | 0.34065157380623134 |
| [64, 128] | 256 | 50 | 0.5066219234904855 | 0.3460378003673324 |
| [64, 128] | 512 | 60 | 0.4742384537082817 | 0.35069956841809286 |
| [64, 256] | 32 | 20 | 0.5211886039936586 | 0.352351231939332 |
| [64, 256] | 64 | 30 | 0.5370382293087058 | 0.3510801946890659 |
| [64, 256] | 128 | 40 | 0.53734716468934 | 0.3455473606403491 |
| [64, 256] | 256 | 50 | 0.5112961739048506 | 0.3503522025895193 |
| [64, 256] | 512 | 60 | 0.4819179802638056 | 0.35460296958205234 |
| [64, 512] | 32 | 20 | 0.5621780968739095 | 0.3268183990223234 |
| [64, 512] | 64 | 20 | 0.5067892731089352 | 0.34589821245923974 |
| [64, 512] | 128 | 30 | 0.5261678086444983 | 0.34052843629794427 |
| [64, 512] | 256 | 40 | 0.511546069960251 | 0.34396647797791835 |
| [64, 512] | 512 | 50 | 0.48457831123432826 | 0.3525369834181648 |
| [128, 32] | 32 | 30 | 0.6039961285897729 | 0.34359048868395253 |
| [128, 32] | 64 | 40 | 0.5985050322190164 | 0.3402627269843407 |
| [128, 32] | 128 | 50 | 0.5794019327922569 | 0.33478221115521123 |
| [128, 32] | 256 | 70 | 0.5739033109847669 | 0.3387409348176428 |
| [128, 32] | 512 | 80 | 0.5175772339746831 | 0.3421137656130536 |
| [128, 64] | 32 | 20 | 0.5542290956403353 | 0.33649332324188325 |
| [128, 64] | 64 | 30 | 0.5780841307041461 | 0.3313689178064872 |
| [128, 64] | 128 | 30 | 0.5060397782798773 | 0.3463739872278888 |
| [128, 64] | 256 | 50 | 0.5388824546976178 | 0.34355298558608943 |
| [128, 64] | 512 | 60 | 0.49846153608787913 | 0.3498538755132095 |
| [128, 128] | 32 | 20 | 0.5969323088625645 | 0.3383213238043019 |
| [128, 128] | 64 | 30 | 0.62310606700902 | 0.3313764363749702 |
| [128, 128] | 128 | 30 | 0.549741282163187 | 0.3430542067484068 |
| [128, 128] | 256 | 40 | 0.5309624665649803 | 0.35104187354270583 |
| [128, 128] | 512 | 60 | 0.5417449906898956 | 0.3491107934108335 |
| [128, 256] | 32 | 20 | 0.6181378488856449 | 0.32070855561961215 |
| [128, 256] | 64 | 20 | 0.5530574362825813 | 0.3378616066308183 |
| [128, 256] | 128 | 30 | 0.5681812442064847 | 0.33457707263754705 |
| [128, 256] | 256 | 40 | 0.5519286224378426 | 0.34162024055503426 |
| [128, 256] | 512 | 50 | 0.519250675395476 | 0.3451956541699298 |
| [128, 512] | 32 | 20 | 0.6552514574466162 | 0.3172032401036353 |
| [128, 512] | 64 | 20 | 0.582892693334241 | 0.32710717987408827 |
| [128, 512] | 128 | 30 | 0.6026291473016793 | 0.33431984845220336 |
| [128, 512] | 256 | 40 | 0.5819287870427984 | 0.3409575726043858 |
| [128, 512] | 512 | 50 | 0.5503226802663136 | 0.34572438556683316 |
| [256, 32] | 32 | 20 | 0.6178103322183633 | 0.3477664372973361 |
| [256, 32] | 64 | 30 | 0.6453541666408443 | 0.34417797447377696 |
| [256, 32] | 128 | 40 | 0.6381036307641387 | 0.35172147458672287 |
| [256, 32] | 256 | 50 | 0.6060682963068668 | 0.35192347999913787 |
| [256, 32] | 512 | 70 | 0.5941857295209115 | 0.35653716958127984 |
| [256, 64] | 32 | 20 | 0.6508446412326873 | 0.3354972318221595 |
| [256, 64] | 64 | 20 | 0.5695339984364156 | 0.3537229823938767 |
| [256, 64] | 128 | 30 | 0.5900300449032189 | 0.3504856280265393 |
| [256, 64] | 256 | 40 | 0.5647273719453317 | 0.3531712153192139 |
| [256, 64] | 512 | 50 | 0.5255412586214809 | 0.36046624092422386 |
| [256, 128] | 32 | 20 | 0.7121633395472036 | 0.3263041137119077 |
| [256, 128] | 64 | 20 | 0.6226772535504698 | 0.34062914405298883 |
| [256, 128] | 128 | 30 | 0.6488284208723305 | 0.3331483268104295 |
| [256, 128] | 256 | 40 | 0.6260847823034719 | 0.3418160957687894 |
| [256, 128] | 512 | 50 | 0.5782757714923286 | 0.34883211590181995 |
| [256, 256] | 32 | 20 | 0.7422068061006223 | 0.3139725271543865 |
| [256, 256] | 64 | 20 | 0.6493067998299227 | 0.3300058666775696 |
| [256, 256] | 128 | 30 | 0.6875151887576874 | 0.32048344192303646 |
| [256, 256] | 256 | 30 | 0.5732933370245789 | 0.34541468268612213 |
| [256, 256] | 512 | 40 | 0.5495634419072722 | 0.34588723505781543 |
| [256, 512] | 32 | 20 | 0.782500496555452 | 0.2732122634388016 |
| [256, 512] | 64 | 20 | 0.6925983481652366 | 0.31086376353851425 |
| [256, 512] | 128 | 20 | 0.6053920409623933 | 0.33602055791937624 |
| [256, 512] | 256 | 30 | 0.6129793417745335 | 0.3304361808171883 |
| [256, 512] | 512 | 40 | 0.5906397087884229 | 0.3383302899338092 |
| [512, 32] | 32 | 20 | 0.7473526096724578 | 0.3210849760341188 |
| [512, 32] | 64 | 20 | 0.6463424365514489 | 0.34960037100969105 |
| [512, 32] | 128 | 30 | 0.6736019797203484 | 0.3456837061484702 |
| [512, 32] | 256 | 40 | 0.6471314103492297 | 0.351719074447624 |
| [512, 32] | 512 | 50 | 0.5940306228931375 | 0.3585976167016458 |
| [512, 64] | 32 | 20 | 0.7801629048283105 | 0.3212583291037159 |
| [512, 64] | 64 | 20 | 0.6815954287895202 | 0.33231715046311944 |
| [512, 64] | 128 | 30 | 0.714064295599162 | 0.3284033861906422 |
| [512, 64] | 256 | 30 | 0.5955503211525276 | 0.34967776659967015 |
| [512, 64] | 512 | 40 | 0.5677400561504948 | 0.35028135079944817 |
| [512, 128] | 32 | 20 | 0.8260028614447924 | 0.2953251952259495 |
| [512, 128] | 64 | 20 | 0.7254102927038344 | 0.3249371086660587 |
| [512, 128] | 128 | 20 | 0.624912277963159 | 0.34015650139658354 |
| [512, 128] | 256 | 30 | 0.6423050217988907 | 0.3433596816806615 |
| [512, 128] | 512 | 40 | 0.6151948846882693 | 0.34232623684709085 |
| [512, 256] | 32 | 20 | 0.8792761671458663 | 0.25363890459228017 |
| [512, 256] | 64 | 20 | 0.7913154856750211 | 0.2780646408203931 |
| [512, 256] | 128 | 20 | 0.6754520532232722 | 0.32230029450889075 |
| [512, 256] | 256 | 20 | 0.5636523483204313 | 0.34393162091158713 |
| [512, 256] | 512 | 30 | 0.5734633023001962 | 0.34435823604452725 |
| [512, 512] | 32 | 20 | 0.8998848099071178 | 0.23402208140250982 |
| [512, 512] | 64 | 20 | 0.8224416560399235 | 0.2723810253666723 |
| [512, 512] | 128 | 20 | 0.7184820066660945 | 0.30443170370453576 |
| [512, 512] | 256 | 20 | 0.5961999576004962 | 0.34077235290976854 |
| [512, 512] | 512 | 30 | 0.6104589519968133 | 0.3364864410301728 |



##### Weight Decay = 0.01



| Hidden Sizes | Batch Size | Final Epoch | Training Pearson Score | Validation Pearson Score |
|---|---|---|---|---|
| [32, 32] | 32 | 80 | 0.4768371156481136 | 0.3472105421814501 |
| [32, 32] | 64 | 80 | 0.4412990560016324 | 0.3490910295584378 |
| [32, 32] | 128 | 130 | 0.46320460333932084 | 0.35604428366231206 |
| [32, 32] | 256 | 200 | 0.47043516351786646 | 0.3544688086342531 |
| [32, 32] | 512 | 250 | 0.46393373094097495 | 0.358283666133919 |
| [32, 64] | 32 | 70 | 0.47664527627112285 | 0.3511953535231789 |
| [32, 64] | 64 | 90 | 0.4629009798187991 | 0.3427573148099104 |
| [32, 64] | 128 | 120 | 0.4530688238469185 | 0.34261743780722276 |
| [32, 64] | 256 | 160 | 0.4549851002315548 | 0.3471089321762645 |
| [32, 64] | 512 | 300 | 0.4918908005658059 | 0.3550438246242377 |
| [32, 128] | 32 | 40 | 0.45640237977557246 | 0.35260709946727226 |
| [32, 128] | 64 | 70 | 0.48555743851648625 | 0.35259276745571405 |
| [32, 128] | 128 | 80 | 0.4608224613616029 | 0.3532852517865351 |
| [32, 128] | 256 | 120 | 0.46475468189265484 | 0.35626144013915956 |
| [32, 128] | 512 | 200 | 0.4852847485078734 | 0.3530608122086552 |
| [32, 256] | 32 | 50 | 0.4879702041143264 | 0.3517854687301486 |
| [32, 256] | 64 | 60 | 0.47734473091428953 | 0.35245407907606846 |
| [32, 256] | 128 | 80 | 0.47116296786118445 | 0.35266136512639207 |
| [32, 256] | 256 | 130 | 0.49059173060902855 | 0.35291393811985433 |
| [32, 256] | 512 | 160 | 0.45808960359670453 | 0.35097953315558644 |
| [32, 512] | 32 | 50 | 0.5076814100559679 | 0.3521317795810667 |
| [32, 512] | 64 | 50 | 0.4625676369439403 | 0.3542414364806836 |
| [32, 512] | 128 | 70 | 0.46512525025000084 | 0.3527059014289503 |
| [32, 512] | 256 | 90 | 0.4507865858215482 | 0.3523437141842701 |
| [32, 512] | 512 | 140 | 0.4624769089335927 | 0.35345885559957096 |
| [64, 32] | 32 | 50 | 0.5017733043285105 | 0.35872979444239467 |
| [64, 32] | 64 | 70 | 0.5082508006545847 | 0.3600456375143 |
| [64, 32] | 128 | 80 | 0.4818669121282903 | 0.3602922105788539 |
| [64, 32] | 256 | 120 | 0.48945038015407494 | 0.3593044090521456 |
| [64, 32] | 512 | 140 | 0.4519026977967834 | 0.35649467730619844 |
| [64, 64] | 32 | 60 | 0.4999454526370554 | 0.36515000522865493 |
| [64, 64] | 64 | 90 | 0.5288894659888791 | 0.36626392298706 |
| [64, 64] | 128 | 110 | 0.5167730798544086 | 0.3664094297590882 |
| [64, 64] | 256 | 120 | 0.47549464815809706 | 0.3651224997268548 |
| [64, 64] | 512 | 210 | 0.5228911069973523 | **0.3688918486855366** |
| [64, 128] | 32 | 50 | 0.5256763827090827 | 0.3436840438702426 |
| [64, 128] | 64 | 50 | 0.47621071051056635 | 0.34935614980525853 |
| [64, 128] | 128 | 70 | 0.492985054394737 | 0.3541966554711201 |
| [64, 128] | 256 | 100 | 0.49597662265953607 | 0.35313526037712695 |
| [64, 128] | 512 | 150 | 0.5069918503234064 | 0.3526837630207277 |
| [64, 256] | 32 | 50 | 0.5494891168785351 | 0.35883686890825656 |
| [64, 256] | 64 | 60 | 0.5269415293432682 | 0.36249765190193034 |
| [64, 256] | 128 | 80 | 0.5300505203434721 | 0.3589246457736425 |
| [64, 256] | 256 | 100 | 0.5080241887029228 | 0.36016269843549736 |
| [64, 256] | 512 | 140 | 0.5059537522589852 | 0.3587068710399753 |
| [64, 512] | 32 | 40 | 0.5401983496880267 | 0.35321207201271876 |
| [64, 512] | 64 | 50 | 0.5318604001907415 | 0.36099390689766553 |
| [64, 512] | 128 | 70 | 0.5421689778085229 | 0.36105933507373916 |
| [64, 512] | 256 | 80 | 0.4944382209176507 | 0.3626388856407587 |
| [64, 512] | 512 | 130 | 0.5196652036143773 | 0.36225251211290205 |
| [128, 32] | 32 | 50 | 0.5249271832296621 | 0.35760718583215756 |
| [128, 32] | 64 | 80 | 0.5685025740689781 | 0.36377939633090767 |
| [128, 32] | 128 | 80 | 0.5139759118893513 | 0.3648175248658738 |
| [128, 32] | 256 | 100 | 0.4952233549899021 | 0.3564903714865462 |
| [128, 32] | 512 | 160 | 0.5147025480042984 | 0.36135181855850185 |
| [128, 64] | 32 | 50 | 0.5738080318348487 | 0.36239898747388705 |
| [128, 64] | 64 | 50 | 0.5196744338346704 | 0.356533390774149 |
| [128, 64] | 128 | 70 | 0.5317510804320469 | 0.3572402072946805 |
| [128, 64] | 256 | 80 | 0.48812725659154754 | 0.35517634016323213 |
| [128, 64] | 512 | 120 | 0.4928091463886728 | 0.35535589307344057 |
| [128, 128] | 32 | 30 | 0.5180119796291096 | 0.35591540698998897 |
| [128, 128] | 64 | 40 | 0.5253773911651628 | 0.3574734787149585 |
| [128, 128] | 128 | 60 | 0.5477979642672689 | 0.3590288563409101 |
| [128, 128] | 256 | 70 | 0.5092802438075198 | 0.36074254358301516 |
| [128, 128] | 512 | 120 | 0.5499484279862112 | 0.3588146512241456 |
| [128, 256] | 32 | 30 | 0.5105203622305265 | 0.3587707117844619 |
| [128, 256] | 64 | 60 | 0.5979701179675816 | 0.3483910760126731 |
| [128, 256] | 128 | 60 | 0.5375379425989347 | 0.3571144401828996 |
| [128, 256] | 256 | 80 | 0.5255759939044361 | 0.35982609689586265 |
| [128, 256] | 512 | 110 | 0.5099185190234068 | 0.35963302396225366 |
| [128, 512] | 32 | 40 | 0.6087175732336804 | 0.34509126397210504 |
| [128, 512] | 64 | 40 | 0.5490352041120409 | 0.35840229732302004 |
| [128, 512] | 128 | 60 | 0.5692428465652248 | 0.36039270279129104 |
| [128, 512] | 256 | 80 | 0.5593548952789219 | 0.3578720909181322 |
| [128, 512] | 512 | 100 | 0.5232604125350693 | 0.3589824505411314 |
| [256, 32] | 32 | 50 | 0.6195472376576506 | 0.3557453465832519 |
| [256, 32] | 64 | 50 | 0.5563804056392209 | 0.36367792457280357 |
| [256, 32] | 128 | 60 | 0.5393136872926895 | 0.3610146053650587 |
| [256, 32] | 256 | 80 | 0.5335812928400498 | 0.363278246196464 |
| [256, 32] | 512 | 140 | 0.5734618054297648 | 0.3633220253579395 |
| [256, 64] | 32 | 40 | 0.5461654804024745 | 0.3514624499793854 |
| [256, 64] | 64 | 40 | 0.5014152491520321 | 0.35704700757568464 |
| [256, 64] | 128 | 70 | 0.5685993638069473 | 0.3579603834979342 |
| [256, 64] | 256 | 90 | 0.5532983304285045 | 0.3612728027167668 |
| [256, 64] | 512 | 120 | 0.5397913602071595 | 0.3625542689527807 |
| [256, 128] | 32 | 30 | 0.5751101278909201 | 0.35622548590492054 |
| [256, 128] | 64 | 40 | 0.5808126596396925 | 0.3576858146716209 |
| [256, 128] | 128 | 60 | 0.6107183590079956 | 0.36312025863448655 |
| [256, 128] | 256 | 80 | 0.6002998835218126 | 0.36162464810100253 |
| [256, 128] | 512 | 100 | 0.5563979841428376 | 0.3644687677651771 |
| [256, 256] | 32 | 40 | 0.6291331839861412 | 0.3506029149198077 |
| [256, 256] | 64 | 40 | 0.5695509042109238 | 0.3572348127585443 |
| [256, 256] | 128 | 40 | 0.5090532832286437 | 0.36040476288879725 |
| [256, 256] | 256 | 60 | 0.5204361648474637 | 0.3645906716991325 |
| [256, 256] | 512 | 100 | 0.5615250273811208 | 0.3618892360575652 |
| [256, 512] | 32 | 30 | 0.5907409393875985 | 0.3564003969888556 |
| [256, 512] | 64 | 40 | 0.6067050252263743 | 0.3556896616542178 |
| [256, 512] | 128 | 50 | 0.5954272195571649 | 0.3582021660717044 |
| [256, 512] | 256 | 70 | 0.5976306911296411 | 0.35736860541759585 |
| [256, 512] | 512 | 90 | 0.5705900509506002 | 0.3576595848614088 |
| [512, 32] | 32 | 30 | 0.5419074818681398 | 0.36196331560976197 |
| [512, 32] | 64 | 40 | 0.5508422389751494 | 0.3542479201862902 |
| [512, 32] | 128 | 50 | 0.5491332492787844 | 0.3599954699406033 |
| [512, 32] | 256 | 70 | 0.5477602713071558 | 0.3598743729389123 |
| [512, 32] | 512 | 90 | 0.5207274821547373 | 0.36096356027225746 |
| [512, 64] | 32 | 30 | 0.5618102403318733 | 0.3564562706132819 |
| [512, 64] | 64 | 40 | 0.5828353571211595 | 0.3555416806816426 |
| [512, 64] | 128 | 50 | 0.5723584849504204 | 0.3570558600224812 |
| [512, 64] | 256 | 70 | 0.5862492012751988 | 0.3556519771525367 |
| [512, 64] | 512 | 90 | 0.5576413872758721 | 0.3585821559465128 |
| [512, 128] | 32 | 30 | 0.5812244600138176 | 0.3471274911395109 |
| [512, 128] | 64 | 40 | 0.6106870282302498 | 0.3557103429810736 |
| [512, 128] | 128 | 40 | 0.5501834512912124 | 0.36250493526584904 |
| [512, 128] | 256 | 70 | 0.6187979699160502 | 0.3515848222457987 |
| [512, 128] | 512 | 80 | 0.549300334035669 | 0.3563584770996534 |
| [512, 256] | 32 | 30 | 0.6293914190652328 | 0.3506981350615361 |
| [512, 256] | 64 | 30 | 0.5865646576172406 | 0.355044666676185 |
| [512, 256] | 128 | 40 | 0.5922234856934344 | 0.3543297974235625 |
| [512, 256] | 256 | 50 | 0.5732103080202352 | 0.3536901261250561 |
| [512, 256] | 512 | 80 | 0.6074728435139047 | 0.36008120848473274 |
| [512, 512] | 32 | 20 | 0.5646885592017371 | 0.3530281135406622 |
| [512, 512] | 64 | 30 | 0.6114552447196876 | 0.3491200624617726 |
| [512, 512] | 128 | 30 | 0.5498582854139467 | 0.35810458536940887 |
| [512, 512] | 256 | 50 | 0.6017678539851714 | 0.3554978351483823 |
| [512, 512] | 512 | 70 | 0.5903664204005237 | 0.36072800770364966 |




##### Weight Decay = 0.001



| Hidden Sizes | Batch Size | Final Epoch | Training Pearson Score | Validation Pearson Score |
|---|---|---|---|---|
| [32, 32] | 32 | 40 | 0.49220829899622603 | 0.35092664847688443 |
| [32, 32] | 64 | 50 | 0.4814242237420692 | 0.35336728521759164 |
| [32, 32] | 128 | 60 | 0.4623328591987379 | 0.35481357280737624 |
| [32, 32] | 256 | 100 | 0.482055365879759 | 0.35386552652642705 |
| [32, 32] | 512 | 120 | 0.45368229268398846 | 0.3548117063226977 |
| [32, 64] | 32 | 30 | 0.46970603986050047 | 0.3464953631119434 |
| [32, 64] | 64 | 50 | 0.5038283753937575 | 0.3470060639477771 |
| [32, 64] | 128 | 50 | 0.45252451205433186 | 0.34808923194702457 |
| [32, 64] | 256 | 80 | 0.4683281268716328 | 0.3456602235825528 |
| [32, 64] | 512 | 100 | 0.4456185036973398 | 0.3460722868375535 |
| [32, 128] | 32 | 30 | 0.4953499398255297 | 0.34413403184735825 |
| [32, 128] | 64 | 40 | 0.4947972288941642 | 0.3413263285240755 |
| [32, 128] | 128 | 50 | 0.4792860605329388 | 0.34806116532695974 |
| [32, 128] | 256 | 70 | 0.47263481440104377 | 0.3492924112688629 |
| [32, 128] | 512 | 90 | 0.4542999750522939 | 0.3511657741553443 |
| [32, 256] | 32 | 30 | 0.5154147483830246 | 0.3500116616793944 |
| [32, 256] | 64 | 40 | 0.5137122847842815 | 0.34964711606728 |
| [32, 256] | 128 | 50 | 0.5015494438727124 | 0.3489774500246439 |
| [32, 256] | 256 | 70 | 0.49118554924565666 | 0.35263849491641047 |
| [32, 256] | 512 | 100 | 0.4806364035907918 | 0.3535470748567105 |
| [32, 512] | 32 | 20 | 0.4899783924771183 | 0.33251553841525444 |
| [32, 512] | 64 | 30 | 0.5030222264397366 | 0.3282021573221569 |
| [32, 512] | 128 | 40 | 0.48568620480059965 | 0.3400918083674615 |
| [32, 512] | 256 | 50 | 0.4601295939510631 | 0.3444868489755288 |
| [32, 512] | 512 | 70 | 0.45618263488675787 | 0.34464840817562464 |
| [64, 32] | 32 | 30 | 0.5290327464336355 | 0.34605905936722636 |
| [64, 32] | 64 | 40 | 0.5182915939129374 | 0.35046352087961224 |
| [64, 32] | 128 | 50 | 0.5050843654030664 | 0.35073774823884163 |
| [64, 32] | 256 | 70 | 0.49857955862397163 | 0.34960669830376256 |
| [64, 32] | 512 | 100 | 0.4953961234223967 | 0.35419207760715293 |
| [64, 64] | 32 | 30 | 0.5321059797841452 | 0.35547005530468156 |
| [64, 64] | 64 | 40 | 0.5284725730513365 | 0.35404423423542636 |
| [64, 64] | 128 | 50 | 0.5048643945562385 | 0.358154874791565 |
| [64, 64] | 256 | 70 | 0.49787378262584414 | 0.36029881588153717 |
| [64, 64] | 512 | 90 | 0.47286944789678603 | **0.36298889594266526** |
| [64, 128] | 32 | 30 | 0.5849697371780138 | 0.3376878616818794 |
| [64, 128] | 64 | 30 | 0.523408372075128 | 0.35156295476869703 |
| [64, 128] | 128 | 40 | 0.5170019344951011 | 0.3500527626734647 |
| [64, 128] | 256 | 50 | 0.49076980473963105 | 0.3510334547570223 |
| [64, 128] | 512 | 70 | 0.4842119321957344 | 0.35209782009341606 |
| [64, 256] | 32 | 30 | 0.5760793991752905 | 0.34452705640663517 |
| [64, 256] | 64 | 30 | 0.5247109534565682 | 0.35333280221079816 |
| [64, 256] | 128 | 40 | 0.5197438836500577 | 0.3519232860652139 |
| [64, 256] | 256 | 50 | 0.4968100124543012 | 0.3570168121520941 |
| [64, 256] | 512 | 70 | 0.48983288608978964 | 0.35522034837044697 |
| [64, 512] | 32 | 20 | 0.5499122951333788 | 0.3357365033779437 |
| [64, 512] | 64 | 30 | 0.5696225814471075 | 0.3289441979020294 |
| [64, 512] | 128 | 30 | 0.508183132914768 | 0.34696702320618816 |
| [64, 512] | 256 | 40 | 0.4905154730174879 | 0.3489190869677597 |
| [64, 512] | 512 | 60 | 0.49505689936540503 | 0.34690359118534525 |
| [128, 32] | 32 | 30 | 0.5938475535687953 | 0.3485381212752656 |
| [128, 32] | 64 | 40 | 0.5903090587785399 | 0.3517772231106364 |
| [128, 32] | 128 | 50 | 0.5735979655201162 | 0.34789932880035207 |
| [128, 32] | 256 | 70 | 0.5678678191746906 | 0.3501491228491053 |
| [128, 32] | 512 | 100 | 0.5632691108254277 | 0.35057831983657595 |
| [128, 64] | 32 | 20 | 0.5478902855067767 | 0.3388997395245118 |
| [128, 64] | 64 | 30 | 0.5735098273523843 | 0.33923056102149035 |
| [128, 64] | 128 | 40 | 0.5616435416161261 | 0.3445626224582357 |
| [128, 64] | 256 | 50 | 0.5304185912115482 | 0.351796889344657 |
| [128, 64] | 512 | 70 | 0.5199774552048672 | 0.3546881384160965 |
| [128, 128] | 32 | 20 | 0.5874930888107026 | 0.33997881213258596 |
| [128, 128] | 64 | 30 | 0.6112999933747997 | 0.33453939705729663 |
| [128, 128] | 128 | 30 | 0.539633970929768 | 0.34941903605026153 |
| [128, 128] | 256 | 40 | 0.517999772562526 | 0.35354188255091884 |
| [128, 128] | 512 | 60 | 0.5287819260513144 | 0.35165268215658513 |
| [128, 256] | 32 | 20 | 0.6091529901556677 | 0.3306415455622169 |
| [128, 256] | 64 | 20 | 0.5402706432932518 | 0.3419752371965145 |
| [128, 256] | 128 | 30 | 0.5552441643062992 | 0.3366510548731301 |
| [128, 256] | 256 | 40 | 0.5368663593034899 | 0.34442229905038785 |
| [128, 256] | 512 | 50 | 0.4985317306133952 | 0.34930500306054935 |
| [128, 512] | 32 | 20 | 0.6386974606845096 | 0.32422169378605586 |
| [128, 512] | 64 | 20 | 0.5700352681804732 | 0.34042993050294845 |
| [128, 512] | 128 | 30 | 0.590284266220035 | 0.3452747812983865 |
| [128, 512] | 256 | 40 | 0.5647640831349545 | 0.3521847830492433 |
| [128, 512] | 512 | 50 | 0.5262224463127279 | 0.35286066214579265 |
| [256, 32] | 32 | 20 | 0.5968803709145881 | 0.3536098532589031 |
| [256, 32] | 64 | 30 | 0.6278749308232062 | 0.3537787449596557 |
| [256, 32] | 128 | 40 | 0.6233092332777945 | 0.3562697194398134 |
| [256, 32] | 256 | 50 | 0.5920866234844513 | 0.35414694313192724 |
| [256, 32] | 512 | 70 | 0.5799499721085484 | 0.3571348064759131 |
| [256, 64] | 32 | 20 | 0.6300591609181392 | 0.34987408296108075 |
| [256, 64] | 64 | 30 | 0.6551040742691451 | 0.3442128183735948 |
| [256, 64] | 128 | 30 | 0.5747788209351343 | 0.3603645835260601 |
| [256, 64] | 256 | 40 | 0.5481061640217578 | 0.36095926865265526 |
| [256, 64] | 512 | 60 | 0.5590083704496889 | 0.3616885754411477 |
| [256, 128] | 32 | 20 | 0.6966667066745962 | 0.32754671592942164 |
| [256, 128] | 64 | 20 | 0.6067113168138177 | 0.34352184795938756 |
| [256, 128] | 128 | 30 | 0.6338465784152287 | 0.3351081999844846 |
| [256, 128] | 256 | 40 | 0.6121486143571581 | 0.3446088140511725 |
| [256, 128] | 512 | 50 | 0.5634122316012783 | 0.3539056753699886 |
| [256, 256] | 32 | 20 | 0.7255063078806151 | 0.3191919452436102 |
| [256, 256] | 64 | 20 | 0.6331412202718765 | 0.3394758163869767 |
| [256, 256] | 128 | 30 | 0.6712839529303881 | 0.33389214980261506 |
| [256, 256] | 256 | 30 | 0.5540993893818433 | 0.351866096767011 |
| [256, 256] | 512 | 50 | 0.5935589399428641 | 0.34611086832794663 |
| [256, 512] | 32 | 20 | 0.7667756411383152 | 0.28951877056564734 |
| [256, 512] | 64 | 20 | 0.6723740383135323 | 0.3231751521695537 |
| [256, 512] | 128 | 20 | 0.5835852779358826 | 0.34141275392319975 |
| [256, 512] | 256 | 30 | 0.589538591181283 | 0.3395809766998332 |
| [256, 512] | 512 | 40 | 0.5600888716127227 | 0.348120032840422 |
| [512, 32] | 32 | 20 | 0.724972855458449 | 0.3331882021343027 |
| [512, 32] | 64 | 20 | 0.6290953279435735 | 0.35320821518874235 |
| [512, 32] | 128 | 30 | 0.6578192164077843 | 0.3524640437088937 |
| [512, 32] | 256 | 40 | 0.633450302530175 | 0.3549363667600653 |
| [512, 32] | 512 | 50 | 0.5802491756886101 | 0.3621849431968058 |
| [512, 64] | 32 | 20 | 0.760050790762964 | 0.3258258091224857 |
| [512, 64] | 64 | 20 | 0.664849173125861 | 0.3368138381066152 |
| [512, 64] | 128 | 30 | 0.6985525721730933 | 0.3323136912577013 |
| [512, 64] | 256 | 30 | 0.5830843282883214 | 0.3495385625985171 |
| [512, 64] | 512 | 40 | 0.5525641062363739 | 0.35032644015852515 |
| [512, 128] | 32 | 20 | 0.8035276579904501 | 0.30421017152988805 |
| [512, 128] | 64 | 20 | 0.7077495836273445 | 0.3306045319715496 |
| [512, 128] | 128 | 20 | 0.6076135253060287 | 0.34387969258578266 |
| [512, 128] | 256 | 30 | 0.6239506935317523 | 0.34745580098039736 |
| [512, 128] | 512 | 40 | 0.5929474225987593 | 0.3474017983968875 |
| [512, 256] | 32 | 20 | 0.8629656049246619 | 0.27932241879825775 |
| [512, 256] | 64 | 20 | 0.7757517901360209 | 0.30086575690165407 |
| [512, 256] | 128 | 20 | 0.6550498778702409 | 0.3371205237782683 |
| [512, 256] | 256 | 30 | 0.6763416074242495 | 0.33425606157790966 |
| [512, 256] | 512 | 30 | 0.5533762703123488 | 0.35313271179873107 |
| [512, 512] | 32 | 20 | 0.887546556023914 | 0.2543909206381692 |
| [512, 512] | 64 | 20 | 0.8027367334752653 | 0.2935067540314058 |
| [512, 512] | 128 | 20 | 0.6972190916364553 | 0.31879209461900665 |
| [512, 512] | 256 | 30 | 0.7095946775338411 | 0.3246259588251714 |
| [512, 512] | 512 | 30 | 0.5796865137498517 | 0.34465830068644226 |



##### Weight Decay = 0.0001



| Hidden Sizes | Batch Size | Final Epoch | Training Pearson Score | Validation Pearson Score |
|---|---|---|---|---|
| [32, 32] | 32 | 40 | 0.505823422288478 | 0.35208704782400835 |
| [32, 32] | 64 | 50 | 0.49037698762052784 | 0.3518732695797378 |
| [32, 32] | 128 | 60 | 0.470264244455386 | 0.3521522512598751 |
| [32, 32] | 256 | 80 | 0.4592030937751869 | 0.3525995328811597 |
| [32, 32] | 512 | 120 | 0.4630796156424533 | 0.35408842483464453 |
| [32, 64] | 32 | 30 | 0.47709959975912886 | 0.3487666771085576 |
| [32, 64] | 64 | 40 | 0.4752297883881018 | 0.35049327884776027 |
| [32, 64] | 128 | 60 | 0.4879883541884056 | 0.35114715253843126 |
| [32, 64] | 256 | 70 | 0.45569665492945394 | 0.34818244420907407 |
| [32, 64] | 512 | 100 | 0.4554815156572567 | 0.3470677014110536 |
| [32, 128] | 32 | 30 | 0.5014691207632432 | 0.3419827451680263 |
| [32, 128] | 64 | 40 | 0.4949133688062435 | 0.3378985771953629 |
| [32, 128] | 128 | 40 | 0.4589822981030443 | 0.34776470134502113 |
| [32, 128] | 256 | 60 | 0.4611340980141366 | 0.34707911882342335 |
| [32, 128] | 512 | 90 | 0.46329305222778044 | 0.3463444973224244 |
| [32, 256] | 32 | 30 | 0.5285825236050112 | 0.34715347906720057 |
| [32, 256] | 64 | 30 | 0.48400379781324576 | 0.3569206194454497 |
| [32, 256] | 128 | 40 | 0.47660329005023033 | 0.35543008908011187 |
| [32, 256] | 256 | 60 | 0.48061403986874496 | 0.35310393310741733 |
| [32, 256] | 512 | 80 | 0.46383492480735516 | 0.35260414686113595 |
| [32, 512] | 32 | 20 | 0.5017216794418631 | 0.3234148204139296 |
| [32, 512] | 64 | 30 | 0.5170884655707663 | 0.3190219832875727 |
| [32, 512] | 128 | 40 | 0.5092333244831441 | 0.324760617247735 |
| [32, 512] | 256 | 50 | 0.48458488154870993 | 0.32990910140845187 |
| [32, 512] | 512 | 60 | 0.45695857130165946 | 0.33466880065703186 |
| [64, 32] | 32 | 30 | 0.5378038673854566 | 0.34321315927271673 |
| [64, 32] | 64 | 40 | 0.5280764677738512 | 0.34520544682154486 |
| [64, 32] | 128 | 50 | 0.5143859485372471 | 0.34680882767087834 |
| [64, 32] | 256 | 60 | 0.4822791433762219 | 0.34945289402224206 |
| [64, 32] | 512 | 80 | 0.4692163868413653 | 0.3509332187376666 |
| [64, 64] | 32 | 30 | 0.5436982980817271 | 0.3431338406224313 |
| [64, 64] | 64 | 40 | 0.5339999811709546 | 0.3519879363889759 |
| [64, 64] | 128 | 50 | 0.5149696707764282 | 0.3549081252467525 |
| [64, 64] | 256 | 60 | 0.48188652386026043 | 0.35923471844303356 |
| [64, 64] | 512 | 90 | 0.4871774318230596 | **0.36183749626325623** |
| [64, 128] | 32 | 20 | 0.5243103131735792 | 0.3455945671763587 |
| [64, 128] | 64 | 30 | 0.5346341781338338 | 0.3436971162953363 |
| [64, 128] | 128 | 40 | 0.5284548595250704 | 0.3446523690665402 |
| [64, 128] | 256 | 50 | 0.5021094530552034 | 0.3462620745361478 |
| [64, 128] | 512 | 70 | 0.5002772332530007 | 0.34718163621904746 |
| [64, 256] | 32 | 30 | 0.5929278814458947 | 0.3394668484428862 |
| [64, 256] | 64 | 30 | 0.535846710060695 | 0.3507148278797935 |
| [64, 256] | 128 | 40 | 0.5350204736291667 | 0.3457623304760307 |
| [64, 256] | 256 | 50 | 0.5094068574318379 | 0.35025242085051933 |
| [64, 256] | 512 | 70 | 0.5065187444431162 | 0.3543220305945619 |
| [64, 512] | 32 | 20 | 0.5595211165049472 | 0.32981630754678 |
| [64, 512] | 64 | 30 | 0.5834017053607854 | 0.3139179290484565 |
| [64, 512] | 128 | 30 | 0.5238281858468297 | 0.34077616381414644 |
| [64, 512] | 256 | 40 | 0.5080456320927217 | 0.34481656429461244 |
| [64, 512] | 512 | 50 | 0.48191985814141547 | 0.35232497637352694 |
| [128, 32] | 32 | 30 | 0.6044275215821936 | 0.3399519333804215 |
| [128, 32] | 64 | 40 | 0.5997000720634801 | 0.3406443143505809 |
| [128, 32] | 128 | 50 | 0.5829624394096283 | 0.3383670620542998 |
| [128, 32] | 256 | 60 | 0.5405331625738093 | 0.3438474835792038 |
| [128, 32] | 512 | 90 | 0.5481624258120793 | 0.3446329544838711 |
| [128, 64] | 32 | 20 | 0.5556817407754253 | 0.33618165147655726 |
| [128, 64] | 64 | 30 | 0.5801433100392112 | 0.33273598840422564 |
| [128, 64] | 128 | 30 | 0.5078993628051661 | 0.34571555410261684 |
| [128, 64] | 256 | 50 | 0.5391542890050969 | 0.34325174249332846 |
| [128, 64] | 512 | 60 | 0.498659265686346 | 0.3495958844519912 |
| [128, 128] | 32 | 20 | 0.6001861859642373 | 0.3368470765369504 |
| [128, 128] | 64 | 30 | 0.624043189387397 | 0.3313504087841685 |
| [128, 128] | 128 | 30 | 0.5475175866017413 | 0.34400055129757573 |
| [128, 128] | 256 | 40 | 0.5293474984155755 | 0.3513873473385241 |
| [128, 128] | 512 | 60 | 0.5421546959775038 | 0.3494712335975795 |
| [128, 256] | 32 | 20 | 0.6151176142804811 | 0.3251729306628442 |
| [128, 256] | 64 | 20 | 0.5491104008118843 | 0.3400793673387446 |
| [128, 256] | 128 | 30 | 0.5672375124967568 | 0.33468287465104046 |
| [128, 256] | 256 | 40 | 0.5504070160997051 | 0.34082059111788776 |
| [128, 256] | 512 | 50 | 0.517177591108635 | 0.3454033249030222 |
| [128, 512] | 32 | 20 | 0.6517414917381634 | 0.3183286374878705 |
| [128, 512] | 64 | 20 | 0.580639127025956 | 0.32650362023209517 |
| [128, 512] | 128 | 30 | 0.601829808285818 | 0.335544632487819 |
| [128, 512] | 256 | 40 | 0.5794158134166872 | 0.3396870438533251 |
| [128, 512] | 512 | 50 | 0.5478086681364609 | 0.3457338496297844 |
| [256, 32] | 32 | 20 | 0.6197104444753478 | 0.3457144986438508 |
| [256, 32] | 64 | 30 | 0.6479059767131523 | 0.34410697675324825 |
| [256, 32] | 128 | 40 | 0.6384299078439426 | 0.35142831228043114 |
| [256, 32] | 256 | 50 | 0.6065169461964632 | 0.3513422641505353 |
| [256, 32] | 512 | 60 | 0.5514046360549553 | 0.35899293053904735 |
| [256, 64] | 32 | 20 | 0.6469998813323583 | 0.339130371223262 |
| [256, 64] | 64 | 20 | 0.5691579735056576 | 0.3534750219601805 |
| [256, 64] | 128 | 30 | 0.5881508686069298 | 0.35202843809757633 |
| [256, 64] | 256 | 40 | 0.5615996261288695 | 0.354230696573085 |
| [256, 64] | 512 | 50 | 0.5252788338517802 | 0.35990926277337176 |
| [256, 128] | 32 | 20 | 0.7113982433313375 | 0.3279882677784439 |
| [256, 128] | 64 | 20 | 0.6223169795241147 | 0.34297085467706845 |
| [256, 128] | 128 | 30 | 0.6491941516330276 | 0.33177493082814763 |
| [256, 128] | 256 | 40 | 0.62604762689492 | 0.3411842514899606 |
| [256, 128] | 512 | 50 | 0.5780201059615827 | 0.3500117003196546 |
| [256, 256] | 32 | 20 | 0.7424329893375285 | 0.31506377945604574 |
| [256, 256] | 64 | 20 | 0.649250723005728 | 0.3305695927538402 |
| [256, 256] | 128 | 30 | 0.6852883806076323 | 0.3219414669690295 |
| [256, 256] | 256 | 30 | 0.5724555604125487 | 0.34732668237950215 |
| [256, 256] | 512 | 40 | 0.5483990312076282 | 0.34773191770685885 |
| [256, 512] | 32 | 20 | 0.7824189213621108 | 0.279793776364684 |
| [256, 512] | 64 | 20 | 0.6896720593389389 | 0.3101808975229983 |
| [256, 512] | 128 | 20 | 0.60338926650279 | 0.3356336502917889 |
| [256, 512] | 256 | 30 | 0.6097007683116058 | 0.3326968722130576 |
| [256, 512] | 512 | 40 | 0.587149730240542 | 0.33950993423713 |
| [512, 32] | 32 | 20 | 0.7461743042275069 | 0.321446278168305 |
| [512, 32] | 64 | 20 | 0.644913366212597 | 0.3483031615777772 |
| [512, 32] | 128 | 30 | 0.6724302084444705 | 0.34533608780189556 |
| [512, 32] | 256 | 40 | 0.6477513266681929 | 0.352426726791974 |
| [512, 32] | 512 | 50 | 0.5935091284770635 | 0.35900178315333175 |
| [512, 64] | 32 | 20 | 0.7820541398530194 | 0.3166445748926883 |
| [512, 64] | 64 | 20 | 0.6818484009843438 | 0.32909293533898454 |
| [512, 64] | 128 | 30 | 0.7149314461115528 | 0.33076436041895546 |
| [512, 64] | 256 | 30 | 0.5962775215284128 | 0.3492206727353766 |
| [512, 64] | 512 | 40 | 0.5676347891241046 | 0.3508315956025261 |
| [512, 128] | 32 | 20 | 0.8229149134376538 | 0.3001021965171695 |
| [512, 128] | 64 | 20 | 0.7267707987734563 | 0.32392133197165146 |
| [512, 128] | 128 | 20 | 0.624037726777915 | 0.342007016836002 |
| [512, 128] | 256 | 30 | 0.6413814614930899 | 0.3446103049980825 |
| [512, 128] | 512 | 40 | 0.6138212565348751 | 0.3426753157705122 |
| [512, 256] | 32 | 20 | 0.879604461892583 | 0.25754215369093464 |
| [512, 256] | 64 | 20 | 0.7908816543354489 | 0.2827369562956576 |
| [512, 256] | 128 | 20 | 0.6747287650493594 | 0.32346635099623494 |
| [512, 256] | 256 | 30 | 0.6949852212846929 | 0.3155644450786335 |
| [512, 256] | 512 | 30 | 0.5723789201320868 | 0.3444996081916982 |
| [512, 512] | 32 | 20 | 0.8996420723531632 | 0.23109622597671117 |
| [512, 512] | 64 | 20 | 0.8183769336481277 | 0.26663308474633246 |
| [512, 512] | 128 | 20 | 0.7161010860517311 | 0.30606238277378073 |
| [512, 512] | 256 | 20 | 0.5947669945351728 | 0.34053336811274176 |
| [512, 512] | 512 | 30 | 0.6081618050928611 | 0.3376470592289059 |



## Results

In [0]:
import os
from google.colab import files
from zipfile import ZipFile

def writeScores(scores):
    fn = "predictions.txt"
    print("")
    with open(fn, 'w') as output_file:
        for idx,x in enumerate(scores):
            #out =  metrics[idx]+":"+str("{0:.2f}".format(x))+"\n"
            #print(out)
            output_file.write(f"{x}\n")


def downloadScores(method_name, scores):
  writeScores(scores)
  with ZipFile(f"en-zh_{method_name}.zip", "w") as newzip:
    newzip.write("predictions.txt")
  
  files.download(f"en-zh_{method_name}.zip")