In [97]:
from copy import deepcopy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
DATADIR = 'data/'

In [44]:
def read_corpus(file):
    """
    returns:
        lines: [['hello', 'world'], ...]
        labels: [[!], [N], ...]
        vocab
    """
    with open(file, 'rt') as f:
        text = f.read()
    lines = text.split('\n\n')
    ret_lines = []
    labels = []
    vocab = set()
    for line in lines:
        if not line: 
            continue
        curr_line = []
        for token_label_str in line.split('\n'):
            if not token_label_str: 
                continue
            token, label = token_label_str.split('\t')
            vocab.add(token)
            labels.append(label)
            curr_line.append(token)
        ret_lines.append(curr_line)
    return ret_lines, labels, vocab

In [25]:
def encode_lines(lines, word2idx_map, window_size):
    """
    returns X: len(lines) x (2 * window_size + 1)
    """
    def encode_line(line, word2idx_map, window_size):
        num_repr = [] # numerical representation
        for word in line:
            num = word2idx_map.get(word, word2idx_map['UUUNKKK'])
            num_repr.append(num)
        # pad with start and end tokens
        start = [word2idx_map['<s>']] * window_size
        end = [word2idx_map['</s>']] * window_size
        padded = start + num_repr + end
        
        ret = []
        for i in range(window_size, len(padded) - window_size):
            windowed = padded[i - window_size : i + window_size + 1]
            ret.append(windowed)
            
        return ret
    
    res = []
    for line in lines:
        res.extend(encode_line(line, word2idx_map, window_size))
    return torch.tensor(res)

# Network

In [172]:
class FeedForwardTagger(nn.Module):
    
    def __init__(self, vocab_size, window_size, output_dim,
                 emb_dim=50, pretrained_emb=None, freeze=False):
        super(FeedForwardTagger, self).__init__()
        if pretrained_emb:
            self.emb = nn.Embedding.from_pretrained(pretrain_emb)
        else:
            self.emb = nn.Embedding(vocab_size, emb_dim)
            torch.nn.init.uniform_(self.emb.weight, -0.01, 0.01)
        input_dim = (2 * window_size + 1) * emb_dim
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, output_dim)
    
    def forward(self, inputs):
        embeds = self.emb(inputs).view((inputs.shape[0], -1))
        out = F.tanh(self.fc1(embeds))
        out = self.fc2(out)
        log_probs = F.log_softmax(out, dim=1)
        return log_probs

In [160]:
def train_util(model, X_train, Y_train, X_dev, Y_dev, n_epochs, lr, 
              batch_size):
    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    
    best_model = None
    losses = []
    train_accu_list, dev_accu_list = [], []
    for epoch in range(n_epochs):
        epoch_loss = 0
        
        for i in range(0, X_train.shape[0], batch_size):
            optimizer.zero_grad()
            log_probs = model(X_train[i : i + batch_size])
            loss = loss_func(log_probs, Y_train[i : i + batch_size])
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        train_preds = torch.argmax(model(X_train), dim=1)
        train_accu = accuracy_score(Y_train, train_preds)
        # evaluate on dev
        dev_preds = torch.argmax(model(X_dev), dim=1)
        dev_accu = accuracy_score(Y_dev, dev_preds)
        
        print(epoch, epoch_loss, train_accu, dev_accu)
        losses.append(epoch_loss)
        train_accu_list.append(train_accu)
        dev_accu_list.append(dev_accu)
        
    return losses, train_accu_list, dev_accu_list

# Load Data

In [45]:
train, train_labels, train_vocab = read_corpus(DATADIR + 'twpos-train.tsv')
dev, dev_labels, dev_vocab = read_corpus(DATADIR + 'twpos-dev.tsv')
devtest, devtest_labels, devtest_vocab = read_corpus(DATADIR + 'twpos-devtest.tsv')

In [75]:
label_encoder = LabelEncoder()
label_encoder.fit(list(set(train_labels)))
Y_train = label_encoder.transform(train_labels)
Y_dev = label_encoder.transform(dev_labels)
Y_devtest = label_encoder.transform(devtest_labels)
Y_train = torch.tensor(Y_train, dtype=torch.long)
Y_dev = torch.tensor(Y_dev, dtype=torch.long)
Y_devtest = torch.tensor(Y_devtest, dtype=torch.long)

