In [1]:
# declare a list tasks whose products you want to use as inputs
upstream = None


In [2]:
# Parameters
epochs = 10
product = {
    "nb": "/Users/mboussarov/_umsi/Capstone/umads_697_data_medics/pipeline/output/nn_classification_analysis.ipynb"
}


## Classification Comparison Using Neural Networks

In [3]:
import os
import pandas as pd
import sys
import spacy
from spacy.language import Language
import time
import torch
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch import nn
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.nn import functional as F
from torch.optim import Adam
from sklearn.metrics import accuracy_score, f1_score

In [4]:
upstream = []
nrows = None
epochs = 10

In [5]:
sys.path.insert(0, "..")

# project imports
import locations as loc

from nn_models import TweetClassificationLSTM, TweetClassificationEmbedder

# run model on gpu if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# location of saved lemmatized texts
lemma_sents_file_location = os.path.join(loc.outputs, "train_lemma_sents.csv")
dev_lemma_sents_file_location = os.path.join(loc.outputs, "dev_lemma_sents.csv")
test_lemma_sents_file_location = os.path.join(loc.outputs, "test_lemma_sents.csv")

# if true, lemmatize the raw tweets
# if false, load the saved, processed tweets (assumes the lemmatized tweets already exist)
train_exists = os.path.exists(lemma_sents_file_location)
dev_exists = os.path.exists(dev_lemma_sents_file_location)
test_exists = os.path.exists(test_lemma_sents_file_location)

lemmatize_texts = any(not i for i in [train_exists, dev_exists, test_exists])

In [6]:
# load a spacy language model
nlp = spacy.load("en_core_web_sm")
stopwords = nlp.Defaults.stop_words

In [7]:
# here are the transformations the spacy nlp object will perform on every doc
nlp.pipeline

[('tok2vec', <spacy.pipeline.tok2vec.Tok2Vec at 0x7fbdae69d9a0>),
 ('tagger', <spacy.pipeline.tagger.Tagger at 0x7fbdae8e30a0>),
 ('parser', <spacy.pipeline.dep_parser.DependencyParser at 0x7fbdae58aac0>),
 ('attribute_ruler',
  <spacy.pipeline.attributeruler.AttributeRuler at 0x7fbdae8d1900>),
 ('lemmatizer',
  <spacy.lang.en.lemmatizer.EnglishLemmatizer at 0x7fbdae5af3c0>),
 ('ner', <spacy.pipeline.ner.EntityRecognizer at 0x7fbdae58aa50>)]

In [8]:
# load the data, sample the DF if interested
# mostly just to get the pipeline working, training/test should be done on the full data set
data_path = os.path.join(loc.data, "all_combined", "all_train.tsv")
dev_data_path = os.path.join(loc.data, "all_combined", "all_dev.tsv")
test_data_path = os.path.join(loc.data, "all_combined", "all_test.tsv")

if isinstance(nrows, int):
    df = pd.read_csv(data_path, sep="\t", nrows=nrows)
    dev_df = pd.read_csv(dev_data_path, sep="\t", nrows=nrows)
    test_df = pd.read_csv(test_data_path, sep="\t", nrows=nrows)
else:
    df = pd.read_csv(data_path, sep="\t")
    dev_df = pd.read_csv(dev_data_path, sep="\t")
    test_df = pd.read_csv(test_data_path, sep="\t")

In [9]:
label_encoder_dict = {i: idx for idx, i in enumerate(df["class_label"].unique())}

In [10]:
# apply the lemmatizer to al lthe tweets
# and save the outputs to csv files
if lemmatize_texts:
    # apply the spacy pipeline to the tweets
    docs = df["tweet_text"].apply(lambda x: nlp(x))
    labels = df["class_label"].apply(lambda a: label_encoder_dict[a])
    pd.DataFrame({"labels": labels, "tweet_text": docs}).to_csv(lemma_sents_file_location, index=False)

    # apply the spacy pipeline to the tweets
    dev_docs = dev_df["tweet_text"].apply(lambda x: nlp(x))
    dev_labels = dev_df["class_label"].apply(lambda a: label_encoder_dict[a])
    pd.DataFrame({"labels": dev_labels, "tweet_text": dev_docs}).to_csv(dev_lemma_sents_file_location, index=False)

    # apply the spacy pipeline to the tweets
    test_docs = test_df["tweet_text"].apply(lambda x: nlp(x))
    test_labels = test_df["class_label"].apply(lambda a: label_encoder_dict[a])
    pd.DataFrame({"labels": test_labels, "tweet_text": test_docs}).to_csv(test_lemma_sents_file_location, index=False)

else:
    print("'lemmatize_texts' is set to False - loading saved texts")

'lemmatize_texts' is set to False - loading saved texts


In [11]:
# read the lemmatized sentence
t = pd.read_csv(lemma_sents_file_location)
lemma_docs = list(zip(t.labels.to_list(), t.tweet_text.to_list()))

t = pd.read_csv(dev_lemma_sents_file_location)
dev_lemma_docs = list(zip(t.labels.to_list(), t.tweet_text.to_list()))

t = pd.read_csv(test_lemma_sents_file_location)
test_lemma_docs = list(zip(t.labels.to_list(), t.tweet_text.to_list()))

#### Prepare Data and Build Helper Functions

In [12]:
# pytorch helper functions
# used with `build_vocab_from_iterator` to build the word to token mapping

def yield_tokens(doc_strings):
    # discard the label because it does not need to be tokenized
    for _, text in doc_strings:
        # yield the tokenized text
        yield tokenizer(text)

