In [1]:
!pip install pytorch-pretrained-bert
!pip install seqeval

Collecting pytorch-pretrained-bert
[?25l  Downloading https://files.pythonhosted.org/packages/d7/e0/c08d5553b89973d9a240605b9c12404bcf8227590de62bae27acbcfe076b/pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123kB)
[K     |██▋                             | 10kB 26.3MB/s eta 0:00:01[K     |█████▎                          | 20kB 5.8MB/s eta 0:00:01[K     |████████                        | 30kB 8.2MB/s eta 0:00:01[K     |██████████▋                     | 40kB 5.7MB/s eta 0:00:01[K     |█████████████▎                  | 51kB 6.9MB/s eta 0:00:01[K     |███████████████▉                | 61kB 8.1MB/s eta 0:00:01[K     |██████████████████▌             | 71kB 9.2MB/s eta 0:00:01[K     |█████████████████████▏          | 81kB 10.2MB/s eta 0:00:01[K     |███████████████████████▉        | 92kB 11.3MB/s eta 0:00:01[K     |██████████████████████████▌     | 102kB 9.4MB/s eta 0:00:01[K     |█████████████████████████████▏  | 112kB 9.4MB/s eta 0:00:01[K     |████████████████████

In [0]:
import numpy as np
import pandas as pd
import nltk

In [0]:
class SentenceGetter(object):
    
    def __init__(self, dataset):
        self.n_sent = 1
        self.dataset = dataset
        self.empty = False
        agg_func = lambda s: [(w,p, t) for w,p, t in zip(s["word"].values.tolist(),
                                                       s['pos'].values.tolist(),
                                                        s["tag"].values.tolist())]
        self.grouped = self.dataset.groupby("sentence_idx").apply(agg_func)
        self.sentences = [s for s in self.grouped]
    
    def get_next(self):
        try:
            s = self.grouped["Sentence: {}".format(self.n_sent)]
            self.n_sent += 1
            return s
        except:
            return None

In [0]:
dframe = pd.read_csv("train_with_pos_table.csv", error_bad_lines=False)
test_dframe = dframe = pd.read_csv("test_with_pos_table.csv", error_bad_lines=False)

In [19]:
dframe.head()

Unnamed: 0,word,pos,tag,sentence_idx
0,Phonegap,NN,B-Fram,1
1,cordova,NN,I-Fram,1
2,android,NN,B-Plat,1
3,4.4,CD,I-Plat,1
4,FileTransfer,NN,B-API,1


In [0]:
getter = SentenceGetter(dframe)

In [22]:
sentences = [" ".join([str(s[0]) for s in sent]) for sent in getter.sentences]
sentences[0]

'Phonegap cordova android 4.4 FileTransfer upload SSL not working'

In [23]:
labels = [[s[2] for s in sent] for sent in getter.sentences]
print(labels[0])

['B-Fram', 'I-Fram', 'B-Plat', 'I-Plat', 'B-API', 'O', 'B-Stan', 'O', 'O']


In [0]:
tags_vals = list(set(dframe["tag"].values))
tag2idx = {t: i for i, t in enumerate(tags_vals)}

In [26]:
import torch
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from pytorch_pretrained_bert import BertTokenizer, BertConfig
from pytorch_pretrained_bert import BertForTokenClassification, BertAdam
from tqdm import tqdm, trange

Using TensorFlow backend.


In [0]:
MAX_LEN = 40
bs = 64

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0) 

'Tesla T4'

In [0]:
tokenized_texts = [tokenizer.tokenize(sent) for sent in sentences]
input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts],
                          maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")

In [0]:
tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in labels],
                     maxlen=MAX_LEN, value=tag2idx["O"], padding="post",
                     dtype="long", truncating="post")

In [0]:
attention_masks = [[float(i>0) for i in ii] for ii in input_ids]

In [0]:
tr_inputs, val_inputs, tr_tags, val_tags = train_test_split(input_ids, tags, 
                                                            random_state=2018, test_size=0.1)
tr_masks, val_masks, _, _ = train_test_split(attention_masks, input_ids,
                                             random_state=2018, test_size=0.1)

tr_inputs = torch.tensor(tr_inputs)
val_inputs = torch.tensor(val_inputs)
tr_tags = torch.tensor(tr_tags)
val_tags = torch.tensor(val_tags)
tr_masks = torch.tensor(tr_masks)
val_masks = torch.tensor(val_masks)

In [0]:
train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=bs)

valid_data = TensorDataset(val_inputs, val_masks, val_tags)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=bs)

In [45]:
model = BertForTokenClassification.from_pretrained("bert-base-uncased", num_labels=len(tag2idx))

100%|██████████| 407873900/407873900 [00:15<00:00, 26131349.65B/s]


In [46]:
model.cuda()

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
      

In [0]:
FULL_FINETUNING = True
if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters()) 
    optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]
optimizer = Adam(optimizer_grouped_parameters, lr=3e-5)

In [0]:
from seqeval.metrics import f1_score

def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=2).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [51]:
epochs = 20
max_grad_norm = 1.0

for _ in trange(epochs, desc="Epoch"):
    # TRAIN loop
    model.train()
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    for step, batch in enumerate(train_dataloader):
        # add batch to gpu
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        # forward pass
        loss = model(b_input_ids, token_type_ids=None,
                     attention_mask=b_input_mask, labels=b_labels)
        # backward pass
        loss.backward()
        # track train loss
        tr_loss += loss.item()
        nb_tr_examples += b_input_ids.size(0)
        nb_tr_steps += 1
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
        # update parameters
        optimizer.step()
        model.zero_grad()
    # print train loss per epoch
    print("Train loss: {}".format(tr_loss/nb_tr_steps))
    # VALIDATION on validation set
    model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    predictions , true_labels = [], []
    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        with torch.no_grad():
            tmp_eval_loss = model(b_input_ids, token_type_ids=None,
                                  attention_mask=b_input_mask, labels=b_labels)
            logits = model(b_input_ids, token_type_ids=None,
                           attention_mask=b_input_mask)
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
        true_labels.append(label_ids)
        
        tmp_eval_accuracy = flat_accuracy(logits, label_ids)
        
        eval_loss += tmp_eval_loss.mean().item()
        eval_accuracy += tmp_eval_accuracy
        
        nb_eval_examples += b_input_ids.size(0)
        nb_eval_steps += 1
    eval_loss = eval_loss/nb_eval_steps
    print("Validation loss: {}".format(eval_loss))
    print("Validation Accuracy: {}".format(eval_accuracy/nb_eval_steps))
    pred_tags = [tags_vals[p_i] for p in predictions for p_i in p]
    valid_tags = [tags_vals[l_ii] for l in true_labels for l_i in l for l_ii in l_i]
    print("F1-Score: {}".format(f1_score(pred_tags, valid_tags)))

Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Train loss: 0.08027242045059349


Epoch:   5%|▌         | 1/20 [00:38<12:07, 38.28s/it]

Validation loss: 0.10240356810390949
Validation Accuracy: 0.9818675321691175
F1-Score: 0.30339321357285437
Train loss: 0.0718140850464503


Epoch:  10%|█         | 2/20 [01:16<11:27, 38.21s/it]

Validation loss: 0.10199398268014193
Validation Accuracy: 0.9770191865808824
F1-Score: 0.29559748427672955
Train loss: 0.06339871079068292


Epoch:  15%|█▌        | 3/20 [01:53<10:46, 38.03s/it]

Validation loss: 0.10031555267050862
Validation Accuracy: 0.9810862821691178
F1-Score: 0.33273056057866185
Train loss: 0.05612682175794334


Epoch:  20%|██        | 4/20 [02:31<10:08, 38.01s/it]

Validation loss: 0.11208972986787558
Validation Accuracy: 0.9773868336397058
F1-Score: 0.3211009174311927
Train loss: 0.050211004433081005


Epoch:  25%|██▌       | 5/20 [03:09<09:29, 37.95s/it]

Validation loss: 0.11054865457117558
Validation Accuracy: 0.9790584788602941
F1-Score: 0.33993399339933994
Train loss: 0.04436129582763621


Epoch:  30%|███       | 6/20 [03:47<08:51, 37.95s/it]

Validation loss: 0.11924935597926378
Validation Accuracy: 0.973196231617647
F1-Score: 0.31016042780748665
Train loss: 0.04178440288612337


Epoch:  35%|███▌      | 7/20 [04:25<08:13, 37.94s/it]

Validation loss: 0.11156334076076746
Validation Accuracy: 0.9782398897058824
F1-Score: 0.3290322580645161
Train loss: 0.03703034087789781


Epoch:  40%|████      | 8/20 [05:03<07:34, 37.91s/it]

Validation loss: 0.12341293506324291
Validation Accuracy: 0.9770306755514705
F1-Score: 0.3458646616541354
Train loss: 0.03262105174927098


Epoch:  45%|████▌     | 9/20 [05:41<06:56, 37.87s/it]

Validation loss: 0.11081618815660477
Validation Accuracy: 0.9815860523897059
F1-Score: 0.36363636363636365
Train loss: 0.028577924162770312


Epoch:  50%|█████     | 10/20 [06:18<06:18, 37.83s/it]

Validation loss: 0.1263092579320073
Validation Accuracy: 0.9779957490808823
F1-Score: 0.3338582677165355
Train loss: 0.02424321219212178


Epoch:  55%|█████▌    | 11/20 [06:56<05:40, 37.82s/it]

Validation loss: 0.13179936632514
Validation Accuracy: 0.9763499540441176
F1-Score: 0.31578947368421056
Train loss: 0.021586414454787067


Epoch:  60%|██████    | 12/20 [07:34<05:02, 37.81s/it]

Validation loss: 0.14338948670774698
Validation Accuracy: 0.9658059512867647
F1-Score: 0.270996640537514
Train loss: 0.019826238403435458


Epoch:  65%|██████▌   | 13/20 [08:12<04:24, 37.81s/it]

Validation loss: 0.1392065705731511
Validation Accuracy: 0.973675896139706
F1-Score: 0.3125827814569536
Train loss: 0.01913055473458812


Epoch:  70%|███████   | 14/20 [08:50<03:46, 37.80s/it]

Validation loss: 0.14692657999694347
Validation Accuracy: 0.9752240349264705
F1-Score: 0.3218390804597701
Train loss: 0.015995726500863606


Epoch:  75%|███████▌  | 15/20 [09:27<03:08, 37.79s/it]

Validation loss: 0.15106365270912647
Validation Accuracy: 0.974224494485294
F1-Score: 0.3047619047619048
Train loss: 0.013570954017764465


Epoch:  80%|████████  | 16/20 [10:05<02:31, 37.79s/it]

Validation loss: 0.1545984959229827
Validation Accuracy: 0.9754308363970589
F1-Score: 0.3033033033033033
Train loss: 0.012828326854629047


Epoch:  85%|████████▌ | 17/20 [10:43<01:53, 37.79s/it]

Validation loss: 0.1536646643653512
Validation Accuracy: 0.9712201286764705
F1-Score: 0.27828348504551365
Train loss: 0.011305597924740252


Epoch:  90%|█████████ | 18/20 [11:21<01:15, 37.79s/it]

Validation loss: 0.15842574555426836
Validation Accuracy: 0.9747213924632352
F1-Score: 0.30837004405286345
Train loss: 0.010358952377414838


Epoch:  95%|█████████▌| 19/20 [11:58<00:37, 37.77s/it]

Validation loss: 0.1718474105000496
Validation Accuracy: 0.9729406020220588
F1-Score: 0.3095238095238095
Train loss: 0.01050201521226854


Epoch: 100%|██████████| 20/20 [12:36<00:00, 37.79s/it]

Validation loss: 0.1559152901172638
Validation Accuracy: 0.9714642693014706
F1-Score: 0.2947903430749682





In [52]:
model.eval()
predictions = []
true_labels = []
eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0
for batch in valid_dataloader:
    batch = tuple(t.to(device) for t in batch)
    b_input_ids, b_input_mask, b_labels = batch

    with torch.no_grad():
        tmp_eval_loss = model(b_input_ids, token_type_ids=None,
                              attention_mask=b_input_mask, labels=b_labels)
        logits = model(b_input_ids, token_type_ids=None,
                       attention_mask=b_input_mask)
        
    logits = logits.detach().cpu().numpy()
    predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
    label_ids = b_labels.to('cpu').numpy()
    true_labels.append(label_ids)
    tmp_eval_accuracy = flat_accuracy(logits, label_ids)

    eval_loss += tmp_eval_loss.mean().item()
    eval_accuracy += tmp_eval_accuracy

    nb_eval_examples += b_input_ids.size(0)
    nb_eval_steps += 1

pred_tags = [[tags_vals[p_i] for p_i in p] for p in predictions]
valid_tags = [[tags_vals[l_ii] for l_ii in l_i] for l in true_labels for l_i in l ]
print("Validation loss: {}".format(eval_loss/nb_eval_steps))
print("Validation Accuracy: {}".format(eval_accuracy/nb_eval_steps))
print("Validation F1-Score: {}".format(f1_score(pred_tags, valid_tags)))

Validation loss: 0.1559152901172638
Validation Accuracy: 0.9714642693014706
Validation F1-Score: 0.2947903430749682