In [55]:
vocab = train_vocab.copy()
vocab.update(dev_vocab)
vocab.update(devtest_vocab)

# 1. Baseline w/ Randomly Initialized Embeddings

In [11]:
# construct maps for randomly initialized embs
idx2word_rand = sorted(vocab)
idx2word_rand += ['<s>', '</s>', 'UUUNKKK']
word2idx_rand = {word: idx for idx, word in enumerate(idx2word_rand)}

## Encode Train, Dev, DevTest

In [29]:
# w = 0
X_train_w0 = encode_lines(train, word2idx_rand, window_size=0)
X_dev_w0 = encode_lines(dev, word2idx_rand, window_size=0)
X_devtest_w0 = encode_lines(devtest, word2idx_rand, window_size=0)

In [137]:
# w = 1
X_train_w1 = encode_lines(train, word2idx_rand, window_size=1)
X_dev_w1 = encode_lines(dev, word2idx_rand, window_size=1)
X_devtest_w1 = encode_lines(devtest, word2idx_rand, window_size=1)

## Train Model

### w = 0

In [154]:
model = FeedForwardTagger(vocab_size=len(word2idx_rand), 
                          window_size=0,
                          output_dim=len(all_labels))
_ = train_util(model, X_train_w0, Y_train, X_dev_w0, Y_dev, n_epochs=50,
              lr=0.02, batch_size=1)

0 26318.854907561443 0.6308231173380034 0.6191661481020535
1 19847.687172139587 0.6945709281961471 0.6672889442024477
2 17140.754573532566 0.7364273204903677 0.6913503422526447
3 15147.710087355106 0.7639813193228254 0.705247873885086
4 13384.72553572192 0.7875656742556918 0.7118855009334163
5 11717.211856766744 0.8145942790426153 0.7251607550300767
6 10178.92038353162 0.8334500875656743 0.7334577888404895
7 8734.392380004963 0.8512551079976649 0.7373988799004356
8 7502.108096102407 0.8659077641564507 0.7440365069487658
9 6518.994200649857 0.8785755983654407 0.7463181912466293
10 5772.929299781988 0.8897256275539989 0.7498444306160548


KeyboardInterrupt: 

In [169]:
model = FeedForwardTagger(vocab_size=len(word2idx_rand), 
                          window_size=0,
                          output_dim=len(all_labels))
_ = train_util(model, X_train_w0, Y_train, X_dev_w0, Y_dev, n_epochs=100,
              lr=2, batch_size=1000)

0 37.45856440067291 0.47542323409223586 0.4665007259904584
1 28.97168219089508 0.5270286047869235 0.5179423356150177
2 26.855947971343994 0.553298307063631 0.5390997718315702
3 25.546590089797974 0.5666666666666667 0.5471893797967227
4 24.535447239875793 0.5816112084063048 0.5648205766438498
5 23.65640115737915 0.5928196147110333 0.5770587015142087
6 22.863015830516815 0.6021599532983071 0.587637419622485
7 22.13068777322769 0.6159369527145359 0.6015349512549264
8 21.44322830438614 0.626444833625219 0.6100394109105994
9 20.789773643016815 0.6359603035610041 0.6183364447210122
10 20.163319408893585 0.6489784004670169 0.6295374403650695
11 19.56081885099411 0.6612375948628137 0.6448869529143332
12 18.982388079166412 0.6704611792177466 0.6511097282721428
13 18.42963135242462 0.6810858143607705 0.6575399294752126
14 17.904225826263428 0.6890251021599533 0.6637627048330222
15 17.407002925872803 0.6974314068884997 0.6699854801908318
16 16.93762093782425 0.7078225335668418 0.6747562746318191


In [173]:
model = FeedForwardTagger(vocab_size=len(word2idx_rand), 
                          window_size=1,
                          output_dim=len(all_labels))
