In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext import data

# assorted QOL things
import random
from tqdm import tqdm
import time

# my classes
from langhelper import BERTHelper
from classifier import *
import modelfitting

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lang_helper = BERTHelper('bert-base-uncased')

In [None]:
max_input_length = lang_helper.max_tokens

print(max_input_length)

In [None]:
# define text/label data types, used for when we instantiate the torchtext TabularDataset class
TEXT = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = lang_helper.tokenize_and_cut,
                  preprocessing = lang_helper.bert_tokenizer.convert_tokens_to_ids,
                  init_token = lang_helper.bert_tokenizer.cls_token_id,
                  eos_token = lang_helper.bert_tokenizer.sep_token_id,
                  pad_token = lang_helper.bert_tokenizer.pad_token_id,
                  unk_token = lang_helper.bert_tokenizer.unk_token_id)

LABEL = data.LabelField(dtype = torch.float, use_vocab=False)

In [None]:
# import data, apply TEXT/LABEL data types to the 'headline'/'is_sarcastic' fields (respectively, and create train and test datasets. Tochtext is pretty good!
headlines_train, headlines_test = data.TabularDataset(
    path='./data/Sarcasm_Headlines_Dataset_v2.json', format='json',
    fields={'headline': ('text', TEXT),
            'is_sarcastic': ('label', LABEL)}).split(split_ratio=0.85, random_state = random.seed(1234))

# split train dataset into train + validation
headlines_train, headlines_valid = headlines_train.split(random_state = random.seed(1234))
print(vars(headlines_train.examples[5]))

In [None]:
# from the data objects we just created we instantiate the bucketiterator class, which is the last preprocessing step we'll take with the data.
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (headlines_train, headlines_valid, headlines_test), 
    batch_size = 16,
    sort_key=lambda x: len(x.text), # the BucketIterator needs to be told what function it should use to group the data.
    sort_within_batch=False,
    device = device)

In [None]:

                 bert_helper,
                 hidden_dim,
                 l1_dim,
                 l2_dim,
                 output_dim,
                 n_layers,
                 dropout

In [None]:
# instantiate our model
HIDDEN_DIM = 256
OUTPUT_DIM = 1
N_LAYERS = 3
DROPOUT = 0.40
model = BERTGRUSentimentPerc(lang_helper,
                             HIDDEN_DIM,
                             100,
                             50,
                             OUTPUT_DIM,
                             N_LAYERS,
                             DROPOUT).to(device)

In [None]:
# we'll use Adam for optimization, and our loss function will be BCE with logit loss
optimizer = optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss().to(device)

In [None]:
N=40

modelfitting.fit(n_epochs=N, model=model, train_iter=train_iterator, valid_iter=valid_iterator, optimizer=optimizer, criterion=criterion, model_name='GRU perceptron 040 drop 3 layer')

In [None]:
# let's test out a few sample headlines:
# sarcastic one from the onion, 5/10/2020 (https://www.theonion.com/experts-warn-unemployment-rate-could-soon-rise-to-ameri-1843348378)
print(single_eval(model, 'Experts Warn Unemployment Rate Could Soon Rise To America Is The Greatest Country In The World', lang_helper, device))

# real one from NPR, 5/10/2020 (https://www.npr.org/2020/05/10/852943513/the-people-flying-during-the-pandemic-and-how-airlines-are-trying-to-protect-the)
print(single_eval(model, 'The People Flying During The Pandemic And How Airlines Are Trying To Protect Them', lang_helper, device))

In [None]:
model.load_state_dict(torch.load('GRU 040 drop 3 layer.pt'))

In [None]:
# how does it look on our test dataset?
test_loss, test_acc, test_precision, test_recall, test_f1 = modelfitting.evaluate(model, test_iterator, criterion)

print(f'Test Loss: {test_loss :.3f} | Test Acc: {test_acc*100 :.2f}')
print(f'Test Precision: {test_precision :.3f} | Test Recall: {test_recall*100 :.2f} | Test F1: {test_f1*100 :.2f}')

About 90% accuracy on the test set. Not too bad considering most of the Kaggle front page solutions have validation set accuracies in the mid 80s!