In [1]:
import copy

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils import data

import tensorflow as tf
from tqdm import tqdm,trange
from seqeval.scheme import IOB2
from seqeval.metrics import classification_report,f1_score
from transformers import *
from sadice import SelfAdjDiceLoss

In [2]:
import pandas as pd
import numpy as np
import json
from tqdm import tqdm
import warnings
import pickle
import random

warnings.filterwarnings('ignore')

data_train = pd.read_csv("IOB2_Data/BC_train_IOB2_all.txt",sep = '\t', na_filter=False)
data_dev = pd.read_csv("IOB2_Data/BC_dev_IOB2_all.txt",sep = '\t', na_filter=False)

In [3]:
class SentenceGetter(object):
    
    def __init__(self, data):
        self.n_sent = 1
        self.data = data
        self.empty = False
        agg_func = lambda s: [(w, t) for w, t in zip(s["Word"].values.tolist(),
                                                     s["Tag"].values.tolist())]
        self.grouped = self.data.groupby("Sentence#").apply(agg_func)
        self.sentences = [s for s in self.grouped]
    
    def get_next(self):
        try:
            s = self.grouped[self.n_sent]
            self.n_sent += 1
            return s
        except:
            return None

In [4]:
getter = SentenceGetter(data_train)
dev_getter = SentenceGetter(data_dev)

In [5]:
sentences = [[word[0] for word in sentence] for sentence in getter.sentences]
dev_sentences = [[word[0] for word in sentence] for sentence in dev_getter.sentences]

In [6]:
labels = [[s[1] for s in sent] for sent in getter.sentences]
dev_labels = [[s[1] for s in sent] for sent in dev_getter.sentences]

In [7]:
tag_values = list(set(data_train["Tag"].values))
tag_values.append("PAD")
tag_values

['B-', 'O', 'I-', 'PAD']

In [8]:
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

In [9]:
def tokenize_and_preserve_labels(sentence, text_labels):
    tokenized_sentence = []
    labels = []

    for word, label in zip(sentence, text_labels):
        tokenized_word = tokenizer.tokenize(word)
        n_subwords = len(tokenized_word)

        tokenized_sentence.extend(tokenized_word)
        if n_subwords > 1 and 'B-' in label:
            labels.extend([label])
            _ = 'I-' + label.split('B-')[1]
            if _ not in tag_values:
                tag_values.append(_)
                print(_)
            labels.extend([_] * (n_subwords-1))
        else:
            labels.extend([label] * n_subwords)
    return tokenized_sentence, labels

In [10]:
tokenized_texts_and_labels = [
    tokenize_and_preserve_labels(sent, labs)
    for sent, labs in zip(sentences, labels)
]
print('done')
dev_tokenized_texts_and_labels = [
    tokenize_and_preserve_labels(sent, labs)
    for sent, labs in zip(dev_sentences, dev_labels)
]

done


In [11]:
tag2idx = {t: i for i, t in enumerate(tag_values)}
idx2tag = {i: t for i, t in enumerate(tag_values)}
idx2tag

{0: 'B-', 1: 'O', 2: 'I-', 3: 'PAD'}

In [12]:
tokenized_texts = [token_label_pair[0] for token_label_pair in tokenized_texts_and_labels]
tokenized_labels = [token_label_pair[1] for token_label_pair in tokenized_texts_and_labels]

dev_tokenized_texts = [token_label_pair[0] for token_label_pair in dev_tokenized_texts_and_labels]
dev_tokenized_labels = [token_label_pair[1] for token_label_pair in dev_tokenized_texts_and_labels]

In [13]:
input_ids = [tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts]
dev_input_ids = [tokenizer.convert_tokens_to_ids(txt) for txt in dev_tokenized_texts]
tags = [[tag2idx.get(l) for l in lab] for lab in tokenized_labels]
dev_tags = [[tag2idx.get(l,tag2idx['O']) for l in lab] for lab in dev_tokenized_labels]

In [14]:
from torch.utils.data import Dataset
class BertNerDataset(Dataset):
    def __init__(self,sentences,labels, word_pad_idx, tag_pad_idx, max_len = 500):
        self.sentences = sentences
        self.labels = labels
        self.word_pad_idx = word_pad_idx
        self.tag_pad_idx = tag_pad_idx
        self.max_len = max_len
    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, index):
        return (self.sentences[index],self.labels[index])
        
    def collate_fn(self, datasets):
        sentences = [dataset[0] for dataset in datasets]
        labels = [dataset[1] for dataset in datasets]
        max_sent = max([len(data) for data in sentences])
        max_len = max([min(len(sentence), self.max_len) for sentence in sentences])
        pad_sentence = []
        pad_label = []
        for sentence,label in zip(sentences,labels):
            
            if len(sentence) > max_len:
#                 print('asd')
                pad_sentence.append(sentence[:max_len])
                pad_label.append(label[:max_len])
                attention_masks = [[float(i != 0.0) for i in ii] for ii in pad_sentence]
            else:
#                 print('zxc')
                pad_sentence.append(sentence+[self.word_pad_idx]*(max_len-len(sentence)))
                pad_label.append(label+[self.tag_pad_idx]*(max_len-len(label)))
                attention_masks = [[float(i != 0.0) for i in ii] for ii in pad_sentence]
        return torch.LongTensor(pad_sentence), torch.LongTensor(pad_label),torch.tensor(attention_masks)

In [16]:
from transformers import get_linear_schedule_with_warmup
bs = 32

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 15
max_grad_norm = 3.0

tr_dataset = BertNerDataset(input_ids,tags,tokenizer.convert_tokens_to_ids('[PAD]'),tag2idx['PAD'])
va_dataset = BertNerDataset(dev_input_ids,dev_tags,tokenizer.convert_tokens_to_ids('[PAD]'),tag2idx['PAD'])

train_dataloader = DataLoader(tr_dataset, batch_size=bs,
                            collate_fn=tr_dataset.collate_fn)
valid_dataloader = DataLoader(va_dataset, batch_size=bs,
                            collate_fn=va_dataset.collate_fn)

In [17]:
class EmbeddedRnn(nn.Module):
    def __init__(self, hidden_dim, output_vocab, n_layer=1):
        super(EmbeddedRnn, self).__init__()
        self.n_layer = n_layer
        self.embedding_size = 3072
        self.hidden_dim = hidden_dim
        self.bert = BertForTokenClassification.from_pretrained('bert-base-cased',output_hidden_states =True)
#         self.bert.eval()
        self.lstm = nn.LSTM(self.embedding_size, hidden_dim, num_layers=n_layer,batch_first = True, bidirectional=True)
        self.fc1 = nn.Linear(2 * hidden_dim, output_vocab)
        self.softmax = nn.LogSoftmax(dim=-1)
        self.relu = nn.ReLU()
    
    def embedding(self,x,att_mask):
        sent_embeddings = []
        _,hidden_states = self.bert(x,att_mask)
        token_embeddings = torch.stack(hidden_states[:-1],dim = 0)
#         with torch.no_grad():
        token_embeddings = token_embeddings.permute(1,2,0,3)
        for tks in token_embeddings:
            token_vecs = []
            for tk in tks:
                cat_vec = torch.cat((tk[-1] , tk[-2] , tk[-3] , tk[-4]) , dim = 0)
                token_vecs.append(cat_vec)
            token_vecs = torch.stack(token_vecs , 0)
            sent_embeddings.append(token_vecs)
        sent_embeddings = torch.stack(sent_embeddings , 0)
        return sent_embeddings    

    def forward(self, x, x_att):
        embedded = self.embedding(x,x_att)
#         print(embedded.shape)
        output, hidden = self.lstm(embedded)
#         print(output.shape)
        output = self.fc1(output)
        output = self.softmax(output)
        return output,hidden
    
    def initHidden(self, batch_size):
        hidden = Variable(torch.zeros(2 * self.n_layer, batch_size, self.hidden_dim))
        cell = Variable(torch.zeros(2 * self.n_layer, batch_size, self.hidden_dim))
        return [hidden, cell]
#         return hidden

model = EmbeddedRnn(300 , len(tag2idx))

model = model.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

In [18]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=5, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

In [24]:
optimizer = AdamW(
    model.parameters(),
    lr=3e-5,
    eps=1e-8
)
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )

all_dataloader = {
    'train' : train_dataloader,
    'valid' : valid_dataloader,
    }
for i in all_dataloader:
    print(i)

# criterion = SelfAdjDiceLoss(reduction="none")
# criterion = nn.CrossEntropyLoss()
criterion = FocalLoss(gamma = 5)

train
valid


In [25]:
from transformers import get_linear_schedule_with_warmup,get_cosine_schedule_with_warmup

epochs = 20
TAG_PAD_IDX = tag2idx['PAD']

In [None]:
records = {
    'loss':[],
    'F1':[],
}
model.train(True)
f_slist = []
for epoch in tqdm(range(epochs)):
    for loader in all_dataloader:
        train_loss,valid_loss = [],[]
        print('')
        predictions , true_labels , x_  = [],[],[]
        for x, y ,x_attn in all_dataloader[loader]:
            optimizer.zero_grad()
            x = x.to(device)
            y = y.to(device)
            x_attn = x_attn.to(device)
            hidden = model.initHidden(x.size(0))
            hidden[0] = hidden[0].to(device)
            hidden[1] = hidden[1].to(device)
            asd = y.shape[1]
            output, hidden = model(x, x_attn)
            predictions.extend(np.argmax(output.detach().cpu().numpy(), axis=2))
            for i in y.detach().cpu().numpy():
                _ = []
                for j in i:
                    if j != TAG_PAD_IDX:
                        _.append(idx2tag[j])
                true_labels.append(_)

            output = output.reshape(-1,output.shape[-1])
            y = y.reshape(-1)
    #             print(output.shape,y.shape)
            
            loss = criterion(output,y)