_ = train_util(model, X_train_w1, Y_train, X_dev_w1, Y_dev, n_epochs=100,
              lr=1, batch_size=1000)

0 51.440529108047485 0.19573847051955634 0.194150591163659
1 48.023465394973755 0.2266199649737303 0.2175897116780751
2 46.34287214279175 0.35288966725043786 0.342874922215308
3 41.87749195098877 0.4299474605954466 0.42086704003318814
4 35.92560696601868 0.5002335084646818 0.49118440157643645
5 30.338786840438843 0.5723876240513719 0.562331466500726
6 26.115234375 0.6063631056625802 0.5924082140634723
7 23.78149378299713 0.6359019264448337 0.6243517942335615
8 22.21266460418701 0.6507880910683013 0.6357602157228791
9 20.879288971424103 0.6665499124343257 0.6486206181290188
10 19.59943276643753 0.6854640980735551 0.6635552789877619
11 18.327219784259796 0.7091068301225919 0.685749844430616
12 17.091582477092743 0.7260945709281962 0.6986102468367559
13 15.919070184230804 0.7434909515469936 0.7164488695291433
14 14.836877554655075 0.7580268534734385 0.7241236258037751
15 13.847228109836578 0.773029772329247 0.7332503629952292
16 12.93113088607788 0.7892586106246351 0.7446587844845468
17 1

In [170]:
model = FeedForwardTagger(vocab_size=len(word2idx_rand), 
                          window_size=1,
                          output_dim=len(all_labels))
_ = train_util(model, X_train_w1, Y_train, X_dev_w1, Y_dev, n_epochs=50,
              lr=2, batch_size=1000)

0 36.602572202682495 0.5162288382953882 0.5131715411740303
1 26.249264121055603 0.5722124927028605 0.5517527483924497
2 23.507112979888916 0.6068301225919439 0.5830740510267579
3 21.791509926319122 0.626970227670753 0.6052686164696122
4 20.53831309080124 0.6461179217746643 0.6185438705662726
5 19.536734640598297 0.6614711033274956 0.6274631819124663
6 18.686891198158264 0.6754816112084063 0.6353453640323584
7 17.93681228160858 0.688733216579101 0.6465463596764157
8 17.257445007562637 0.699883245767659 0.6538062642605269
9 16.629733592271805 0.7128429655575015 0.6631404272972412
10 16.038014620542526 0.7234676007005254 0.6708151835718731
11 15.508488088846207 0.7300058377116171 0.6797344949180668
12 15.255434840917587 0.7430239346176299 0.6842978635137938
13 14.544320568442345 0.7536485697606539 0.6840904376685335
14 14.012437254190445 0.7559252772913018 0.6826384567517113
15 13.895058989524841 0.7646234676007005 0.6884463804190002
16 13.057566717267036 0.7789842381786339 0.697158265919

In [163]:
# w = 2
X_train_w2 = encode_lines(train, word2idx_rand, window_size=2)
X_dev_w2 = encode_lines(dev, word2idx_rand, window_size=2)
X_devtest_w2 = encode_lines(devtest, word2idx_rand, window_size=2)

In [171]:
model = FeedForwardTagger(vocab_size=len(word2idx_rand), 
                          window_size=2,
                          output_dim=len(all_labels))
_ = train_util(model, X_train_w2, Y_train, X_dev_w2, Y_dev, n_epochs=500,
              lr=1, batch_size=X_train_w2.shape[0])

0 3.2408864498138428 0.1950379451255108 0.18481642812694463
1 3.0854387283325195 0.29422066549912435 0.2750466708151836
2 2.9423320293426514 0.34098073555166375 0.3194358017008919
3 2.8048031330108643 0.36065382370110916 0.34204521883426675
4 2.674838066101074 0.3745475773496789 0.35573532462144786
5 2.5570180416107178 0.39136018680677176 0.37295166977805433
6 2.4515624046325684 0.40461179217746646 0.3887160340178386
7 2.356186866760254 0.4204319906596614 0.4048952499481435
8 2.2698733806610107 0.4332165791009924 0.4206596141879278
9 2.192101001739502 0.4447752481027437 0.43808338518979467
10 2.1223580837249756 0.4558085230589609 0.45218834266749636
11 2.060042142868042 0.4649737302977233 0.4627670607757727
12 2.0045011043548584 0.4740805604203152 0.4704418170504045
13 1.9550763368606567 0.4807355516637478 0.4764571665629537
14 1.9110753536224365 0.4871570344424985 0.4833022194565443
15 1.8718129396438599 0.49281961471103325 0.4866210329807094
16 1.8366429805755615 0.4995329830706363 0

