In [2]:
import os
import json
import gdown
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from torchtext import data
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from tqdm.notebook import tqdm

import spacy

# Data Preprocessing

## Exploring the data

In [123]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
train_dict = json.load(open('./train.json', 'r'))
print('Number of train data: ', len(train_dict))
print('Example Data: \n', train_dict[0])

valid_dict = json.load(open('./valid.json', 'r'))
print('\nNumber of valid data: ', len(valid_dict))
print('Example Data: \n', valid_dict[0])

test_dict = json.load(open('./test.json', 'r'))
print('\nNumber of test data: ', len(test_dict))
print('Example Data: \n', test_dict[0])

Number of train data:  5156
Example Data: 
 {'user_id': 0, 'sentence': 'one of my top priorities here at is making sure that the world prioritizes scaling clean technology innovation here is the message i delivered to world leaders earlier today'}

Number of valid data:  800
Example Data: 
 {'user_id': 0, 'sentence': 'i got arrested beaten left bloody and unconscious but i havent given up and you can not give up an inspiring read from civil rights legend'}

Number of test data:  800
Example Data: 
 {'sentence': 'i got arrested beaten left bloody and unconscious but i havent given up and you can not give up an inspiring read from civil rights legend'}


In [116]:
print(train_dict[0])

{'user_id': 0, 'sentence': 'one of my top priorities here at is making sure that the world prioritizes scaling clean technology innovation here is the message i delivered to world leaders earlier today'}


In [14]:
ids = []
count_sen = {}
count = 0
for data in train_dict:
    if data['user_id'] not in ids:
        count_sen.update({data['user_id']: 0})
        ids.append(data['user_id'])
    else:
        count_sen[data['user_id']] += 1
print(ids)
print(count_sen)

[0, 1, 2, 3, 4, 5, 6, 7]
{0: 660, 1: 645, 2: 647, 3: 632, 4: 641, 5: 633, 6: 658, 7: 632}


In [15]:
meta = json.load(open('./meta.json', 'r'))
print(meta.keys())
print(len(meta['tokens']))
print(meta['num_user'])

dict_keys(['tokens', 'num_user'])
13369
8


# Pre-processing with spacy

In [16]:
# tokenizer
spacy_en = spacy.load('en_core_web_sm')

In [242]:
def tokenize(en_text):
    return [tok.text for tok in spacy_en.tokenizer(en_text)]

def tokenize_with_filtering(en_text):
    # eliminate stop_words
    stop_words = spacy.lang.en.stop_words.STOP_WORDS
    # stop_words = ['the', 'an', 'a']
    filtered = []
    tokenized = [tok.text for tok in spacy_en.tokenizer(en_text)]
    for word in tokenized:
        if word not in stop_words:
            filtered.append(word)
    return filtered

In [243]:
# example
ex_sen = train_dict[0]['sentence']
print(ex_sen)
print('normal tokenizing: ', tokenize(ex_sen))
print('filtered tokens: ', tokenize_with_filtering(ex_sen))
print(len(tokenize(train_dict[0]['sentence'])))
print(len(train_dict))
print('max sequence length: ', max(len(tokenize(train_dict[i]['sentence'])) for i in range(len(train_dict))))

one of my top priorities here at is making sure that the world prioritizes scaling clean technology innovation here is the message i delivered to world leaders earlier today
normal tokenizing:  ['one', 'of', 'my', 'top', 'priorities', 'here', 'at', 'is', 'making', 'sure', 'that', 'the', 'world', 'prioritizes', 'scaling', 'clean', 'technology', 'innovation', 'here', 'is', 'the', 'message', 'i', 'delivered', 'to', 'world', 'leaders', 'earlier', 'today']
filtered tokens:  ['priorities', 'making', 'sure', 'world', 'prioritizes', 'scaling', 'clean', 'technology', 'innovation', 'message', 'delivered', 'world', 'leaders', 'earlier', 'today']
29
5156
max sequence length:  62


In [244]:
train_list = []
valid_list = []
test_list = []
for data in train_dict:
    filtered_sentence = ''
    for word in tokenize_with_filtering(data['sentence']):
        filtered_sentence += word + ' '
    filtered_sentence = filtered_sentence[:-1]
    train_list.append([data['user_id'], filtered_sentence])

