In [1]:
import pandas as pd
import numpy as np
import pickle

import torch

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForSequenceClassification

from sklearn.metrics import accuracy_score, matthews_corrcoef, confusion_matrix, f1_score, precision_score, recall_score, roc_auc_score

from matplotlib import pyplot as plt

import random
import math
import time

In [2]:
df = pd.read_csv('assets/df_cleaned.csv')

In [3]:
df.head()

Unnamed: 0,tweet_id,tweet_text,class_label,file_name,event,year,event_type,data_type,tweet_text_cleaned,hashtags,class_label_id,processed_text_length
0,1.065845e+18,"Camp Fire leaves over 13,000 without homes thi...",displaced_people_and_evacuations,california_wildfires_2018_dev.tsv,california_wildfires_2018,2018,fire,dev,camp fire leaves over without homes this thank...,[],0,123.0
1,1.061321e+18,"So in a truly strange world, we have @RealJame...",not_humanitarian,california_wildfires_2018_dev.tsv,california_wildfires_2018,2018,fire,dev,so in a truly strange world we have playing th...,[],1,148.0
2,1.063536e+18,66 people have died and more than 600 are stil...,injured_or_dead_people,california_wildfires_2018_dev.tsv,california_wildfires_2018,2018,fire,dev,people have died and more than are still missi...,"['californiawildfires', 'cafire', 'campfire', ...",2,61.0
3,1.062711e+18,BBC News - California wildfires: Nine dead and...,injured_or_dead_people,california_wildfires_2018_dev.tsv,california_wildfires_2018,2018,fire,dev,bbc news california wildfires nine dead and mo...,[],2,63.0
4,1.064808e+18,Death toll in California’s #CampFire has climb...,injured_or_dead_people,california_wildfires_2018_dev.tsv,california_wildfires_2018,2018,fire,dev,death toll in californias has climbed to the n...,['campfire'],2,219.0


In [4]:
train_size = int(0.8 * len(df))
val_size = int(0.1 * len(df))
test_sie = len(df) - train_size - val_size
df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42), [train_size, train_size + val_size])

In [5]:
len(df_train), len(df_val), len(df_test)

(61186, 7648, 7649)

In [6]:
class BertModel(nn.Module):

    def __init__(self, n_class, bert_config='bert-base-uncased'):

        super(BertModel, self).__init__()

        self.n_class = n_class
        self.bert_config = bert_config
        self.bert = BertForSequenceClassification.from_pretrained(self.bert_config, num_labels=self.n_class)
        self.tokenizer = BertTokenizer.from_pretrained(self.bert_config)

    def forward(self, sents):
        
        sents_prepend = self.prepend_token(sents, ['CLS'])
        sents_tensor, masks_tensor = sents2tensor(self.tokenizer, sents)
        pre_softmax = self.bert(input_ids=sents_tensor, attention_mask=masks_tensor)

        return pre_softmax
    
    def add_special_token(self, sents_token, cls_token='[CLS]', sep_token='[SEP]'):
        return [[cls_token] + sent + [sep_token] for sent in sents_token]
    
    def pad_sents(self, sents_token, pad_token):
        sents_padded = []
        max_len = max(len(s) for s in sents_token)
        for s in sents_token:
            padded = [pad_token] * max_len
            padded[:len(s)] = s
            sents_padded.append(padded)
        return sents_padded
    
    def sents2tensor(self, sents):
        tokens_list = [self.tokenizer.tokenize(sent) for sent in sents]
        tokens_list_added = self.add_special_token(tokens_list)
#         sents_lengths = [len(tokens) for tokens in tokens_list_added]
        # tokens_sents_zip = zip(tokens_list, sents_lengths)
        # tokens_sents_zip = sorted(tokens_sents_zip, key=lambda x: x[1], reverse=True)
        # tokens_list, sents_lengths = zip(*tokens_sents_zip)
        tokens_list_padded = pad_sents(tokens_list_added, '[PAD]')
#         sents_lengths = torch.tensor(sents_lengths, device=device)

        masks = []
        for tokens in tokens_list_padded:
            mask = [0 if token=='[PAD]' else 1 for token in tokens]
            masks.append(mask)
        masks_tensor = torch.tensor(masks, dtype=torch.long, device=device)
        tokens_id_list = [self.tokenizer.convert_tokens_to_ids(tokens) for tokens in tokens_list_padded]
        sents_tensor = torch.tensor(tokens_id_list, dtype=torch.long, device=device)
        return sents_tensor, masks_tensor
    
    def load(model_path):
        params = torch.load(model_path)
        args = params['args']
        model = BertModel(**args)
        model.load_state_dict(params['state_dict'])

        return model

    def save(self, model_path):
        print('save model parameters to [%s]' % model_path)

        params = {
            'args': dict(bert_config=self.bert_config, n_class=self.n_class),
            'state_dict': self.state_dict()
        }

        torch.save(params, model_path)

In [7]:
model = BertModel(n_class=10, bert_config='C:\\bert\\bert-base-uncased')