138 1.0479636192321777 0.6842381786339755 0.6388716034017838
139 1.0458006858825684 0.6851138353765324 0.6386641775565235
140 1.043658971786499 0.6856392294220666 0.6386641775565235
141 1.0415271520614624 0.6869819030939872 0.6386641775565235
142 1.039417028427124 0.687215411558669 0.6384567517112633
143 1.0373133420944214 0.6877991827203737 0.6388716034017838
144 1.0352345705032349 0.6884997081144192 0.6386641775565235
145 1.0331683158874512 0.6890251021599533 0.6386641775565235
146 1.031112551689148 0.6892586106246351 0.6388716034017838
147 1.0290756225585938 0.6899591360186806 0.6394938809375649
148 1.0270510911941528 0.6905429071803852 0.6394938809375649
149 1.0250391960144043 0.6911266783420899 0.6397013067828251
150 1.0230484008789062 0.691768826619965 0.6394938809375649
151 1.0210626125335693 0.6922358435493287 0.6394938809375649
152 1.019087791442871 0.6926444833625219 0.6399087326280855
153 1.017130970954895 0.6931115002918856 0.6399087326280855
154 1.0151886940002441 0.693461

274 0.830418586730957 0.7461762988908348 0.6618958722256794
275 0.8290773630142212 0.7464681844716871 0.6618958722256794
276 0.8277413249015808 0.746584938704028 0.6623107239161999
277 0.8264058232307434 0.7470519556333917 0.6623107239161999
278 0.8250706791877747 0.7472854640980735 0.6625181497614603
279 0.8237396478652954 0.7476357267950963 0.6625181497614603
280 0.8224065899848938 0.7480443666082895 0.6633478531425016
281 0.8210771083831787 0.7482778750729714 0.6637627048330222
282 0.8197477459907532 0.7487448920023351 0.6639701306782825
283 0.8184197545051575 0.7491535318155284 0.6639701306782825
284 0.8170944452285767 0.749737302977233 0.6635552789877619
285 0.8157685399055481 0.7503210741389376 0.6635552789877619
286 0.8144422173500061 0.7507880910683012 0.6637627048330222
287 0.8131142854690552 0.7512551079976649 0.6643849823688032
288 0.8117927312850952 0.7514302393461763 0.6645924082140635
289 0.8104709982872009 0.7518972562755399 0.6645924082140635
290 0.8091483116149902 0.75

410 0.6512418985366821 0.8082895504962054 0.6776602364654636
411 0.6499039530754089 0.8087565674255692 0.6778676623107239
412 0.6485701203346252 0.8092819614711033 0.6780750881559843
413 0.647233247756958 0.8101576182136603 0.6780750881559843
414 0.6458925604820251 0.8106830122591944 0.677245384774943
415 0.644554078578949 0.8109165207238762 0.6776602364654636
416 0.6432172060012817 0.8113251605370695 0.6774528106202032
417 0.6418792009353638 0.8117338003502627 0.6778676623107239
418 0.6405388116836548 0.8119089316987741 0.6778676623107239
419 0.639200747013092 0.8124927028604787 0.6776602364654636
420 0.6378592252731323 0.8129597197898424 0.6776602364654636
421 0.6365196704864502 0.8132516053706947 0.6778676623107239
422 0.6351797580718994 0.8136018680677175 0.6789047915370255
423 0.6338367462158203 0.8141856392294221 0.6789047915370255
424 0.6324960589408875 0.8145942790426153 0.6782825140012445
425 0.6311555504798889 0.8153531815528313 0.6782825140012445
426 0.6298137307167053 0.816

In [None]:
# construct maps for pretrained word embs