for data in valid_dict:
    filtered_sentence = ''
    for word in tokenize_with_filtering(data['sentence']):
        filtered_sentence += word + ' '
    filtered_sentence = filtered_sentence[:-1]
    valid_list.append([data['user_id'], filtered_sentence])

for data in test_dict:
    filtered_sentence = ''
    for word in tokenize_with_filtering(data['sentence']):
        filtered_sentence += word + ' '
    filtered_sentence = filtered_sentence[:-1]
    test_list.append([filtered_sentence])

print(train_list[0])
print(valid_list[0])
print(test_list[0])

[0, 'priorities making sure world prioritizes scaling clean technology innovation message delivered world leaders earlier today']
[0, 'got arrested beaten left bloody unconscious nt given inspiring read civil rights legend']
['got arrested beaten left bloody unconscious nt given inspiring read civil rights legend']


# Tokenize with torchtext

In [246]:
tokenizer = get_tokenizer(tokenizer='spacy', language='en_core_web_sm')

# example
print(tokenizer(train_dict[0]['sentence']))
print(len(tokenizer(train_dict[0]['sentence'])))

['one', 'of', 'my', 'top', 'priorities', 'here', 'at', 'is', 'making', 'sure', 'that', 'the', 'world', 'prioritizes', 'scaling', 'clean', 'technology', 'innovation', 'here', 'is', 'the', 'message', 'i', 'delivered', 'to', 'world', 'leaders', 'earlier', 'today']
29


In [247]:
def yield_tokens(data_iter):
    for text in data_iter:
        yield tokenizer(text)

In [248]:
vocab = build_vocab_from_iterator(yield_tokens(meta["tokens"]), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

# example
print(vocab(['here', 'is', 'an', 'example']))

[5480, 25, 516, 4057]


In [249]:
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x)

print(text_pipeline(train_dict[0]['sentence']))
print(label_pipeline(train_dict[0]['user_id']))

[8328, 8243, 7817, 12049, 9223, 5480, 814, 25, 7163, 11509, 31, 11848, 13143, 9226, 10363, 2108, 11752, 6014, 5480, 25, 11848, 7458, 5, 3003, 12005, 13143, 6744, 3584, 12008]
0


In [61]:
word_to_idx = {word[0] : i + 2 for i, word in enumerate(train_dict)} # 0, 1 index will be used for other purposes
word_to_idx['pad'] = 1 # index 1
word_to_idx['unk'] = 0 # index 0

encoded = []
for i in range(len(train_data)):
    token = tokenize_with_filtering(train_data[i]['sentence'])
    temp = []
    for word in token:
        try:
            temp.append(word_to_idx[word])
        except KeyError:
            temp.append(word_to_idx['unk'])
    encoded.append(temp)
print(encoded)

TypeError: enumerate() missing required argument 'iterable' (pos 1)

In [62]:
import torch
from torchtext import data

sentences = data.Field(tokenize = 'spacy', tokenizer_language = 'en')
label = data.LabelField(dtype = torch.float)

ModuleNotFoundError: No module named 'torchtext'

# Define Models

## RNN

In [276]:
def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

In [277]:
class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, num_class):
        super(TextClassificationModel, self).__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_class = num_class
        
        self.embedding = nn.EmbeddingBag(self.vocab_size, self.embed_dim, sparse=True)
        self.rnn = nn.RNN(self.embed_dim, self.hidden_dim, self.num_layers, batch_first=True)
        
        self.fc = nn.Linear(self.hidden_dim, self.num_class)
        
        self.dropout = nn.Dropout(0.4)
        self.layernorm1 = nn.LayerNorm([batch_size, 1, embed_dim])
        
        self.init_weights()
    
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        
    def forward(self, text, offsets):
        # print('embedding shape before view: ', self.embedding(text, offsets).shape)
        embedded = self.embedding(text, offsets).view(batch_size, -1, self.embed_dim)
        # print('embedding shape after view:', embedded.shape)
        embedded = self.layernorm1(self.dropout(embedded))
        
        hidden = torch.zeros(
        self.num_layers, embedded.size(0), self.hidden_dim).to(device)
        
        rnn_out, hidden = self.rnn(embedded, hidden)
        
        out = self.fc(rnn_out[:, -1:]).view([-1, self.num_class])
        return out

In [278]:
train_iter = train_list
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
emsize = 128
hiddim = 32
num_layers = 1