In [13]:
# use the torchtext tokenizer
# a little redundent because we have spacy, but this allows for the entire pipeline to run
# in torch if we want

tokenizer = get_tokenizer('basic_english')

In [14]:
# build the torch encodings
# add a special character for out of bag words
vocab = build_vocab_from_iterator(yield_tokens(lemma_docs), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

In [15]:
# convenience functions to transform data
text_transform = lambda x: vocab(tokenizer(x))
label_transform = lambda x: int(x)

In [16]:
def collate_batch_embedder(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_transform(_label))
        processed_text = torch.tensor(text_transform(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

# function will run on the batches of data BEFORE they are passed to the model
def collate_batch_lstm(batch):
    label_list, text_list = [], []
    for (_label, _text) in batch:
        label_list.append(label_transform(_label))
        # t_text = text_transform(_text)[:max_words]
        processed_text = torch.tensor(text_transform(_text)[:max_words])
        text_list.append(processed_text)
    return torch.tensor(label_list), pad_sequence(text_list, padding_value=3.0).T

In [17]:
# number of output layers
# should be 10 because there are 10 classes
output_dim = len(label_encoder_dict)

### Embedding Bag Model

In [18]:
# model run variables
vocab_size = len(vocab)
emsize = 64
batch_log_freq = 25


# instantiate model
tweet_embedding_classifier = TweetClassificationEmbedder(vocab_size, emsize, output_dim).to(device)

In [19]:
# define the training loop
def train(dataloader):
    tweet_embedding_classifier.train()
    total_acc, total_count = 0, 0

    for idx, (label, text, offsets) in enumerate(dataloader):
        # zero out the gradient for a new run
        optimizer.zero_grad()
        # create a prediction
        predicted_label = tweet_embedding_classifier(text, offsets)
        # calculate loss and run backprop
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(tweet_embedding_classifier.parameters(), 0.1)
        # update weights
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % batch_log_freq == 0 and idx > 0:
            accuracy = total_acc / total_count
            print(f"Epoch: {epoch} Batch: {idx} of {total_batches_per_epoch}.  Accuracy: {accuracy:.2f}\n")
            total_acc, total_count = 0, 0


def evaluate(dataloader):
    tweet_embedding_classifier.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = tweet_embedding_classifier(text, offsets)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count

In [20]:
# Hyperparameters
# set from pipeline.yaml
# EPOCHS = 15
LR = 5  # learning rate
BATCH_SIZE = 64  # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(tweet_embedding_classifier.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

train_loader = DataLoader(lemma_docs, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch_embedder)
val_loader = DataLoader(dev_lemma_docs, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch_embedder)
test_loader = DataLoader(test_lemma_docs, batch_size=BATCH_SIZE,
                             shuffle=True, collate_fn=collate_batch_embedder)

total_batches_per_epoch = len(train_loader)

In [21]:
# model training loop
for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(train_loader)
    accu_val = evaluate(val_loader)
    if total_accu is not None and total_accu > accu_val:
        scheduler.step()
    else:
        total_accu = accu_val
    print('#' * 25)

    print(f"\nVALIDATION SET: Epoch: {epoch} Accuracy: {accu_val:.2f}\n")

    print('#' * 25)

Epoch: 1 Batch: 25 of 837.  Accuracy: 0.27

Epoch: 1 Batch: 50 of 837.  Accuracy: 0.31

Epoch: 1 Batch: 75 of 837.  Accuracy: 0.36



Epoch: 1 Batch: 100 of 837.  Accuracy: 0.43

Epoch: 1 Batch: 125 of 837.  Accuracy: 0.47

Epoch: 1 Batch: 150 of 837.  Accuracy: 0.49



Epoch: 1 Batch: 175 of 837.  Accuracy: 0.50

Epoch: 1 Batch: 200 of 837.  Accuracy: 0.50

Epoch: 1 Batch: 225 of 837.  Accuracy: 0.54



Epoch: 1 Batch: 250 of 837.  Accuracy: 0.57

Epoch: 1 Batch: 275 of 837.  Accuracy: 0.55

Epoch: 1 Batch: 300 of 837.  Accuracy: 0.60



Epoch: 1 Batch: 325 of 837.  Accuracy: 0.59

Epoch: 1 Batch: 350 of 837.  Accuracy: 0.62

Epoch: 1 Batch: 375 of 837.  Accuracy: 0.62



Epoch: 1 Batch: 400 of 837.  Accuracy: 0.62

Epoch: 1 Batch: 425 of 837.  Accuracy: 0.63

Epoch: 1 Batch: 450 of 837.  Accuracy: 0.64



Epoch: 1 Batch: 475 of 837.  Accuracy: 0.65

Epoch: 1 Batch: 500 of 837.  Accuracy: 0.65

Epoch: 1 Batch: 525 of 837.  Accuracy: 0.61



Epoch: 1 Batch: 550 of 837.  Accuracy: 0.64

Epoch: 1 Batch: 575 of 837.  Accuracy: 0.66

Epoch: 1 Batch: 600 of 837.  Accuracy: 0.67



Epoch: 1 Batch: 625 of 837.  Accuracy: 0.67

Epoch: 1 Batch: 650 of 837.  Accuracy: 0.66

Epoch: 1 Batch: 675 of 837.  Accuracy: 0.66



Epoch: 1 Batch: 700 of 837.  Accuracy: 0.67

Epoch: 1 Batch: 725 of 837.  Accuracy: 0.68

Epoch: 1 Batch: 750 of 837.  Accuracy: 0.66



Epoch: 1 Batch: 775 of 837.  Accuracy: 0.65

Epoch: 1 Batch: 800 of 837.  Accuracy: 0.69

Epoch: 1 Batch: 825 of 837.  Accuracy: 0.68



#########################

VALIDATION SET: Epoch: 1 Accuracy: 0.68

#########################
Epoch: 2 Batch: 25 of 837.  Accuracy: 0.70

Epoch: 2 Batch: 50 of 837.  Accuracy: 0.68



Epoch: 2 Batch: 75 of 837.  Accuracy: 0.69

Epoch: 2 Batch: 100 of 837.  Accuracy: 0.70

Epoch: 2 Batch: 125 of 837.  Accuracy: 0.68



Epoch: 2 Batch: 150 of 837.  Accuracy: 0.69

Epoch: 2 Batch: 175 of 837.  Accuracy: 0.71

Epoch: 2 Batch: 200 of 837.  Accuracy: 0.71



Epoch: 2 Batch: 225 of 837.  Accuracy: 0.67

Epoch: 2 Batch: 250 of 837.  Accuracy: 0.68

Epoch: 2 Batch: 275 of 837.  Accuracy: 0.68



Epoch: 2 Batch: 300 of 837.  Accuracy: 0.71

Epoch: 2 Batch: 325 of 837.  Accuracy: 0.70

Epoch: 2 Batch: 350 of 837.  Accuracy: 0.71



Epoch: 2 Batch: 375 of 837.  Accuracy: 0.71

Epoch: 2 Batch: 400 of 837.  Accuracy: 0.71

Epoch: 2 Batch: 425 of 837.  Accuracy: 0.70



Epoch: 2 Batch: 450 of 837.  Accuracy: 0.71

Epoch: 2 Batch: 475 of 837.  Accuracy: 0.71

Epoch: 2 Batch: 500 of 837.  Accuracy: 0.71



Epoch: 2 Batch: 525 of 837.  Accuracy: 0.71

Epoch: 2 Batch: 550 of 837.  Accuracy: 0.72

Epoch: 2 Batch: 575 of 837.  Accuracy: 0.71



Epoch: 2 Batch: 600 of 837.  Accuracy: 0.71

Epoch: 2 Batch: 625 of 837.  Accuracy: 0.71

Epoch: 2 Batch: 650 of 837.  Accuracy: 0.70



Epoch: 2 Batch: 675 of 837.  Accuracy: 0.72

Epoch: 2 Batch: 700 of 837.  Accuracy: 0.71

Epoch: 2 Batch: 725 of 837.  Accuracy: 0.72



Epoch: 2 Batch: 750 of 837.  Accuracy: 0.71

Epoch: 2 Batch: 775 of 837.  Accuracy: 0.72

Epoch: 2 Batch: 800 of 837.  Accuracy: 0.72



Epoch: 2 Batch: 825 of 837.  Accuracy: 0.70



#########################

VALIDATION SET: Epoch: 2 Accuracy: 0.70

#########################
Epoch: 3 Batch: 25 of 837.  Accuracy: 0.73

Epoch: 3 Batch: 50 of 837.  Accuracy: 0.74



Epoch: 3 Batch: 75 of 837.  Accuracy: 0.74

Epoch: 3 Batch: 100 of 837.  Accuracy: 0.73

Epoch: 3 Batch: 125 of 837.  Accuracy: 0.73



Epoch: 3 Batch: 150 of 837.  Accuracy: 0.75

Epoch: 3 Batch: 175 of 837.  Accuracy: 0.72

Epoch: 3 Batch: 200 of 837.  Accuracy: 0.74



Epoch: 3 Batch: 225 of 837.  Accuracy: 0.72

Epoch: 3 Batch: 250 of 837.  Accuracy: 0.74

Epoch: 3 Batch: 275 of 837.  Accuracy: 0.73



Epoch: 3 Batch: 300 of 837.  Accuracy: 0.73

Epoch: 3 Batch: 325 of 837.  Accuracy: 0.74

Epoch: 3 Batch: 350 of 837.  Accuracy: 0.72



Epoch: 3 Batch: 375 of 837.  Accuracy: 0.74

Epoch: 3 Batch: 400 of 837.  Accuracy: 0.72

Epoch: 3 Batch: 425 of 837.  Accuracy: 0.73



Epoch: 3 Batch: 450 of 837.  Accuracy: 0.74

Epoch: 3 Batch: 475 of 837.  Accuracy: 0.74

Epoch: 3 Batch: 500 of 837.  Accuracy: 0.73



Epoch: 3 Batch: 525 of 837.  Accuracy: 0.73

Epoch: 3 Batch: 550 of 837.  Accuracy: 0.74

Epoch: 3 Batch: 575 of 837.  Accuracy: 0.74



Epoch: 3 Batch: 600 of 837.  Accuracy: 0.74

Epoch: 3 Batch: 625 of 837.  Accuracy: 0.74

Epoch: 3 Batch: 650 of 837.  Accuracy: 0.73



Epoch: 3 Batch: 675 of 837.  Accuracy: 0.73

Epoch: 3 Batch: 700 of 837.  Accuracy: 0.73

Epoch: 3 Batch: 725 of 837.  Accuracy: 0.74



Epoch: 3 Batch: 750 of 837.  Accuracy: 0.74

Epoch: 3 Batch: 775 of 837.  Accuracy: 0.74

Epoch: 3 Batch: 800 of 837.  Accuracy: 0.74



Epoch: 3 Batch: 825 of 837.  Accuracy: 0.72



#########################

VALIDATION SET: Epoch: 3 Accuracy: 0.72

#########################
Epoch: 4 Batch: 25 of 837.  Accuracy: 0.75

Epoch: 4 Batch: 50 of 837.  Accuracy: 0.76



Epoch: 4 Batch: 75 of 837.  Accuracy: 0.77

Epoch: 4 Batch: 100 of 837.  Accuracy: 0.74

Epoch: 4 Batch: 125 of 837.  Accuracy: 0.75



Epoch: 4 Batch: 150 of 837.  Accuracy: 0.75

Epoch: 4 Batch: 175 of 837.  Accuracy: 0.75

Epoch: 4 Batch: 200 of 837.  Accuracy: 0.77



Epoch: 4 Batch: 225 of 837.  Accuracy: 0.75

Epoch: 4 Batch: 250 of 837.  Accuracy: 0.73

Epoch: 4 Batch: 275 of 837.  Accuracy: 0.77



Epoch: 4 Batch: 300 of 837.  Accuracy: 0.75

Epoch: 4 Batch: 325 of 837.  Accuracy: 0.76

Epoch: 4 Batch: 350 of 837.  Accuracy: 0.74



Epoch: 4 Batch: 375 of 837.  Accuracy: 0.74

Epoch: 4 Batch: 400 of 837.  Accuracy: 0.75

Epoch: 4 Batch: 425 of 837.  Accuracy: 0.77



Epoch: 4 Batch: 450 of 837.  Accuracy: 0.75

Epoch: 4 Batch: 475 of 837.  Accuracy: 0.76

Epoch: 4 Batch: 500 of 837.  Accuracy: 0.73



Epoch: 4 Batch: 525 of 837.  Accuracy: 0.77

Epoch: 4 Batch: 550 of 837.  Accuracy: 0.74

Epoch: 4 Batch: 575 of 837.  Accuracy: 0.76



Epoch: 4 Batch: 600 of 837.  Accuracy: 0.75

Epoch: 4 Batch: 625 of 837.  Accuracy: 0.74

Epoch: 4 Batch: 650 of 837.  Accuracy: 0.75



Epoch: 4 Batch: 675 of 837.  Accuracy: 0.74

Epoch: 4 Batch: 700 of 837.  Accuracy: 0.76

Epoch: 4 Batch: 725 of 837.  Accuracy: 0.76



Epoch: 4 Batch: 750 of 837.  Accuracy: 0.74

Epoch: 4 Batch: 775 of 837.  Accuracy: 0.74

Epoch: 4 Batch: 800 of 837.  Accuracy: 0.77



Epoch: 4 Batch: 825 of 837.  Accuracy: 0.75



#########################

VALIDATION SET: Epoch: 4 Accuracy: 0.72

#########################
Epoch: 5 Batch: 25 of 837.  Accuracy: 0.77

Epoch: 5 Batch: 50 of 837.  Accuracy: 0.75



Epoch: 5 Batch: 75 of 837.  Accuracy: 0.78

Epoch: 5 Batch: 100 of 837.  Accuracy: 0.76

Epoch: 5 Batch: 125 of 837.  Accuracy: 0.77



Epoch: 5 Batch: 150 of 837.  Accuracy: 0.77

Epoch: 5 Batch: 175 of 837.  Accuracy: 0.76

Epoch: 5 Batch: 200 of 837.  Accuracy: 0.77



Epoch: 5 Batch: 225 of 837.  Accuracy: 0.77

Epoch: 5 Batch: 250 of 837.  Accuracy: 0.78

Epoch: 5 Batch: 275 of 837.  Accuracy: 0.77



Epoch: 5 Batch: 300 of 837.  Accuracy: 0.77

Epoch: 5 Batch: 325 of 837.  Accuracy: 0.76

Epoch: 5 Batch: 350 of 837.  Accuracy: 0.75



Epoch: 5 Batch: 375 of 837.  Accuracy: 0.76

Epoch: 5 Batch: 400 of 837.  Accuracy: 0.78

Epoch: 5 Batch: 425 of 837.  Accuracy: 0.78



Epoch: 5 Batch: 450 of 837.  Accuracy: 0.77

Epoch: 5 Batch: 475 of 837.  Accuracy: 0.75

Epoch: 5 Batch: 500 of 837.  Accuracy: 0.76



Epoch: 5 Batch: 525 of 837.  Accuracy: 0.77

Epoch: 5 Batch: 550 of 837.  Accuracy: 0.78

Epoch: 5 Batch: 575 of 837.  Accuracy: 0.76



Epoch: 5 Batch: 600 of 837.  Accuracy: 0.77

Epoch: 5 Batch: 625 of 837.  Accuracy: 0.76

Epoch: 5 Batch: 650 of 837.  Accuracy: 0.76



Epoch: 5 Batch: 675 of 837.  Accuracy: 0.76

Epoch: 5 Batch: 700 of 837.  Accuracy: 0.75

Epoch: 5 Batch: 725 of 837.  Accuracy: 0.77



Epoch: 5 Batch: 750 of 837.  Accuracy: 0.77

Epoch: 5 Batch: 775 of 837.  Accuracy: 0.74

Epoch: 5 Batch: 800 of 837.  Accuracy: 0.77



Epoch: 5 Batch: 825 of 837.  Accuracy: 0.77



#########################

VALIDATION SET: Epoch: 5 Accuracy: 0.72

#########################
Epoch: 6 Batch: 25 of 837.  Accuracy: 0.80

Epoch: 6 Batch: 50 of 837.  Accuracy: 0.79



Epoch: 6 Batch: 75 of 837.  Accuracy: 0.79

Epoch: 6 Batch: 100 of 837.  Accuracy: 0.79

Epoch: 6 Batch: 125 of 837.  Accuracy: 0.78



Epoch: 6 Batch: 150 of 837.  Accuracy: 0.79

Epoch: 6 Batch: 175 of 837.  Accuracy: 0.79

Epoch: 6 Batch: 200 of 837.  Accuracy: 0.77



Epoch: 6 Batch: 225 of 837.  Accuracy: 0.77

Epoch: 6 Batch: 250 of 837.  Accuracy: 0.79

Epoch: 6 Batch: 275 of 837.  Accuracy: 0.79



Epoch: 6 Batch: 300 of 837.  Accuracy: 0.77

Epoch: 6 Batch: 325 of 837.  Accuracy: 0.77

Epoch: 6 Batch: 350 of 837.  Accuracy: 0.77



Epoch: 6 Batch: 375 of 837.  Accuracy: 0.77

Epoch: 6 Batch: 400 of 837.  Accuracy: 0.78

Epoch: 6 Batch: 425 of 837.  Accuracy: 0.75



Epoch: 6 Batch: 450 of 837.  Accuracy: 0.79

Epoch: 6 Batch: 475 of 837.  Accuracy: 0.76

Epoch: 6 Batch: 500 of 837.  Accuracy: 0.76



Epoch: 6 Batch: 525 of 837.  Accuracy: 0.78

Epoch: 6 Batch: 550 of 837.  Accuracy: 0.76

Epoch: 6 Batch: 575 of 837.  Accuracy: 0.77



Epoch: 6 Batch: 600 of 837.  Accuracy: 0.78

Epoch: 6 Batch: 625 of 837.  Accuracy: 0.77

Epoch: 6 Batch: 650 of 837.  Accuracy: 0.76



Epoch: 6 Batch: 675 of 837.  Accuracy: 0.78

Epoch: 6 Batch: 700 of 837.  Accuracy: 0.77

Epoch: 6 Batch: 725 of 837.  Accuracy: 0.76



Epoch: 6 Batch: 750 of 837.  Accuracy: 0.78

Epoch: 6 Batch: 775 of 837.  Accuracy: 0.78

Epoch: 6 Batch: 800 of 837.  Accuracy: 0.79



Epoch: 6 Batch: 825 of 837.  Accuracy: 0.79



#########################

VALIDATION SET: Epoch: 6 Accuracy: 0.72

#########################
Epoch: 7 Batch: 25 of 837.  Accuracy: 0.78

Epoch: 7 Batch: 50 of 837.  Accuracy: 0.81



Epoch: 7 Batch: 75 of 837.  Accuracy: 0.81

Epoch: 7 Batch: 100 of 837.  Accuracy: 0.79

Epoch: 7 Batch: 125 of 837.  Accuracy: 0.79



Epoch: 7 Batch: 150 of 837.  Accuracy: 0.78

Epoch: 7 Batch: 175 of 837.  Accuracy: 0.80

Epoch: 7 Batch: 200 of 837.  Accuracy: 0.80



Epoch: 7 Batch: 225 of 837.  Accuracy: 0.79

Epoch: 7 Batch: 250 of 837.  Accuracy: 0.79

Epoch: 7 Batch: 275 of 837.  Accuracy: 0.78



Epoch: 7 Batch: 300 of 837.  Accuracy: 0.77

Epoch: 7 Batch: 325 of 837.  Accuracy: 0.79

Epoch: 7 Batch: 350 of 837.  Accuracy: 0.80



Epoch: 7 Batch: 375 of 837.  Accuracy: 0.81

Epoch: 7 Batch: 400 of 837.  Accuracy: 0.80

Epoch: 7 Batch: 425 of 837.  Accuracy: 0.79



Epoch: 7 Batch: 450 of 837.  Accuracy: 0.79

Epoch: 7 Batch: 475 of 837.  Accuracy: 0.79

Epoch: 7 Batch: 500 of 837.  Accuracy: 0.80



Epoch: 7 Batch: 525 of 837.  Accuracy: 0.78

Epoch: 7 Batch: 550 of 837.  Accuracy: 0.79

Epoch: 7 Batch: 575 of 837.  Accuracy: 0.77



Epoch: 7 Batch: 600 of 837.  Accuracy: 0.78

Epoch: 7 Batch: 625 of 837.  Accuracy: 0.79

Epoch: 7 Batch: 650 of 837.  Accuracy: 0.77



Epoch: 7 Batch: 675 of 837.  Accuracy: 0.80

Epoch: 7 Batch: 700 of 837.  Accuracy: 0.77

Epoch: 7 Batch: 725 of 837.  Accuracy: 0.79



Epoch: 7 Batch: 750 of 837.  Accuracy: 0.79

Epoch: 7 Batch: 775 of 837.  Accuracy: 0.78

Epoch: 7 Batch: 800 of 837.  Accuracy: 0.80



Epoch: 7 Batch: 825 of 837.  Accuracy: 0.80



#########################

VALIDATION SET: Epoch: 7 Accuracy: 0.72

#########################
Epoch: 8 Batch: 25 of 837.  Accuracy: 0.81

Epoch: 8 Batch: 50 of 837.  Accuracy: 0.81



Epoch: 8 Batch: 75 of 837.  Accuracy: 0.81

Epoch: 8 Batch: 100 of 837.  Accuracy: 0.81

Epoch: 8 Batch: 125 of 837.  Accuracy: 0.81



Epoch: 8 Batch: 150 of 837.  Accuracy: 0.83

Epoch: 8 Batch: 175 of 837.  Accuracy: 0.81

Epoch: 8 Batch: 200 of 837.  Accuracy: 0.80



Epoch: 8 Batch: 225 of 837.  Accuracy: 0.82

Epoch: 8 Batch: 250 of 837.  Accuracy: 0.81

Epoch: 8 Batch: 275 of 837.  Accuracy: 0.81



Epoch: 8 Batch: 300 of 837.  Accuracy: 0.79

Epoch: 8 Batch: 325 of 837.  Accuracy: 0.82

Epoch: 8 Batch: 350 of 837.  Accuracy: 0.83



Epoch: 8 Batch: 375 of 837.  Accuracy: 0.82

Epoch: 8 Batch: 400 of 837.  Accuracy: 0.81

Epoch: 8 Batch: 425 of 837.  Accuracy: 0.80



Epoch: 8 Batch: 450 of 837.  Accuracy: 0.80

Epoch: 8 Batch: 475 of 837.  Accuracy: 0.82

Epoch: 8 Batch: 500 of 837.  Accuracy: 0.82



Epoch: 8 Batch: 525 of 837.  Accuracy: 0.81

Epoch: 8 Batch: 550 of 837.  Accuracy: 0.81

Epoch: 8 Batch: 575 of 837.  Accuracy: 0.82



Epoch: 8 Batch: 600 of 837.  Accuracy: 0.82

Epoch: 8 Batch: 625 of 837.  Accuracy: 0.81

Epoch: 8 Batch: 650 of 837.  Accuracy: 0.80



Epoch: 8 Batch: 675 of 837.  Accuracy: 0.82

Epoch: 8 Batch: 700 of 837.  Accuracy: 0.81

Epoch: 8 Batch: 725 of 837.  Accuracy: 0.80



Epoch: 8 Batch: 750 of 837.  Accuracy: 0.82

Epoch: 8 Batch: 775 of 837.  Accuracy: 0.80



Epoch: 8 Batch: 800 of 837.  Accuracy: 0.80

Epoch: 8 Batch: 825 of 837.  Accuracy: 0.82



#########################

VALIDATION SET: Epoch: 8 Accuracy: 0.73

#########################
Epoch: 9 Batch: 25 of 837.  Accuracy: 0.81

Epoch: 9 Batch: 50 of 837.  Accuracy: 0.80



Epoch: 9 Batch: 75 of 837.  Accuracy: 0.83

Epoch: 9 Batch: 100 of 837.  Accuracy: 0.82

Epoch: 9 Batch: 125 of 837.  Accuracy: 0.82



Epoch: 9 Batch: 150 of 837.  Accuracy: 0.82

Epoch: 9 Batch: 175 of 837.  Accuracy: 0.80

Epoch: 9 Batch: 200 of 837.  Accuracy: 0.81



Epoch: 9 Batch: 225 of 837.  Accuracy: 0.82

Epoch: 9 Batch: 250 of 837.  Accuracy: 0.81

Epoch: 9 Batch: 275 of 837.  Accuracy: 0.82



Epoch: 9 Batch: 300 of 837.  Accuracy: 0.82

Epoch: 9 Batch: 325 of 837.  Accuracy: 0.83

Epoch: 9 Batch: 350 of 837.  Accuracy: 0.80



Epoch: 9 Batch: 375 of 837.  Accuracy: 0.80

Epoch: 9 Batch: 400 of 837.  Accuracy: 0.81

Epoch: 9 Batch: 425 of 837.  Accuracy: 0.81



Epoch: 9 Batch: 450 of 837.  Accuracy: 0.83

Epoch: 9 Batch: 475 of 837.  Accuracy: 0.83

Epoch: 9 Batch: 500 of 837.  Accuracy: 0.81



Epoch: 9 Batch: 525 of 837.  Accuracy: 0.82

Epoch: 9 Batch: 550 of 837.  Accuracy: 0.81

Epoch: 9 Batch: 575 of 837.  Accuracy: 0.80



Epoch: 9 Batch: 600 of 837.  Accuracy: 0.82

Epoch: 9 Batch: 625 of 837.  Accuracy: 0.82

Epoch: 9 Batch: 650 of 837.  Accuracy: 0.83



Epoch: 9 Batch: 675 of 837.  Accuracy: 0.80

Epoch: 9 Batch: 700 of 837.  Accuracy: 0.79

Epoch: 9 Batch: 725 of 837.  Accuracy: 0.81



Epoch: 9 Batch: 750 of 837.  Accuracy: 0.82

Epoch: 9 Batch: 775 of 837.  Accuracy: 0.82

Epoch: 9 Batch: 800 of 837.  Accuracy: 0.82



Epoch: 9 Batch: 825 of 837.  Accuracy: 0.82



#########################

VALIDATION SET: Epoch: 9 Accuracy: 0.73

#########################
Epoch: 10 Batch: 25 of 837.  Accuracy: 0.81

Epoch: 10 Batch: 50 of 837.  Accuracy: 0.81



Epoch: 10 Batch: 75 of 837.  Accuracy: 0.81

Epoch: 10 Batch: 100 of 837.  Accuracy: 0.81

Epoch: 10 Batch: 125 of 837.  Accuracy: 0.81



Epoch: 10 Batch: 150 of 837.  Accuracy: 0.81

Epoch: 10 Batch: 175 of 837.  Accuracy: 0.82

Epoch: 10 Batch: 200 of 837.  Accuracy: 0.81



Epoch: 10 Batch: 225 of 837.  Accuracy: 0.81

Epoch: 10 Batch: 250 of 837.  Accuracy: 0.83

Epoch: 10 Batch: 275 of 837.  Accuracy: 0.82



Epoch: 10 Batch: 300 of 837.  Accuracy: 0.81

Epoch: 10 Batch: 325 of 837.  Accuracy: 0.81

Epoch: 10 Batch: 350 of 837.  Accuracy: 0.81



Epoch: 10 Batch: 375 of 837.  Accuracy: 0.82

Epoch: 10 Batch: 400 of 837.  Accuracy: 0.81

Epoch: 10 Batch: 425 of 837.  Accuracy: 0.80



Epoch: 10 Batch: 450 of 837.  Accuracy: 0.83

Epoch: 10 Batch: 475 of 837.  Accuracy: 0.80

Epoch: 10 Batch: 500 of 837.  Accuracy: 0.82



Epoch: 10 Batch: 525 of 837.  Accuracy: 0.81

Epoch: 10 Batch: 550 of 837.  Accuracy: 0.82

Epoch: 10 Batch: 575 of 837.  Accuracy: 0.81



Epoch: 10 Batch: 600 of 837.  Accuracy: 0.82

Epoch: 10 Batch: 625 of 837.  Accuracy: 0.81

Epoch: 10 Batch: 650 of 837.  Accuracy: 0.83



Epoch: 10 Batch: 675 of 837.  Accuracy: 0.83

Epoch: 10 Batch: 700 of 837.  Accuracy: 0.82

Epoch: 10 Batch: 725 of 837.  Accuracy: 0.83



Epoch: 10 Batch: 750 of 837.  Accuracy: 0.83

Epoch: 10 Batch: 775 of 837.  Accuracy: 0.81

Epoch: 10 Batch: 800 of 837.  Accuracy: 0.82



Epoch: 10 Batch: 825 of 837.  Accuracy: 0.83



#########################

VALIDATION SET: Epoch: 10 Accuracy: 0.73

#########################


In [22]:
# create predictions on the test set
full_preds = []
with torch.no_grad():
    for idx, (label, text, offsets) in enumerate(test_loader):
        predicted_label = tweet_embedding_classifier(text, offsets)
        preds = predicted_label.argmax(1)
        full_preds.extend(preds)
y_pred = [a.item() for a in full_preds]

In [23]:
# metrics on test set
y_true = [a[0] for a in test_lemma_docs]
accuracy_score(y_true, y_pred)

0.15977307210238142

In [24]:
# more metrics on test set
f1_score(y_true, y_pred, average='macro')

0.10009860585237265

### LSTM

In [25]:
max_words = 50
train_batch_size = 512
val_batch_size = 1024
test_batch_size = 1024

# instantiate the data loaders
train_loader = DataLoader(lemma_docs, batch_size=train_batch_size, collate_fn=collate_batch_lstm, shuffle=True)
val_loader = DataLoader(dev_lemma_docs, batch_size=val_batch_size, collate_fn=collate_batch_lstm)
test_loader = DataLoader(test_lemma_docs, batch_size=test_batch_size, collate_fn=collate_batch_lstm)

In [26]:
# neural net architecture settings
embedded_len = 64
hidden_dim = 128
n_layers=1

In [27]:
# instantiate the the LSTM
tweet_lstm_classifier = TweetClassificationLSTM(vocab_size=vocab_size, 
                                                embedded_len=embedded_len, 
                                                hidden_dim=hidden_dim, 
                                                n_layers=n_layers, 
                                                output_len=output_dim)

# view the model architecture
tweet_lstm_classifier

TweetClassificationLSTM(
  (embedding_layer): Embedding(77853, 64)
  (lstm): LSTM(64, 128, batch_first=True)
  (fc_1): Linear(in_features=128, out_features=128, bias=True)
  (fc_2): Linear(in_features=128, out_features=10, bias=True)
)

In [28]:
# check that all the model dimensions fit correctly
# and the output is of the desired shape
output_tensor = tweet_lstm_classifier(torch.randint(0, len(vocab), (1024, max_words)), n_layers, hidden_dim)

assert output_tensor.shape[0] == test_batch_size
assert output_tensor.shape[1] == output_dim

In [29]:
# settings for model training run
# set from pipeline.yaml
# epochs = 25
learning_rate = 1e-3

criterion = nn.CrossEntropyLoss()
# tweet_classifier = TweetClassificationLSTM()
optimizer = Adam(tweet_lstm_classifier.parameters(), lr=learning_rate)
# optimizer = torch.optim.SGD(tweet_classifier.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

total_accu = None
full_acc_score = 0
acc_score_counter = []
total_batches_per_epoch = len(train_loader)
batch_log_freq = 25

In [30]:
# training loop - RUN AND EVALUATE THE MODEL

for epoch in range(1, epochs + 1):
    tweet_lstm_classifier.train()
    total_acc, total_count = 0, 0
    for idx, (label, text) in enumerate(train_loader):
        #         zero out the gradient for a new run
        optimizer.zero_grad()
        # create a prediction
        predicted_label = tweet_lstm_classifier(text, n_layers, hidden_dim)
        # calculate loss and run backprop
        # print(label)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(tweet_lstm_classifier.parameters(), 0.1)
        # update weights
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % batch_log_freq == 0 and idx > 0:
            accuracy = total_acc / total_count
            print(f"Epoch: {epoch} Batch: {idx} of {total_batches_per_epoch}.  Accuracy: {accuracy:.2f}\n")
            total_acc, total_count = 0, 0
            
    tweet_lstm_classifier.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text) in enumerate(val_loader):
            predicted_label = tweet_lstm_classifier(text, n_layers, hidden_dim)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    acc_score = total_acc / total_count
    
    if total_accu is not None and total_accu > acc_score:
        scheduler.step()
    else:
        total_accu = acc_score
    
    print(f"\nVALIDATION SET: Epoch: {epoch} Accuracy: {acc_score:.2f}\n")
    acc_improvement = acc_score - full_acc_score
    full_acc_score = acc_score
    if acc_improvement < 0.02:
        acc_score_counter.append(1)
    else:
        acc_score_counter.append(0)
    if len(acc_score_counter) > 5 and sum(acc_score_counter[-3:]) > 2:
        print("No validation accuracy improvement in last 3 epochs, terminating training loop")

Epoch: 1 Batch: 25 of 105.  Accuracy: 0.26



Epoch: 1 Batch: 50 of 105.  Accuracy: 0.28



Epoch: 1 Batch: 75 of 105.  Accuracy: 0.28



Epoch: 1 Batch: 100 of 105.  Accuracy: 0.28




VALIDATION SET: Epoch: 1 Accuracy: 0.28



Epoch: 2 Batch: 25 of 105.  Accuracy: 0.28



Epoch: 2 Batch: 50 of 105.  Accuracy: 0.30



Epoch: 2 Batch: 75 of 105.  Accuracy: 0.33



Epoch: 2 Batch: 100 of 105.  Accuracy: 0.35




VALIDATION SET: Epoch: 2 Accuracy: 0.35



Epoch: 3 Batch: 25 of 105.  Accuracy: 0.36



Epoch: 3 Batch: 50 of 105.  Accuracy: 0.39



Epoch: 3 Batch: 75 of 105.  Accuracy: 0.42



Epoch: 3 Batch: 100 of 105.  Accuracy: 0.43




VALIDATION SET: Epoch: 3 Accuracy: 0.44



Epoch: 4 Batch: 25 of 105.  Accuracy: 0.46



Epoch: 4 Batch: 50 of 105.  Accuracy: 0.48



Epoch: 4 Batch: 75 of 105.  Accuracy: 0.51



Epoch: 4 Batch: 100 of 105.  Accuracy: 0.54




VALIDATION SET: Epoch: 4 Accuracy: 0.54



Epoch: 5 Batch: 25 of 105.  Accuracy: 0.57



Epoch: 5 Batch: 50 of 105.  Accuracy: 0.59



Epoch: 5 Batch: 75 of 105.  Accuracy: 0.57



Epoch: 5 Batch: 100 of 105.  Accuracy: 0.60




VALIDATION SET: Epoch: 5 Accuracy: 0.57



Epoch: 6 Batch: 25 of 105.  Accuracy: 0.62



Epoch: 6 Batch: 50 of 105.  Accuracy: 0.63



Epoch: 6 Batch: 75 of 105.  Accuracy: 0.64



Epoch: 6 Batch: 100 of 105.  Accuracy: 0.64




VALIDATION SET: Epoch: 6 Accuracy: 0.61



Epoch: 7 Batch: 25 of 105.  Accuracy: 0.66



Epoch: 7 Batch: 50 of 105.  Accuracy: 0.68



Epoch: 7 Batch: 75 of 105.  Accuracy: 0.68



Epoch: 7 Batch: 100 of 105.  Accuracy: 0.68




VALIDATION SET: Epoch: 7 Accuracy: 0.64



Epoch: 8 Batch: 25 of 105.  Accuracy: 0.70



Epoch: 8 Batch: 50 of 105.  Accuracy: 0.71



Epoch: 8 Batch: 75 of 105.  Accuracy: 0.71



Epoch: 8 Batch: 100 of 105.  Accuracy: 0.71




VALIDATION SET: Epoch: 8 Accuracy: 0.66



Epoch: 9 Batch: 25 of 105.  Accuracy: 0.74



Epoch: 9 Batch: 50 of 105.  Accuracy: 0.74



Epoch: 9 Batch: 75 of 105.  Accuracy: 0.73



Epoch: 9 Batch: 100 of 105.  Accuracy: 0.74




VALIDATION SET: Epoch: 9 Accuracy: 0.67



Epoch: 10 Batch: 25 of 105.  Accuracy: 0.76



Epoch: 10 Batch: 50 of 105.  Accuracy: 0.76



Epoch: 10 Batch: 75 of 105.  Accuracy: 0.76



Epoch: 10 Batch: 100 of 105.  Accuracy: 0.76




VALIDATION SET: Epoch: 10 Accuracy: 0.68

No validation accuracy improvement in last 3 epochs, terminating training loop


In [31]:
# create predictions on the test set
full_preds = []
with torch.no_grad():
    for idx, (label, text) in enumerate(test_loader):
        predicted_label = tweet_lstm_classifier(text, n_layers, hidden_dim)
        preds = predicted_label.argmax(1)
        full_preds.extend(preds)
y_pred = [a.item() for a in full_preds]

In [32]:
# metrics on test set
y_true = [a[0] for a in test_lemma_docs]
accuracy_score(y_true, y_pred)

0.6941750775117093

In [33]:
# more metrics on test set
f1_score(y_true, y_pred, average='macro')

0.6476137532005054

Notes and resources:
* https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html
* https://coderzcolumn.com/tutorials/artificial-intelligence/pytorch-lstm-for-text-classification-tasks
* https://github.com/pytorch/text/blob/master/examples/legacy_tutorial/migration_tutorial.ipynb