In [1]:
device = "cpu"

In [2]:
from datasets import  Tokenized_Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = Tokenized_Dataset(json_file='negacio_uab_revised_version.json', tokenizer_name='bert-base-multilingual-cased')

do train_test split:

In [4]:
import torch
# Define the indices for train and test subsets
train_indices = range(0,len(dataset)-int(len(dataset)*0.3))
test_indices = range(len(dataset)-int(len(dataset)*0.3), len(dataset))

# Create Subset datasets based on the defined indices
train_data = torch.utils.data.Subset(dataset, train_indices)
val_data= torch.utils.data.Subset(dataset, test_indices)

In [5]:
from torch.utils.data import DataLoader
batch_size = 1
train_loader = DataLoader(train_data,batch_size)
val_loader = DataLoader(val_data,batch_size)

Import pre-trained bert multillingual

In [6]:
from transformers import BertModel
bert = BertModel.from_pretrained('bert-base-multilingual-cased')

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
import torch

Test a sample:

In [8]:
sample = next(iter(train_loader))
sample

{'x': tensor([[   164,  75151,  10731,    166,  28630,  13418,    171,  27542,  10425,
             131,    115,    115,    115,    115,    115,    115,    115,    115,
           28630,  19986,  48832,  10703,    131,    115,    115,    115,    115,
             115,    115,    115,    115,  18793,    131,  34010,  11165,  10104,
           76206,  90945,    131,  10719,    119,  10719,    119,  11082,  24695,
             131,  12791,  12627,  82176,  13212,  70982,  34128,    120,  39429,
             119,  10380,  10350,  36384,  96825,  28562,  64576,  11165,    172,
             112,  11600,  11234,  10150,    119,  10831,    119,  10434,  11165,
             172,    112,  14855,  10150,    119,  10831,    119,  10434,  10250,
             131,  11528,    131,  11349,  10160,  10171,  10178,    115,    115,
             115,    115,    115,    115,    115,    115,    115,    115,    115,
             115,    115,    115,    115,    115,    117,    115,    115,    115,
           

In [9]:
predictions = bert(sample["x"])
predictions = predictions["last_hidden_state"]

In [10]:
predictions.shape

torch.Size([1, 512, 768])

In [11]:
tags = sample["y"]
tags.shape

torch.Size([1, 512])

In [12]:
tags.shape

torch.Size([1, 512])

We'll define a tagger model that has a linear layer that helps project the last hidden state into the vocab we have. We'll further have a dropout for regularisation.

In [13]:
import torch.nn as nn

class BERT_Tagger(nn.Module):
    def __init__(self,
                 bert,
                 output_dim, 
                 dropout):
        
        super().__init__()
        
        self.bert = bert
        
        embedding_dim = bert.config.to_dict()['hidden_size']
        
        self.fc = nn.Linear(embedding_dim,output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, tokens):
        
        bert_out = self.bert(tokens)["last_hidden_state"]
        
        predictions = self.fc(bert_out)
        
        return predictions

In [14]:
bert_tagger = BERT_Tagger(bert,len(dataset.uniq_tags),0.2)

In [15]:
out = bert_tagger(sample["x"])
out.shape

torch.Size([1, 512, 5])

In [16]:
def categorical_accuracy(preds, y, tag_pad_idx):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    non_pad_elements = (y != tag_pad_idx).nonzero()
    correct = max_preds[non_pad_elements].squeeze(1).eq(y[non_pad_elements])
    return correct.sum() / torch.FloatTensor([y[non_pad_elements].shape[0]]).to(device)

In [17]:
def train(model, epochs, dataloader, optimizer, criterion, tag_pad_idx):    
    model.train()

    for i in range(epochs):
        for j, batch in enumerate(dataloader):
            tokens = batch["x"].to(device)
            tags = batch["y"].to(device)
            #look if all tags in the batch are none, if so skip
            if torch.equal(tags, torch.tensor([[dataset.uniq_tags.index("NONE")]*tags.shape[1]])) :
                continue #skip batch
                     
            optimizer.zero_grad()
            
            #text = [sent len, batch size]
            
            predictions = model(tokens)
            predictions = predictions.view(-1, predictions.shape[-1]) #merge sent len and batch dimensions

            tags = tags.view(-1)
            #predictions  = [sent len * batch size, output dim]
            #tags = [sent len * batch size]
            
            loss = criterion(predictions, tags)
                    
            acc = categorical_accuracy(predictions, tags, tag_pad_idx)
            
            loss.backward()
            
            optimizer.step()

            print(f"epoch:{i} batch:{j} loss:{loss.item()} acc:{acc.item()}")

In [18]:
#make the criterion not count loss on "NONE" tag
#criterion = nn.CrossEntropyLoss(ignore_index = dataset.uniq_tags.index("NONE"))
criterion = nn.CrossEntropyLoss()

In [19]:
import torch.optim as optim
LEARNING_RATE = 5e-5 #as recomended in BERT paper
optimizer = optim.Adam(bert_tagger.parameters(), lr = LEARNING_RATE)

In [20]:
bert_tagger = bert_tagger.to(device)

In [21]:
tag_pad_idx = dataset.uniq_tags.index("NONE")

In [26]:
train(bert_tagger,100,train_loader,optimizer,criterion,10)

epoch:0 batch:0 loss:0.014694410376250744 acc:0.998046875
epoch:0 batch:1 loss:0.01432236097753048 acc:0.998046875
epoch:0 batch:2 loss:0.013745415024459362 acc:0.998046875
epoch:0 batch:3 loss:0.014352340251207352 acc:0.998046875
epoch:0 batch:4 loss:0.014523501507937908 acc:0.998046875
epoch:0 batch:5 loss:0.014265867881476879 acc:0.998046875
epoch:0 batch:6 loss:0.01462233904749155 acc:0.998046875
epoch:0 batch:7 loss:0.014471763744950294 acc:0.998046875
epoch:0 batch:8 loss:0.013861819170415401 acc:0.998046875
epoch:0 batch:9 loss:0.014427647925913334 acc:0.998046875
epoch:0 batch:10 loss:0.014190863817930222 acc:0.998046875
epoch:0 batch:11 loss:0.013955176807940006 acc:0.998046875
epoch:0 batch:12 loss:0.013599306344985962 acc:0.998046875
epoch:0 batch:13 loss:0.013699935749173164 acc:0.998046875
epoch:0 batch:14 loss:0.014510207809507847 acc:0.998046875
epoch:0 batch:15 loss:0.014030101709067822 acc:0.998046875
epoch:0 batch:16 loss:0.014362625777721405 acc:0.998046875
epoch:0 b

KeyboardInterrupt: 

Test the model qualititavely:

In [23]:
from colorama import Fore, Back, Style

In [27]:
def eval(model,dataset,n_batches,tag_pad_idx,uniq_tags):
    colors = [Fore.BLUE,Fore.RED,Fore.GREEN,Fore.CYAN,Fore.WHITE]
    reference_tag_txt = [colors[i]+uniq_tags[i] for i in range(len(uniq_tags))]
    reference_tag_txt = " ".join(reference_tag_txt)
    
    for i in range(0,n_batches*512,512):
        batch = dataset.__getitem__(i)
        tokens = batch["x"].to(device).unsqueeze(0)
        tags = batch["y"].to(device).unsqueeze(0)
        tokens_txt = batch["x_ref"]
        predictions = model(tokens)
        predictions = predictions.view(-1, predictions.shape[-1]) #merge sent len and batch dimensions
        predictions = torch.argmax(predictions,axis=1).numpy()
        tags = tags.view(-1)
        
        txt = ""
        txt_pred = ""
        for tok,tag,pred in zip(tokens_txt,tags,predictions):
            if tok[0] == "#":
                txt += colors[tag]+tok.replace("#","")
                txt_pred += colors[pred]+tok.replace("#","")

            else:
                txt += " " + colors[tag]+tok.replace("#","")
                txt_pred += " " + colors[pred]+tok.replace("#","")

        print(reference_tag_txt)
        print(Fore.WHITE+"-------------------True-------------------")
        print(txt)
        print(Fore.WHITE+"-------------------True-------------------")
        print(Fore.WHITE+"-------------------Pred-------------------")
        print(txt_pred)
        print(Fore.WHITE+"-------------------Pred-------------------")
        

In [28]:
eval(bert_tagger,val_data,10,tag_pad_idx,dataset.uniq_tags)

[34mUNC [31mNEG [32mUSCO [36mNSCO [37mNONE
[37m-------------------True-------------------
[37mia [37m, [37mleve [37mprotein[37muria [37mse [37mori[37menta [37mcomo [37mins[37muf[37miciencia [37mren[37mal [37mag[37muda [37mcon [37mfeu [37m40 [37m% [37m; [37minicia[37mndo [37msue[37mrot[37mera[37mpia [37my [37mbi[37mcar[37mbona[37mto [37mev [37m. [37mad[37mema[37ms [37m, [37mane[37mmia [37mnor[37mmo[37mciti[37mca [37mhip[37moc[37mrom[37mica [37m, [37mse [37mrealiza [37mtrans[37mfus[37mion [37mde [37m1 [37mcc[37mh[37mh [37m, [37mcon [37mbuena [37mrespuesta [37mposterior [37m. [37mante [37msed[37mimento [37mde [37mori[37mna [37mcon [37mpi[37muria [37my [37mleve [37mle[37muco[37mcito[37msis [37m, [37mse [37minicia [37mcobertura [37mem[37mpir[37mica [37mcon [37mci[37mpro[37mf[37mlo[37mxa[37mcino [37mev [37m, [37mcambia[37mndose [37ma [37mamo[37mxic[37milina [37m- [37mc[37mlav[37mul