model = TextClassificationModel(vocab_size, emsize, hiddim, num_layers, num_class).to(device)
print(num_class, vocab_size)
print(model)

8 13333
TextClassificationModel(
  (embedding): EmbeddingBag(13333, 128, mode=mean)
  (rnn): RNN(128, 32, batch_first=True)
  (fc): Linear(in_features=32, out_features=8, bias=True)
  (dropout): Dropout(p=0.4, inplace=False)
  (layernorm1): LayerNorm((1, 1, 128), eps=1e-05, elementwise_affine=True)
)


In [279]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn * s
        pp += nn
    return pp

In [280]:
get_n_params(model)

1712328

### Training, evaluation functions

In [180]:
# training function (for RNN)
def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    
    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predicted_label = model(text, offsets)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            print('| epoch {:3d} | {:5d}/{:5d} batches | accuracy {:8.3f}'
                  .format(epoch, idx, len(dataloader), total_acc/total_count))
            total_acc, total_count = 0, 0

In [181]:
# evaluation function (for RNN)
def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0
    
    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count

def evaluate_pred(dataloader):
    model.eval()
    total_acc, total_count = 0, 0
    
    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)

    return total_acc/total_count

## GRU

In [270]:
class TextClassificationGRU(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, num_class):
        super(TextClassificationGRU, self).__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_class = num_class
        
        self.embedding = nn.EmbeddingBag(self.vocab_size, self.embed_dim, sparse=True)
        self.gru = nn.GRU(self.embed_dim, self.hidden_dim, self.num_layers, batch_first=True)
        
        self.fc = nn.Linear(self.hidden_dim, self.num_class)
        
        self.dropout = nn.Dropout(0.4)
        self.layernorm1 = nn.LayerNorm([batch_size, 1, embed_dim])
        
        self.init_weights()
    
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets).view(batch_size, -1, self.embed_dim)
        embedded = self.dropout(embedded)
        embedded = self.layernorm1(embedded)
        
        hidden = torch.zeros(
        self.num_layers, embedded.size(0), self.hidden_dim).to(device)
        
        gru_out, hidden = self.gru(embedded, hidden)
        
        out = self.fc(gru_out[:, -1:]).view([-1, self.num_class])
        return out

In [271]:
train_iter = train_list
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
emsize = 128
hiddim = 32
num_layers = 1

model = TextClassificationGRU(vocab_size, emsize, hiddim, num_layers, num_class).to(device)
print(num_class, vocab_size)
print(model)

8 13333
TextClassificationGRU(
  (embedding): EmbeddingBag(13333, 128, mode=mean)
  (gru): GRU(128, 32, batch_first=True)
  (fc): Linear(in_features=32, out_features=8, bias=True)
  (dropout): Dropout(p=0.4, inplace=False)
  (layernorm1): LayerNorm((1, 1, 128), eps=1e-05, elementwise_affine=True)
)


In [272]:
get_n_params(model)

1722696

In [238]:
class TextClassificationLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, num_class):
        super(TextClassificationLSTM, self).__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_class = num_class
        
        self.embedding = nn.EmbeddingBag(self.vocab_size, self.embed_dim, sparse=True)
        self.lstm = nn.LSTM(self.embed_dim, self.hidden_dim, self.num_layers, batch_first=True)
        
        self.fc = nn.Linear(self.hidden_dim, self.num_class)
        
        self.dropout = nn.Dropout(0.4)
        self.layernorm1 = nn.LayerNorm([batch_size, 1, embed_dim])
        
        self.init_weights()
    
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets).view(batch_size, -1, self.embed_dim)
        embedded = self.dropout(embedded)
        embedded = self.layernorm1(embedded)
        
        hidden = torch.zeros(
        self.num_layers, embedded.size(0), self.hidden_dim).to(device)
        
        lstm_out = self.lstm(embedded, hidden)
        
        out = self.fc(rnn_out[:, -1:]).view([-1, self.num_class])
        return out

In [239]:
train_iter = train_list
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
emsize = 128
hiddim = 32
num_layers = 1

model = TextClassificationLSTM(vocab_size, emsize, hiddim, num_layers, num_class).to(device)
print(num_class, vocab_size)
print(model)