#             loss = loss.reshape(-1, asd).mean(-1).mean()
#             if loader == 'train':
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_loss.append(loss.cpu().item()) 
        print(f'loss : {np.mean(np.array(train_loss))}')
        pred_tags = [[idx2tag[p_i] for p_i, l_i in zip(p, l) if l_i != "PAD"] for p, l in zip(predictions, true_labels)]
        f_ = f1_score(true_labels,pred_tags, scheme = IOB2)
        print(f_)
        if loader == 'valid':
            f_slist.append(f_)
            print(classification_report(true_labels,pred_tags , scheme = IOB2 ))

  0%|                                                                                           | 0/20 [00:00<?, ?it/s]


loss : 0.0001313879938827759
0.09859154929577464

loss : 0.00014161245200378631
0.1732283464566929


  5%|███▉                                                                           | 1/20 [19:45<6:15:21, 1185.37s/it]

              precision    recall  f1-score   support

          AD       0.00      0.00      0.00         0
           _       0.21      0.21      0.21       105

   micro avg       0.15      0.21      0.17       105
   macro avg       0.10      0.10      0.10       105
weighted avg       0.21      0.21      0.21       105


loss : 1.7610060100560087e-05
0.6313645621181263

loss : 0.00012007343481244592
0.4545454545454546


 10%|███████▉                                                                       | 2/20 [39:35<5:56:25, 1188.10s/it]

              precision    recall  f1-score   support

          AD       0.00      0.00      0.00         0
           _       0.54      0.62      0.58       105

   micro avg       0.36      0.62      0.45       105
   macro avg       0.27      0.31      0.29       105
weighted avg       0.54      0.62      0.58       105


loss : 8.77713249665112e-06
0.7250996015936255

loss : 3.269035245436711e-05
0.6437768240343347


 15%|███████████▊                                                                   | 3/20 [59:27<5:37:07, 1189.86s/it]

              precision    recall  f1-score   support

          AD       0.00      0.00      0.00         0
           _       0.64      0.71      0.67       105

   micro avg       0.59      0.71      0.64       105
   macro avg       0.32      0.36      0.34       105
weighted avg       0.64      0.71      0.67       105


loss : 5.260741506245309e-06
0.7857142857142856

loss : 4.52675972896867e-05
0.75


 20%|███████████████▍                                                             | 4/20 [1:19:17<5:17:20, 1190.01s/it]

              precision    recall  f1-score   support

          AD       0.00      0.00      0.00         0
           _       0.72      0.83      0.77       105

   micro avg       0.69      0.83      0.75       105
   macro avg       0.36      0.41      0.39       105
weighted avg       0.72      0.83      0.77       105


loss : 8.587155102479045e-06
0.7229862475442043

loss : 5.335020009655144e-06
0.7906976744186046


 25%|███████████████████▎                                                         | 5/20 [1:39:01<4:56:57, 1187.83s/it]

              precision    recall  f1-score   support

          AD       0.00      0.00      0.00         0
           _       0.78      0.81      0.79       105

   micro avg       0.77      0.81      0.79       105
   macro avg       0.39      0.40      0.40       105
weighted avg       0.78      0.81      0.79       105


loss : 3.141350787784917e-06
0.886128364389234



In [31]:
np.round(np.array(f_slist) ,2)

array([0.  , 0.51, 0.47, 0.44, 0.6 , 0.58, 0.06, 0.55, 0.52, 0.61, 0.58,
       0.66, 0.62, 0.59, 0.21, 0.59, 0.55, 0.54, 0.55, 0.54])

In [23]:
set([p for pred in pred_tags for p in pred])

{'B-', 'I-', 'O', 'PAD'}

In [32]:
import time
start_time = time.time()
maxe = 0
for i in range(1000):
    x = torch.rand(12800,2)*random.randint(1,10)
    x = Variable(x.cuda())
    l = torch.rand(12800).ge(0.1).long()
    l = Variable(l.cuda())
    print()
    output0 = FocalLoss(gamma=0)(x,l)
    output1 = nn.CrossEntropyLoss()(x,l)
    a = output0.data
    b = output1.data
    if abs(a-b)>maxe: maxe = abs(a-b)
print('time:',time.time()-start_time,'max_error:',maxe)

time: 0.7968442440032959 max_error: tensor(4.7684e-07, device='cuda:0')
torch.Size([128, 1000, 8, 4]) torch.Size([128, 8, 4])
time: 0.023962020874023438 max_error: tensor(9.5367e-07, device='cuda:0')
