In [1]:
import pickle, argparse, os, sys
from sklearn.metrics import accuracy_score
import numpy as np
import random
import torch
import torch.nn as nn
import torch.functional as F
from collections import defaultdict
import time
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score
from helpers import *

# Train

In [2]:
def train(training_file):
    assert os.path.isfile(training_file), 'Training file does not exist'

    # Your code starts here
    # load the data, vocabulary, lookups and other things - all here in this function
    sentences, voc, idx_2_word, word_2_idx, tag_mapper, target_size = get_training_data(training_file)
    
    # now load the embeddings
    weights_matrix = get_glove('./glove.42B.300d.txt', 300, len(voc) + 1, word_2_idx) # +1 because of PAD token
    
    # maybe will be run on cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # init the model
    model = LSTMTagger(weights_matrix, 128, target_size)
    model.apply(init_weights)
    model = model.to(device)    
    
    # contruct the dataset and dataloader
    ds = WSJDataset(sentences, word_2_idx, tag_mapper)
    dl = DataLoader(ds, batch_size=128, collate_fn=collate_examples)

    # cross-entropy and adam 
    criterion = torch.nn.CrossEntropyLoss(ignore_index=0) # ignore index is for target so should work
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    # train for 20 epochs
    for epoch in range(40):
        train_loss, train_acc = train_epoch(model, dl, optimizer, criterion, epoch)
        
    state_dict= model.state_dict()
    to_serialize = {'state_dict': state_dict, 'word_2_idx': word_2_idx, 'tag_mapper': tag_mapper}
    # Your code ends here
    
    return to_serialize

In [3]:
model = train('data/wsj1-18.training')
torch.save(model, 'model.torch')

Found 1807092 word vectors in glove.
embed_matrix.shape (16927, 300)
11689 words are found in glove
Epoch: 001/010 | Batch 000/299 | Cost: 3.8495 | : Accuracy: 0.0206
Epoch: 001/010 | Batch 100/299 | Cost: 3.4092 | : Accuracy: 0.4920
Epoch: 001/010 | Batch 200/299 | Cost: 3.3003 | : Accuracy: 0.5906
Epoch: 002/010 | Batch 000/299 | Cost: 3.2309 | : Accuracy: 0.6570
Epoch: 002/010 | Batch 100/299 | Cost: 3.1930 | : Accuracy: 0.6928
Epoch: 002/010 | Batch 200/299 | Cost: 3.1803 | : Accuracy: 0.7043
Epoch: 003/010 | Batch 000/299 | Cost: 3.1771 | : Accuracy: 0.7072
Epoch: 003/010 | Batch 100/299 | Cost: 3.1657 | : Accuracy: 0.7188
Epoch: 003/010 | Batch 200/299 | Cost: 3.1553 | : Accuracy: 0.7289
Epoch: 004/010 | Batch 000/299 | Cost: 3.0862 | : Accuracy: 0.8017
Epoch: 004/010 | Batch 100/299 | Cost: 3.1080 | : Accuracy: 0.7774
Epoch: 004/010 | Batch 200/299 | Cost: 3.0785 | : Accuracy: 0.8065
Epoch: 005/010 | Batch 000/299 | Cost: 3.0718 | : Accuracy: 0.8135
Epoch: 005/010 | Batch 100/29

# Predict

In [4]:
def test(model_file, data_file, label_file):
    assert os.path.isfile(model_file), 'Model file does not exist'
    assert os.path.isfile(data_file), 'Data file does not exist'
    assert os.path.isfile(label_file), 'Label file does not exist'
    
    model_and_co = torch.load('model.torch')
    state_dict = model_and_co['state_dict']
    word_2_idx = model_and_co['word_2_idx']
    tag_mapper = model_and_co['tag_mapper']
    
    # Your code starts here
    valid_sentences = make_validation_pairs(data_file, label_file)
    
    valid_ds = WSJDataset(valid_sentences, word_2_idx, tag_mapper)
    valid_dl = DataLoader(valid_ds, batch_size=32, collate_fn=collate_examples)
    
    model = LSTMTagger(np.zeros((len(word_2_idx), 300)), 128, len(tag_mapper)) # doesn't matter zeros because we load later
    model.load_state_dict(state_dict)
    model = model.to(device) 
    
    # for the conviniency I get ground truth through the dataset as well
    prediction, ground_truth = make_predictions(model, valid_dl)
    prediction = np.concatenate(prediction)
    ground_truth = np.concatenate(ground_truth)

    print(f'The accuracy of the model is {100*accuracy_score(prediction, ground_truth):6.2f}%')

In [5]:
test('model.torch', 'data/wsj19-21.testing', 'data/wsj19-21.truth')

The accuracy of the model is  89.04%