8 13370
TextClassificationLSTM(
  (embedding): EmbeddingBag(13370, 128, mode=mean)
  (lstm): LSTM(128, 32, batch_first=True)
  (fc): Linear(in_features=32, out_features=8, bias=True)
  (dropout): Dropout(p=0.4, inplace=False)
  (layernorm1): LayerNorm((1, 1, 128), eps=1e-05, elementwise_affine=True)
)


In [240]:
get_n_params(model)

1732616

# Train, Eval

In [281]:
# Hyperparameters
epochs = 20
lr = 5
batch_size = 1

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

total_accu = None

# split train, valid, test
train_iter, test_iter = train_list, valid_list

train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)

num_train = int(len(train_dataset) * 0.95)

split_train_, split_valid_ = \
    random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=batch_size,
                              shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=batch_size,
                              shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size,
                             shuffle=True, collate_fn=collate_batch)

# split train, valid
# train_iter = train_list
# valid_iter = valid_list

# train_dataset = to_map_style_dataset(train_iter)
# valid_dataset = to_map_style_dataset(valid_iter)

# train_dataloader = DataLoader(train_dataset, batch_size=batch_size,
#                              shuffle=True, collate_fn=collate_batch)
# valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size,
#                              shuffle=True, collate_fn=collate_batch)

for epoch in range(1, epochs + 1):
    train(train_dataloader)
    accu_val = evaluate(valid_dataloader)
    if total_accu is not None and total_accu > accu_val:
        scheduler.step()
    else:
        total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | valid accuracy {:8.3f} '
          .format(epoch, accu_val))
    print('-' * 59)
    model_path = './saved_models/221221_FT_RNN_SGD_DO_LN_5_v1/'
    if not os.path.isdir(model_path):
        os.mkdir(model_path)
    torch.save(model.state_dict(), os.path.join(model_path, 'epoch_{}.pth'.format(epoch)))

| epoch   1 |   500/ 4898 batches | accuracy    0.194
| epoch   1 |  1000/ 4898 batches | accuracy    0.232
| epoch   1 |  1500/ 4898 batches | accuracy    0.324
| epoch   1 |  2000/ 4898 batches | accuracy    0.342
| epoch   1 |  2500/ 4898 batches | accuracy    0.362
| epoch   1 |  3000/ 4898 batches | accuracy    0.360
| epoch   1 |  3500/ 4898 batches | accuracy    0.454
| epoch   1 |  4000/ 4898 batches | accuracy    0.454
| epoch   1 |  4500/ 4898 batches | accuracy    0.464
-----------------------------------------------------------
| end of epoch   1 | valid accuracy    0.527 
-----------------------------------------------------------
| epoch   2 |   500/ 4898 batches | accuracy    0.499
| epoch   2 |  1000/ 4898 batches | accuracy    0.540
| epoch   2 |  1500/ 4898 batches | accuracy    0.520
| epoch   2 |  2000/ 4898 batches | accuracy    0.628
| epoch   2 |  2500/ 4898 batches | accuracy    0.578
| epoch   2 |  3000/ 4898 batches | accuracy    0.572
| epoch   2 |  3500/ 489

In [282]:
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))

test accuracy    0.675


In [275]:
pth_list = os.listdir(model_path)
accu_test_list = []
for file_path in pth_list:
    path = os.path.join(model_path, file_path)
    model.load_state_dict(torch.load(path))
    accu_test_list.append(evaluate(test_dataloader))
print(accu_test_list)

[0.6325, 0.69625, 0.68625, 0.67, 0.6825, 0.65875, 0.69125, 0.48125, 0.685, 0.65125, 0.59875, 0.69125, 0.65375, 0.6975, 0.6975, 0.69, 0.675, 0.57625, 0.63125, 0.6925]


In [268]:
for idx, accu in enumerate(accu_test_list):
    if max(accu_test_list) == accu:
        print('best epoch: ', idx + 1)
        print('best acc_test: ', max(accu_test_list))

best epoch:  7
best acc_test:  0.685


In [144]:
pred = evaluate_pred(test_dataloader)
print('test accuracy {:8.3f}'.format(pred))

test accuracy    0.724


In [284]:
for i in range(len(valid_dict)):
    print(valid_dict[i])

