In [1]:
device = "cpu"

In [2]:
from datasets import  Tokenized_Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = Tokenized_Dataset(json_file='negacio_uab_revised_version.json', tokenizer_name='bert-base-multilingual-cased')

do train_test split:

In [93]:
from torch.utils.data import random_split
train_ratio = 0.7
train_size = int(train_ratio * len(dataset))
test_size = len(dataset) - train_size
train_data, val_data = random_split(dataset, [train_size, test_size])

In [94]:
from torch.utils.data import DataLoader
batch_size = 1
train_loader = DataLoader(train_data,batch_size)
val_loader = DataLoader(val_data,batch_size)

Import pre-trained bert multillingual

In [6]:
from transformers import BertModel
bert = BertModel.from_pretrained('bert-base-multilingual-cased')

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
import torch

Test a sample:

In [8]:
sample = next(iter(train_loader))
sample

{'x': tensor([[ 12479,  37241,  65921,  10112,  10104,  11782,  12563,    182,  10237,
           10350,    115,    115,    115,    115,    115,    115,    115,    115,
             113,    115,    115,    115,    115,    115,    115,    115,    115,
             114,  12089,    118,    190,    118,    182,  71843,    123,    120,
             125,  10406,  10410,  10162,  12642,    119,  11052,  15603,  86153,
           12458,    119,  10907,  10196,  24154,  11129,  63256,  10415,    177,
           48602,  12429,  48832,  37253,  11130,  10104,  10109,  41807,  13584,
             119,  48602,  10104,  10109,  10104,  58771,  13584,  63256,  10415,
           12074,    119,  11417,  71560,  11669, 100527,    172,    112,  83360,
           11231,  55391,  60304,    113,    123,    114,  12074,    119,  12428,
           71560,  11669, 100527,    172,    112,  94614,    177,  62893,  30698,
          101493,  12926,  10138, 107126,    177,  11639,  76454,  61804,  15880,
           

In [9]:
predictions = bert(sample["x"])
predictions = predictions["last_hidden_state"]

In [10]:
predictions.shape

torch.Size([1, 512, 768])

In [11]:
tags = sample["y"]
tags.shape

torch.Size([1, 512])

In [12]:
tags.shape

torch.Size([1, 512])

We'll define a tagger model that has a linear layer that helps project the last hidden state into the vocab we have. We'll further have a dropout for regularisation.

In [13]:
import torch.nn as nn

class BERT_Tagger(nn.Module):
    def __init__(self,
                 bert,
                 output_dim, 
                 dropout):
        
        super().__init__()
        
        self.bert = bert
        
        embedding_dim = bert.config.to_dict()['hidden_size']
        
        self.fc = nn.Linear(embedding_dim,output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, tokens):
        
        bert_out = self.bert(tokens)["last_hidden_state"]
        
        predictions = self.fc(bert_out)
        
        return predictions

In [14]:
bert_tagger = BERT_Tagger(bert,len(dataset.uniq_tags),0.2)

In [15]:
out = bert_tagger(sample["x"])
out.shape

torch.Size([1, 512, 5])

In [16]:
def categorical_accuracy(preds, y, tag_pad_idx):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    non_pad_elements = (y != tag_pad_idx).nonzero()
    correct = max_preds[non_pad_elements].squeeze(1).eq(y[non_pad_elements])
    return correct.sum() / torch.FloatTensor([y[non_pad_elements].shape[0]]).to(device)

In [17]:
def train(model, epochs, dataloader, optimizer, criterion, tag_pad_idx):    
    model.train()

    for i in range(epochs):
        for j, batch in enumerate(dataloader):
            tokens = batch["x"].to(device)
            tags = batch["y"].to(device)
            #look if all tags in the batch are none, if so skip
            if torch.equal(tags, torch.tensor([[dataset.uniq_tags.index("NONE")]*tags.shape[1]])) :
                continue #skip batch
                     
            optimizer.zero_grad()
            
            #text = [sent len, batch size]
            
            predictions = model(tokens)
            predictions = predictions.view(-1, predictions.shape[-1]) #merge sent len and batch dimensions

            tags = tags.view(-1)
            #predictions  = [sent len * batch size, output dim]
            #tags = [sent len * batch size]
            
            loss = criterion(predictions, tags)
                    
            acc = categorical_accuracy(predictions, tags, tag_pad_idx)
            
            loss.backward()
            
            optimizer.step()

            print(f"epoch:{i} batch:{j} loss:{loss.item()} acc:{acc.item()}")

In [18]:
#make the criterion not count loss on "NONE" tag
criterion = nn.CrossEntropyLoss(ignore_index = dataset.uniq_tags.index("NONE"))

In [19]:
import torch.optim as optim
LEARNING_RATE = 5e-5 #as recomended in BERT paper
optimizer = optim.Adam(bert_tagger.parameters(), lr = LEARNING_RATE)

In [20]:
bert_tagger = bert_tagger.to(device)

In [22]:
tag_pad_idx = dataset.uniq_tags.index("NONE")

In [23]:
train(bert_tagger,100,train_loader,optimizer,criterion,tag_pad_idx)

epoch:0 batch:3 loss:0.024001633748412132 acc:1.0
epoch:0 batch:4 loss:0.023329535499215126 acc:1.0


KeyboardInterrupt: 

Test the model qualititavely:

In [24]:
from colorama import Fore, Back, Style

In [118]:
def eval(model,dataloader,n_batches,tag_pad_idx):
    loader = iter(dataloader)

    colors = [Fore.BLUE,Fore.RED,Fore.GREEN,Fore.CYAN,Fore.WHITE]
    reference_tag_txt = [colors[i]+dataset.uniq_tags[i] for i in range(len(dataset.uniq_tags))]
    reference_tag_txt = " ".join(reference_tag_txt)
    
    txt = ""
    txt_pred = ""
    for i in range(n_batches):
        batch = next(loader)
        tokens = batch["x"].to(device)
        tags = batch["y"].to(device)
        tokens_txt = batch["x_ref"]
        predictions = model(tokens)
        predictions = predictions.view(-1, predictions.shape[-1]) #merge sent len and batch dimensions
        predictions = torch.argmax(predictions,axis=1).numpy()
        print(predictions.shape)
        tags = tags.view(-1)
        
        for tok,tag,pred in zip(tokens_txt,tags,predictions):
            if tok[0][0] == "#":
                txt += colors[tag]+str(tok[0]).replace("#","")
                txt_pred += colors[pred]+str(tok[0]).replace("#","")

            else:
                txt += " " + colors[tag]+str(tok[0]).replace("#","")
                txt_pred += " " + colors[pred]+str(tok[0]).replace("#","")

        print(reference_tag_txt)
        print(Fore.WHITE+"-------------------True-------------------")
        print(txt)
        print(Fore.WHITE+"-------------------True-------------------")
        print(Fore.WHITE+"-------------------Pred-------------------")
        print(txt_pred)
        print(Fore.WHITE+"-------------------Pred-------------------")
        

In [119]:
eval(bert_tagger,train_loader,20,tag_pad_idx)

(512,)
[34mUNC [31mNEG [32mNSCO [36mUSCO [37mNONE
[37m-------------------True-------------------
 [37mesti[37mmula[37mcion [37mo[37mx[37mcito[37mcini[37mca [37m. [37mse [37mins[37mtau[37mra [37mane[37mstes[37mia [37mper[37midu[37mral [37m. [37mse [37mrealiza [37mprofil[37max[37mis [37manti[37mbio[37mtica [37mcon [37mpen[37mile[37mvel [37m5[37mm [37mu[37mi [37mdurante [37mam[37mnio[37mrre[37mxis [37m+ [37m2 [37m, [37m5 [37mm [37mu[37mi [37m/ [37m4[37mh [37mhasta [37mel [37mex[37mpuls[37mivo [37mpor [37msg[37mb [37mdes[37mcono[37mcido [37men [37mrpm [37mpret[37merm[37mino [37m. [37mpro[37mgres[37mion [37mad[37mecu[37mada [37mde [37mla [37mdil[37mata[37mcion [37mhasta [37mllegar [37ma [37mdil[37mata[37mcion [37mcompleta [37m. [37mel [37mdia [37m22 [37m/ [37m10 [37m/ [37m18 [37ma [37mlas [37m14 [37m: [37m15 [37mhoras [37m, [37mse [37masi[37mste [37mparte [37meut[37moci[37mco [

KeyboardInterrupt: 