{'user_id': 0, 'sentence': 'i got arrested beaten left bloody and unconscious but i havent given up and you can not give up an inspiring read from civil rights legend'}
{'user_id': 0, 'sentence': 'as weve seen this past year new variants of disease can emerge over time in order to develop new tools to fight the disease we need to identify those variants quickly dr senjuti saha is one expert working to sequence sars cov'}
{'user_id': 0, 'sentence': 'this book written by young surgeon with terminal cancer earned my admirationand tears'}
{'user_id': 0, 'sentence': 'without access to energy the poor are denied all of the benefits that come with power'}
{'user_id': 0, 'sentence': 'be catalyst is financing the foundation of clean economy through four critical technologies in these areas catalyst'}
{'user_id': 0, 'sentence': 'every spring i visit omaha nebraska to catch up with my friend'}
{'user_id': 0, 'sentence': 'my family loved talking about this book at the dinner table and i think your

In [9]:
import torch.nn as nn
from transformers import BertTokenizer, BertForSequenceClassification
import torch.optim as optim

In [10]:
class BERT(nn.Module):
    
    def __init__(self):
        super(BERT, self).__init__()
        
        options_name = "bert_base-uncased"
        self.encoder = BertForSequenceClassification.from_pretrained(options_name)
    
    def forward(self, text, label):
        loss, text_fea = self.encoder(text, labels=label)[:2]
        
        return loss, text_fea

In [None]:
# Save and Load Functions

def save_checkpoint(save_path, model, valid_loss):

    if save_path == None:
        return
    
    state_dict = {'model_state_dict': model.state_dict(),
                  'valid_loss': valid_loss}
    
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')

def load_checkpoint(load_path, model):
    
    if load_path==None:
        return
    
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    
    model.load_state_dict(state_dict['model_state_dict'])
    return state_dict['valid_loss']


def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):

    if save_path == None:
        return
    
    state_dict = {'train_loss_list': train_loss_list,
                  'valid_loss_list': valid_loss_list,
                  'global_steps_list': global_steps_list}
    
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')


def load_metrics(load_path):

    if load_path==None:
        return
    
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    
    return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']

In [None]:
# Training Function

def train(model,
          optimizer,
          criterion = nn.BCELoss(),
          train_loader = train_iter,
          valid_loader = valid_iter,
          num_epochs = 5,
          eval_every = len(train_iter) // 2,
          file_path = destination_folder,
          best_valid_loss = float("Inf")):
    
    # initialize running values
    running_loss = 0.0
    valid_running_loss = 0.0
    global_step = 0
    train_loss_list = []
    valid_loss_list = []
    global_steps_list = []

    # training loop
    model.train()
    for epoch in range(num_epochs):
        for (labels, title, text, titletext), _ in train_loader:
            labels = labels.type(torch.LongTensor)           
            labels = labels.to(device)
            titletext = titletext.type(torch.LongTensor)  
            titletext = titletext.to(device)
            output = model(titletext, labels)
            loss, _ = output

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # update running values
            running_loss += loss.item()
            global_step += 1

            # evaluation step
            if global_step % eval_every == 0:
                model.eval()
                with torch.no_grad():                    

                    # validation loop
                    for (labels, title, text, titletext), _ in valid_loader:
                        labels = labels.type(torch.LongTensor)           
                        labels = labels.to(device)
                        titletext = titletext.type(torch.LongTensor)  
                        titletext = titletext.to(device)
                        output = model(titletext, labels)
                        loss, _ = output
                        
                        valid_running_loss += loss.item()

                # evaluation
                average_train_loss = running_loss / eval_every
                average_valid_loss = valid_running_loss / len(valid_loader)
                train_loss_list.append(average_train_loss)
                valid_loss_list.append(average_valid_loss)
                global_steps_list.append(global_step)

                # resetting running values
                running_loss = 0.0                
                valid_running_loss = 0.0
                model.train()

                # print progress
                print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Valid Loss: {:.4f}'
                      .format(epoch+1, num_epochs, global_step, num_epochs*len(train_loader),
                              average_train_loss, average_valid_loss))
                
                # checkpoint
                if best_valid_loss > average_valid_loss:
                    best_valid_loss = average_valid_loss
                    save_checkpoint(file_path + '/' + 'model.pt', model, best_valid_loss)
                    save_metrics(file_path + '/' + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list)
    
    save_metrics(file_path + '/' + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list)
    print('Finished Training!')

model = BERT().to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-5)

train(model=model, optimizer=optimizer)