In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
PATH_TO_PROJECT = '/content/drive/My Drive/Serious/'
# path to conll class as well as to conll data
PATH_TO_CONLL = PATH_TO_PROJECT + 'coNLL/'
PATH_TO_TAG2IDX = PATH_TO_CONLL + 'tag2idx.json'
PATH_TO_ONE_TAG2IDX = PATH_TO_CONLL + 'one_tag2idx.json'
PATH_TO_CHECKPOINT = '/content/drive/My Drive/models/'

### Intall requirements

In [5]:
!pip install -r '/content/drive/My Drive/Serious/requirements.txt'

Collecting transformers~=4.3
[?25l  Downloading https://files.pythonhosted.org/packages/f9/54/5ca07ec9569d2f232f3166de5457b63943882f7950ddfcc887732fc7fb23/transformers-4.3.3-py3-none-any.whl (1.9MB)
[K     |████████████████████████████████| 1.9MB 16.9MB/s 
[?25hCollecting allennlp~=2.0
[?25l  Downloading https://files.pythonhosted.org/packages/e7/bd/c75fa01e3deb9322b637fe0be45164b40d43747661aca9195b5fb334947c/allennlp-2.1.0-py3-none-any.whl (585kB)
[K     |████████████████████████████████| 593kB 49.7MB/s 
[?25hCollecting seqeval~=1.2
[?25l  Downloading https://files.pythonhosted.org/packages/9d/2d/233c79d5b4e5ab1dbf111242299153f3caddddbb691219f363ad55ce783d/seqeval-1.2.2.tar.gz (43kB)
[K     |████████████████████████████████| 51kB 8.8MB/s 
[?25hCollecting pytorch-crf~=0.7
  Downloading https://files.pythonhosted.org/packages/96/7d/4c4688e26ea015fc118a0327e5726e6596836abce9182d3738be8ec2e32a/pytorch_crf-0.7.2-py3-none-any.whl
Collecting sacremoses
[?25l  Downloading https://fi

### Loading coNLL

In [6]:
import sys
sys.path.append(PATH_TO_PROJECT)
sys.path.append(PATH_TO_CONLL)

from importlib import reload
import conll as co

In [7]:
conll = co.CoNLL(PATH_TO_CONLL)

In [62]:
# splitting raw data to sentences and labels
for typ in conll.types:
    conll.split_text_label(typ)

# define set of all labels
conll.create_set_of_labels()

for typ in conll.types:
    # for multiple heads of CRF layer
    conll.create_one_labeled_data(typ)

    # creating one_tag2idx dictionary
    conll.create_one_tag2idx(PATH_TO_ONE_TAG2IDX)
    conll.create_idx2one_tag()

In [63]:
# dict of tag2idx mapping for each CRF-head (one head responsible for 'LOC' etc.)
conll.one_tag2idx

{'LOC': {'B-LOC': 0, 'I-LOC': 3, 'O': 2, 'PAD': 1},
 'MISC': {'B-MISC': 0, 'I-MISC': 3, 'O': 2, 'PAD': 1},
 'ORG': {'B-ORG': 2, 'I-ORG': 3, 'O': 1, 'PAD': 0},
 'PER': {'B-PER': 2, 'I-PER': 0, 'O': 3, 'PAD': 1}}

In [64]:
print(f"sen example: {conll.sentences['train'][0]}")
print(f"tags example: {conll.labels['train'][0]}")
print(f"tags example with only 'ORG' tag: {conll.one_tag_dict['train']['ORG'][0]}")
print(f"tags for CRF tags has labels: {conll.one_tag_dict['train'].keys()}")

sen example: ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.']
tags example: ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']
tags example with only 'ORG' tag: ['B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
tags for CRF tags has labels: dict_keys(['LOC', 'PER', 'ORG', 'MISC'])


### Importing packages

In [8]:
import numpy as np
import torch
from torch import nn
from torch.optim import AdamW
from transformers import BertTokenizer, BertModel
from transformers import BertForTokenClassification
from allennlp.modules.elmo import Elmo, batch_to_ids

from torchcrf import CRF

from sklearn.model_selection import KFold, ParameterGrid

from transformers import get_linear_schedule_with_warmup

import matplotlib
from matplotlib import pyplot as plt

%matplotlib inline

### Creating dataloaders

In [152]:
import data_loaders as dalo

In [153]:
reload(dalo)

<module 'data_loaders' from '/content/drive/My Drive/Serious/data_loaders.py'>

In [13]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




In [91]:
TAG_NAMES = ['ORG', 'LOC','PER']
NUM_OF_HEADS = len(TAG_NAMES)

In [92]:
# in the second argument we pass list of tag names for every head of the model
train_dataset, train_sampler, train_dataloader = dalo.create_dataloader(conll, TAG_NAMES, bert_tokenizer)

# sanity check for output sizes
assert train_dataset[0][0].shape[0] == train_dataset[0][1].shape[0]
if NUM_OF_HEADS > 1:
  assert train_dataset[0][2].shape[0] == NUM_OF_HEADS
  assert train_dataset[0][2].shape[1] == train_dataset[0][0].shape[0]
else:
  assert len(train_dataset[0][2].shape) == NUM_OF_HEADS # == 1
  assert train_dataset[0][2].shape[0] == train_dataset[0][0].shape[0]
assert train_dataset[0][3].shape[0] == train_dataset[0][0].shape[0]

print(f"bert sentence shape: {train_dataset[0][0].shape}")
print(f"elmo sentence shape: {train_dataset[0][1].shape}")
print(f"number of heads: {train_dataset[0][2].shape[0] if NUM_OF_HEADS > 1 else 1}")
print(f"tokens len: {train_dataset[0][2].shape[1] if NUM_OF_HEADS > 1 else train_dataset[0][2].shape[0]}")
print(f"mask shape: {train_dataset[0][3].shape}")

bert sentence shape: torch.Size([173])
elmo sentence shape: torch.Size([173, 50])
number of heads: 3
tokens len: 173
mask shape: torch.Size([173])


In [99]:
valid_dataset, valid_sampler, valid_dataloader = dalo.create_dataloader(conll, TAG_NAMES, bert_tokenizer,
                                                                        'valid', desired_pad=train_dataset[0][0].shape[0])

# sanity check for output sizes
assert valid_dataset[0][0].shape[0] == train_dataset[0][1].shape[0]
if NUM_OF_HEADS > 1:
  assert valid_dataset[0][2].shape[0] == NUM_OF_HEADS
  assert valid_dataset[0][2].shape[1] == train_dataset[0][0].shape[0]
else:
  assert len(valid_dataset[0][2].shape) == NUM_OF_HEADS # == 1
  assert valid_dataset[0][2].shape[0] == train_dataset[0][0].shape[0]
assert valid_dataset[0][3].shape[0] == train_dataset[0][0].shape[0]

print(f"bert sentence shape: {valid_dataset[0][0].shape}")
print(f"elmo sentence shape: {valid_dataset[0][1].shape}")

bert sentence shape: torch.Size([173])
elmo sentence shape: torch.Size([173, 50])


### Creating model

In [16]:
from bert_config import *
from elmo_config import *

In [17]:
class BEbiC(nn.Module):
    """
    BERT+Elmo+biLSTM+CRFs
    """
    def __init__(self, hidden_size=128, num_labels=4, tag_names=TAG_NAMES,
                 elmo_layers=2, bert_layers=1, concat_bert=True,
                 bilstm_layers=1, bilstm_dropout=0):
        """
        Creates model
        
        Parameters
        ----------
        hidden_size: int, default=128
          LSTM parameter
        num_labels: int, defualt=4
          The number of each CRF labels (ex: B-LABEL, I-LABEL, O, PAD for multiple heads case)
        tag_names: list of str
          List of tag names for models heads
        elmo_layers: int, default=2
          Num of ELMo layers to be considered
        bert_layers: int, default=1
          Num of final BERT hidden layers to be used as embedding vector.
        concat_bert: bool, default=True
          Whether to concat (True) or sum (False) last BERT hidden layers.
        bilstm_layers: int, default=1

        """

        super(BEbiC, self).__init__()

        self.hidden_size = hidden_size
        self.num_labels = num_labels
        self.tag_names = tag_names
        self.num_heads = len(self.tag_names)
        self.elmo_layers = elmo_layers
        self.bert_layers = bert_layers
        self.concat_bert = concat_bert
        self.bilstm_layers = bilstm_layers
        self.bilstm_dropout = bilstm_dropout
        
        self.bert = BertForTokenClassification.from_pretrained(
                        BERT_MODEL,
                        output_hidden_states=True)
        
        for pars in self.bert.parameters():
            pars.requires_grad = False
        
        bert_embedding_dim = self.bert.config.to_dict()['hidden_size']

        self.elmo = Elmo(options_file, weight_file, self.elmo_layers, dropout=0, requires_grad=False)
        
        elmo_embedding_dim = 512 # it's always fixed

        if self.concat_bert:
          self.linear1 = nn.Linear(bert_embedding_dim*self.bert_layers+elmo_embedding_dim*self.elmo_layers, 1024)
        else:
          self.linear1 = nn.Linear(bert_embedding_dim+elmo_embedding_dim*self.elmo_layers, 1024)
        
        self.bilstm = nn.LSTM(1024, self.hidden_size, self.bilstm_layers, 
                              bidirectional=True, dropout=self.bilstm_dropout)

        self.heads = {}
        for i, tag in enumerate(self.tag_names):
            lin_crf = nn.ModuleDict({'linear': nn.Linear(self.hidden_size*2, self.num_labels),
                                     'crf': CRF(num_tags=self.num_labels, batch_first=True)})
            self.heads[tag] = lin_crf

        self.heads = nn.ModuleDict(self.heads)
        self.active_heads = {head: True for head in self.heads.keys()}

    def add_head(self, tag_name):
        """
        Adds new head to the model

        """
        self.tag_names.append(tag_name)
        self.num_heads += 1
        lin_crf = nn.ModuleDict({'linear': nn.Linear(self.hidden_size*2, self.num_labels),
                                     'crf': CRF(num_tags=self.num_labels, batch_first=True)})
        self.heads.update({tag_name: lin_crf})


    def shared_forward(self, bert_ids, elmo_ids, attention_mask):
        """
        Forward propogate of model shared layers.
        
        Parameters
        ----------
        bert_ids:
        elmo_ids:
        attention_mask:
        
        Returns
        -------
        Bilstm logits with shape (seq_len, batch, 2*self.hidden_size)
        
        """

        mask = attention_mask.byte()
        bert_hiddens = self.bert(bert_ids, attention_mask=mask)[1]
        elmo_hiddens = self.elmo(elmo_ids)

        if self.concat_bert:
            bert_embedding = torch.cat(bert_hiddens[-self.bert_layers:], dim=2)
        else:
            emb_sum = 0
            for h in bert_hiddens[-self.bert_layers:]:
                emb_sum += h
            bert_embedding = emb_sum

        elmo_bert_embeddings = torch.clone(bert_embedding)
        for el_hi in elmo_hiddens['elmo_representations']:
            elmo_bert_embeddings = torch.cat((elmo_bert_embeddings, el_hi), dim=-1)

        linear1_output = nn.functional.relu(self.linear1(elmo_bert_embeddings))

        bilstm_output, (h_n, c_n) = self.bilstm(linear1_output)

        return bilstm_output
    
    def get_one_head_loss(self, bilstm_logits, head_labels, attention_mask, head_tag):
        """
        Returns negative log-likelihood for one head.
        You should run it after shared forward.

        Parameters
        ----------
        bilstm_logits:
        head_labels:
        attention_mask:
        head_tag: str
          Key of self.heads dictionary.
        
        Returns
        -------
        Loss

        """
        lin_out = nn.functional.relu(self.heads[head_tag]['linear'](bilstm_logits))
        loss = -1*self.heads[head_tag]['crf'].forward(lin_out, head_labels, mask=attention_mask.byte())
        return loss
    
    def get_one_head_seq(self, bilstm_logits, attention_mask, head_tag):
        """
        Returns the most likely sequence of labels for the given head.
        You should run it after shared forward.

        Parameters
        ----------
        bilstm_logits:
        attention_mask:
        head_tag: str
          Key of self.heads dictionary.
        
        Returns
        -------
        List
        """
      
        lin_out = nn.functional.relu(self.heads[head_tag]['linear'](bilstm_logits))
        seq = self.heads[head_tag]['crf'].decode(lin_out, mask=attention_mask.byte())
        return seq
    
    def forward(self, bert_ids, elmo_ids, head_labels, attention_mask):
        """
        Forward model pass.
        
        Parameters
        ----------
        bert_ids:
        elmo_ids:
        head_labels:
        attention_mask:
        
        Returns
        -------
        Total loss for all heads.

        """

        mask = attention_mask.byte()
        bilstm_logits = self.shared_forward(bert_ids, elmo_ids, mask)
        head_loss = 0
        for i, tag in enumerate(self.heads.keys()):
          _one_head_labels = head_labels[:,i,:] if len(self.heads.keys()) > 1 else head_labels
          head_loss += self.get_one_head_loss(bilstm_logits, _one_head_labels, mask, tag)
        return head_loss
    
    def freeze_head(self, head_tag):
        """
        Freezes model's head parameters.

        """
        if head_tag not in self.heads.keys():
            raise ValueError(f"Unknown head tag. Please, give one of {self.heads.keys()}")
        
        for parameter in self.heads[head_tag].parameters():
            parameter.requires_grad = False
        
        self.active_heads[head_tag] = False
    
    def unfreeze_head(self, head_tag):
        """
        Unfreezes model's head parameters.

        """
        if head_tag not in self.heads.keys():
            raise ValueError(f"Unknown head tag. Please, give one of {self.heads.keys()}")
        
        for parameter in self.heads[head_tag].parameters():
            parameter.requires_grad = True
        
        self.active_heads[head_tag] = True

In [18]:
model = BEbiC(hidden_size=512, bert_layers=2, bilstm_layers=2, bilstm_dropout=0.3)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435779157.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

In [141]:
import model_utils as mu

In [142]:
reload(mu)

<module 'model_utils' from '/content/drive/My Drive/Serious/model_utils.py'>

In [20]:
N_EPOCHS = 10
total_steps = len(train_dataloader) *  N_EPOCHS

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

model.to(device)

Two head experiment

In [22]:
model.tag_names

['ORG', 'LOC']

In [23]:
optimizer = AdamW(params=model.parameters(),lr=5e-4)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

if device.type != 'cpu':
    model.to(device)

loss_value, head_results = mu.train(model, train_dataloader, optimizer, device, conll, scheduler, n_epoch=N_EPOCHS,
                                valid_dataloader=valid_dataloader, path_to_save=PATH_TO_CHECKPOINT)

  0%|          | 0/110 [00:00<?, ?it/s]


Epoch #0


  9%|▉         | 10/110 [00:34<05:45,  3.45s/it]


9: avg loss per batch: 1634.3530409071182



 18%|█▊        | 20/110 [01:10<05:32,  3.70s/it]


19: avg loss per batch: 1204.4291156969573



 27%|██▋       | 30/110 [01:45<04:40,  3.51s/it]


29: avg loss per batch: 1041.4398277545797



 36%|███▋      | 40/110 [02:21<04:09,  3.57s/it]


39: avg loss per batch: 935.4730240259415



 45%|████▌     | 50/110 [02:57<03:33,  3.55s/it]


49: avg loss per batch: 849.6877516143176



 55%|█████▍    | 60/110 [03:32<02:58,  3.57s/it]


59: avg loss per batch: 789.2377645201602



 64%|██████▎   | 70/110 [04:07<02:21,  3.54s/it]


69: avg loss per batch: 740.8934733072916



 73%|███████▎  | 80/110 [04:43<01:47,  3.58s/it]


79: avg loss per batch: 704.5731784482546



 82%|████████▏ | 90/110 [05:18<01:10,  3.55s/it]


89: avg loss per batch: 672.6319806388254



 91%|█████████ | 100/110 [05:54<00:35,  3.57s/it]


99: avg loss per batch: 649.4647287696299



100%|██████████| 110/110 [06:28<00:00,  3.54s/it]


109: avg loss per batch: 628.2318778781716

Average train loss: 622.52067898837



  0%|          | 0/110 [00:00<?, ?it/s]

Mean validation loss: 413.95143921563886
Mean validation accuracy: 0.47381729529335914
Mean validation F1-score: 0.0


Epoch #1


  9%|▉         | 10/110 [00:35<05:54,  3.55s/it]


9: avg loss per batch: 466.7183261447483



 18%|█▊        | 20/110 [01:10<05:16,  3.52s/it]


19: avg loss per batch: 410.6698785079153



 27%|██▋       | 30/110 [01:46<04:49,  3.62s/it]


29: avg loss per batch: 387.460535509833



 36%|███▋      | 40/110 [02:21<04:06,  3.52s/it]


39: avg loss per batch: 372.5221463716947



 45%|████▌     | 50/110 [02:56<03:32,  3.54s/it]


49: avg loss per batch: 362.8121973154496



 55%|█████▍    | 60/110 [03:32<02:57,  3.54s/it]


59: avg loss per batch: 358.7145008151814



 64%|██████▎   | 70/110 [04:07<02:21,  3.53s/it]


69: avg loss per batch: 350.451318713202



 73%|███████▎  | 80/110 [04:43<01:46,  3.55s/it]


79: avg loss per batch: 342.9233520121514



 82%|████████▏ | 90/110 [05:18<01:11,  3.56s/it]


89: avg loss per batch: 340.8356352388189



 91%|█████████ | 100/110 [05:53<00:35,  3.51s/it]


99: avg loss per batch: 335.5155730584655



100%|██████████| 110/110 [06:28<00:00,  3.53s/it]


109: avg loss per batch: 333.22541865077585

Average train loss: 330.19609666304154



  0%|          | 0/110 [00:00<?, ?it/s]

Mean validation loss: 294.5017654682575
Mean validation accuracy: 0.4827732108317215
Mean validation F1-score: 0.20879231123228403


Epoch #2


  9%|▉         | 10/110 [00:35<05:50,  3.50s/it]


9: avg loss per batch: 312.1840074327257



 18%|█▊        | 20/110 [01:10<05:18,  3.54s/it]


19: avg loss per batch: 289.8907189620169



 27%|██▋       | 30/110 [01:45<04:43,  3.55s/it]


29: avg loss per batch: 290.9315069790544



 36%|███▋      | 40/110 [02:20<04:06,  3.52s/it]


39: avg loss per batch: 290.6043697259365



 45%|████▌     | 50/110 [02:56<03:34,  3.58s/it]


49: avg loss per batch: 287.6250753597337



 55%|█████▍    | 60/110 [03:32<03:00,  3.60s/it]


59: avg loss per batch: 292.4013759807005



 64%|██████▎   | 70/110 [04:07<02:21,  3.53s/it]


69: avg loss per batch: 291.3028131015059



 73%|███████▎  | 80/110 [04:42<01:44,  3.50s/it]


79: avg loss per batch: 291.3895501245426



 82%|████████▏ | 90/110 [05:17<01:10,  3.50s/it]


89: avg loss per batch: 290.59957577137465



 91%|█████████ | 100/110 [05:53<00:35,  3.57s/it]


99: avg loss per batch: 289.25368646178583



100%|██████████| 110/110 [06:28<00:00,  3.53s/it]


109: avg loss per batch: 287.5355680973158

Average train loss: 284.9216083873402



  0%|          | 0/110 [00:00<?, ?it/s]

Mean validation loss: 272.6783558393105
Mean validation accuracy: 0.4887874758220503
Mean validation F1-score: 0.32920291282735203


Epoch #3


  9%|▉         | 10/110 [00:35<05:53,  3.53s/it]


9: avg loss per batch: 287.0365159776476



 18%|█▊        | 20/110 [01:10<05:15,  3.51s/it]


19: avg loss per batch: 282.3577479312294



 27%|██▋       | 30/110 [01:45<04:41,  3.52s/it]


29: avg loss per batch: 277.12901621851427



 36%|███▋      | 40/110 [02:21<04:06,  3.52s/it]


39: avg loss per batch: 276.5733173076923



 45%|████▌     | 50/110 [02:56<03:31,  3.52s/it]


49: avg loss per batch: 273.8538712482063



 55%|█████▍    | 60/110 [03:31<02:56,  3.53s/it]


59: avg loss per batch: 273.06025411314886



 64%|██████▎   | 70/110 [04:07<02:23,  3.59s/it]


69: avg loss per batch: 270.09000076072806



 73%|███████▎  | 80/110 [04:43<01:47,  3.58s/it]


79: avg loss per batch: 269.30368351030955



 82%|████████▏ | 90/110 [05:18<01:11,  3.58s/it]


89: avg loss per batch: 268.5756890800562



 91%|█████████ | 100/110 [05:53<00:35,  3.50s/it]


99: avg loss per batch: 269.4731066154711



100%|██████████| 110/110 [06:28<00:00,  3.53s/it]


109: avg loss per batch: 268.4528355029745

Average train loss: 266.0123551802202



  0%|          | 0/110 [00:00<?, ?it/s]

Mean validation loss: 265.7896843229608
Mean validation accuracy: 0.4893986406619385
Mean validation F1-score: 0.337284429891815


Epoch #4


  9%|▉         | 10/110 [00:35<05:55,  3.55s/it]


9: avg loss per batch: 273.5650380452474



 18%|█▊        | 20/110 [01:10<05:17,  3.52s/it]


19: avg loss per batch: 269.816331562243



 27%|██▋       | 30/110 [01:45<04:40,  3.50s/it]


29: avg loss per batch: 268.01663997255525



 36%|███▋      | 40/110 [02:21<04:11,  3.59s/it]


39: avg loss per batch: 263.75291716746796



 45%|████▌     | 50/110 [02:57<03:31,  3.52s/it]


49: avg loss per batch: 263.6200405821508



 55%|█████▍    | 60/110 [03:32<02:58,  3.57s/it]


59: avg loss per batch: 263.5149949930482



 64%|██████▎   | 70/110 [04:08<02:21,  3.53s/it]


69: avg loss per batch: 260.05612890271175



 73%|███████▎  | 80/110 [04:43<01:46,  3.55s/it]


79: avg loss per batch: 258.660400390625



 82%|████████▏ | 90/110 [05:19<01:09,  3.48s/it]


89: avg loss per batch: 256.8822350662746



 91%|█████████ | 100/110 [05:53<00:34,  3.47s/it]


99: avg loss per batch: 255.07086012098523



100%|██████████| 110/110 [06:28<00:00,  3.53s/it]


109: avg loss per batch: 254.65500521878585

Average train loss: 252.33995971679687



  0%|          | 0/110 [00:00<?, ?it/s]

Mean validation loss: 253.3906175299644
Mean validation accuracy: 0.4901273372018053
Mean validation F1-score: 0.3523579569281063


Epoch #5


  9%|▉         | 10/110 [00:35<05:55,  3.55s/it]


9: avg loss per batch: 269.09601338704425



 18%|█▊        | 20/110 [01:10<05:21,  3.57s/it]


19: avg loss per batch: 254.5497364244963



 27%|██▋       | 30/110 [01:46<04:44,  3.56s/it]


29: avg loss per batch: 248.27895539382408



 36%|███▋      | 40/110 [02:21<04:07,  3.53s/it]


39: avg loss per batch: 247.20717210036057



 45%|████▌     | 50/110 [02:57<03:35,  3.60s/it]


49: avg loss per batch: 246.05020079320792



 55%|█████▍    | 60/110 [03:32<02:54,  3.50s/it]


59: avg loss per batch: 243.84449819791115



 64%|██████▎   | 70/110 [04:08<02:22,  3.57s/it]


69: avg loss per batch: 243.42119388303894



 73%|███████▎  | 80/110 [04:44<01:47,  3.59s/it]


79: avg loss per batch: 242.11924241464348



 82%|████████▏ | 90/110 [05:20<01:11,  3.58s/it]


89: avg loss per batch: 242.33003011982092



 91%|█████████ | 100/110 [05:55<00:35,  3.55s/it]


99: avg loss per batch: 243.1880303585168



100%|██████████| 110/110 [06:29<00:00,  3.54s/it]


109: avg loss per batch: 240.68852835838948

Average train loss: 238.50045082785866



  0%|          | 0/110 [00:00<?, ?it/s]

Mean validation loss: 241.51312492094877
Mean validation accuracy: 0.490983639587363
Mean validation F1-score: 0.3605975543119111


Epoch #6


  9%|▉         | 10/110 [00:35<05:52,  3.53s/it]


9: avg loss per batch: 273.61539035373266



 18%|█▊        | 20/110 [01:10<05:17,  3.53s/it]


19: avg loss per batch: 243.31902915553042



 27%|██▋       | 30/110 [01:45<04:42,  3.54s/it]


29: avg loss per batch: 236.80270754057784



 36%|███▋      | 40/110 [02:21<04:08,  3.55s/it]


39: avg loss per batch: 234.37194119966946



 45%|████▌     | 50/110 [02:56<03:33,  3.56s/it]


49: avg loss per batch: 236.96360031439335



 55%|█████▍    | 60/110 [03:31<02:55,  3.51s/it]


59: avg loss per batch: 233.7699116852324



 64%|██████▎   | 70/110 [04:07<02:22,  3.57s/it]


69: avg loss per batch: 233.4279577282892



 73%|███████▎  | 80/110 [04:42<01:46,  3.54s/it]


79: avg loss per batch: 232.73816555361205



 82%|████████▏ | 90/110 [05:18<01:10,  3.54s/it]


89: avg loss per batch: 232.3286784311359



 91%|█████████ | 100/110 [05:53<00:35,  3.52s/it]


99: avg loss per batch: 231.6726259173769



100%|██████████| 110/110 [06:27<00:00,  3.53s/it]


109: avg loss per batch: 230.72493932881486

Average train loss: 228.6274398803711



  0%|          | 0/110 [00:00<?, ?it/s]

Mean validation loss: 236.7181797003066
Mean validation accuracy: 0.4910105039759295
Mean validation F1-score: 0.36152443546735236


Epoch #7


  9%|▉         | 10/110 [00:35<05:57,  3.57s/it]


9: avg loss per batch: 240.56321885850696



 18%|█▊        | 20/110 [01:10<05:15,  3.51s/it]


19: avg loss per batch: 221.65395074141654



 27%|██▋       | 30/110 [01:46<04:43,  3.55s/it]


29: avg loss per batch: 221.50297335920663



 36%|███▋      | 40/110 [02:21<04:06,  3.52s/it]


39: avg loss per batch: 220.5404827411358



 45%|████▌     | 50/110 [02:56<03:31,  3.53s/it]


49: avg loss per batch: 219.57207598005022



 55%|█████▍    | 60/110 [03:31<02:55,  3.50s/it]


59: avg loss per batch: 221.08830985376392



 64%|██████▎   | 70/110 [04:07<02:22,  3.55s/it]


69: avg loss per batch: 221.59434487163156



 73%|███████▎  | 80/110 [04:43<01:47,  3.59s/it]


79: avg loss per batch: 222.33280133597458



 82%|████████▏ | 90/110 [05:18<01:10,  3.50s/it]


89: avg loss per batch: 221.06860540154275



 91%|█████████ | 100/110 [05:53<00:35,  3.58s/it]


99: avg loss per batch: 223.16273251928465



100%|██████████| 110/110 [06:28<00:00,  3.53s/it]


109: avg loss per batch: 222.2333687598552

Average train loss: 220.2130654074929



  0%|          | 0/110 [00:00<?, ?it/s]

Mean validation loss: 238.9468812042509
Mean validation accuracy: 0.49060753814743174
Mean validation F1-score: 0.3569450126014265


Epoch #8


  9%|▉         | 10/110 [00:35<05:54,  3.54s/it]


9: avg loss per batch: 232.38173590766058



 18%|█▊        | 20/110 [01:10<05:16,  3.51s/it]


19: avg loss per batch: 225.296192369963



 27%|██▋       | 30/110 [01:45<04:42,  3.53s/it]


29: avg loss per batch: 217.63945849188443



 36%|███▋      | 40/110 [02:21<04:05,  3.51s/it]


39: avg loss per batch: 216.3811508569962



 45%|████▌     | 50/110 [02:56<03:31,  3.52s/it]


49: avg loss per batch: 218.02829571159518



 55%|█████▍    | 60/110 [03:31<02:56,  3.52s/it]


59: avg loss per batch: 216.1800917285984



 64%|██████▎   | 70/110 [04:06<02:20,  3.51s/it]


69: avg loss per batch: 219.28322081634963



 73%|███████▎  | 80/110 [04:42<01:46,  3.55s/it]


79: avg loss per batch: 219.60334121124654



 82%|████████▏ | 90/110 [05:17<01:10,  3.51s/it]


89: avg loss per batch: 217.26043272554205



 91%|█████████ | 100/110 [05:52<00:34,  3.50s/it]


99: avg loss per batch: 217.2481216276535



100%|██████████| 110/110 [06:26<00:00,  3.51s/it]


109: avg loss per batch: 216.4533736202695

Average train loss: 214.48561567826704



  0%|          | 0/110 [00:00<?, ?it/s]

Mean validation loss: 231.5713075835644
Mean validation accuracy: 0.4911078873844831
Mean validation F1-score: 0.36261410122705573


Epoch #9


  9%|▉         | 10/110 [00:35<05:52,  3.53s/it]


9: avg loss per batch: 242.50179545084634



 18%|█▊        | 20/110 [01:10<05:23,  3.59s/it]


19: avg loss per batch: 233.6527244166324



 27%|██▋       | 30/110 [01:46<04:44,  3.56s/it]


29: avg loss per batch: 223.27904957738417



 36%|███▋      | 40/110 [02:22<04:08,  3.55s/it]


39: avg loss per batch: 217.1456498366136



 45%|████▌     | 50/110 [02:57<03:33,  3.57s/it]


49: avg loss per batch: 214.99374856754224



 55%|█████▍    | 60/110 [03:32<02:56,  3.53s/it]


59: avg loss per batch: 213.96652816513839



 64%|██████▎   | 70/110 [04:08<02:21,  3.54s/it]


69: avg loss per batch: 213.96804632656816



 73%|███████▎  | 80/110 [04:43<01:46,  3.54s/it]


79: avg loss per batch: 215.11833075028431



 82%|████████▏ | 90/110 [05:19<01:11,  3.57s/it]


89: avg loss per batch: 213.3746607062522



 91%|█████████ | 100/110 [05:54<00:35,  3.53s/it]


99: avg loss per batch: 214.39108584625552



100%|██████████| 110/110 [06:28<00:00,  3.53s/it]


109: avg loss per batch: 212.63346085854627

Average train loss: 210.70042939619586





Mean validation loss: 229.4394204326439
Mean validation accuracy: 0.4913127283473028
Mean validation F1-score: 0.36505496439767937



In [24]:
mu.eval_model(model, valid_dataloader, device, conll)



({'LOC': {'acc': 0.9869707715452396, 'f1': 0.8170916609235009},
  'ORG': {'acc': 0.9782801418439716, 'f1': 0.6431281966672167}},
 458.9203822408757,
 0.9826254566946055,
 0.7301099287953587)

### Loading pretrained model

In [None]:
bert_tokenizer, model, opt_state = mu.load_checkpoint(PATH_TO_CHECKPOINT+'ElMo_BERT_biLSTM_oneCRF_19_state_dict.pth',
                                                      PATH_TO_CHECKPOINT+'ElMo_BERT_biLSTM_oneCRF_19_tokenizer.pth')

In [30]:
optimizer = AdamW(params=model.parameters(),lr=3e-4)
optimizer.load_state_dict(opt_state)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

model.to(device)

In [None]:
valid_dataset, valid_sampler, valid_dataloader = dalo.create_dataloader(conll, TAG_NAMES, bert_tokenizer, 'valid')

  0%|          | 0/110 [22:01<?, ?it/s]


In [None]:
head_result, mean_loss, mean_acc, mean_f1 = mu.eval_model(model, valid_dataloader, device, conll)



In [None]:
head_result

{'LOC': {'acc': 0.9873871695680206, 'f1': 0.8200941046221977},
 'ORG': {'acc': 0.9786562432839029, 'f1': 0.6420724708968684},
 'PER': {'acc': 0.9772995916612938, 'f1': 0.3731228340392761}}

#### Continue to train pretrained model

In [None]:
TAG_NAMES = ['ORG', 'LOC', 'PER']
NUM_OF_HEADS = len(TAG_NAMES)

# in the second argument we pass list of tag names for every head of the model
train_dataset, train_sampler, train_dataloader = dalo.create_dataloader(conll, TAG_NAMES, bert_tokenizer)

In [None]:
N_EPOCHS = 5
total_steps = len(train_dataloader) *  N_EPOCHS

In [None]:
optimizer = AdamW(params=model.parameters(),lr=1e-4)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

if device.type != 'cpu':
    model.to(device)

loss_value, head_results = mu.train(model, train_dataloader, optimizer, device, conll, scheduler, n_epoch=N_EPOCHS,
                                valid_dataloader=valid_dataloader, path_to_save=PATH_TO_CHECKPOINT)

### Load one-head model to compare with mul-head one

In [193]:
conll_old = co.CoNLL_old(PATH_TO_CONLL)
for typ in conll_old.types:
  conll_old.split_text_label(typ)
conll_old.create_tag2idx(PATH_TO_TAG2IDX)
conll_old.create_idx2tag()

In [107]:
import models

In [108]:
reload(models)

<module 'models' from '/content/drive/My Drive/Serious/models.py'>

In [109]:
from models import *

In [110]:
bert_tokenizer, old_model, opt_state = mu.load_checkpoint(PATH_TO_CHECKPOINT+'ElMo_BERT_biLSTM_oneCRF_19_state_dict.pth',
                                                          PATH_TO_CHECKPOINT+'ElMo_BERT_biLSTM_oneCRF_19_tokenizer.pth')



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

old_model.to(device)

In [73]:
train_dataset, train_sampler, train_dataloader = dalo.create_dataloader(conll, ['PER'], bert_tokenizer)

In [194]:
valid_dataset, valid_sampler, valid_dataloader = dalo.create_dataloader(conll, ['PER'], bert_tokenizer, 
                                                                        datatype='valid', desired_pad=train_dataset[0][0].shape[0])

PER

We need an old dataloader here

In [195]:
valid_dataset, valid_sampler, valid_dataloader = dalo.create_dataloader_old(conll.sentences['valid'],
                                                                            conll.one_tag_dict['valid']['PER'], conll_old.tag2idx,
                                                                            bert_tokenizer, datatype='valid', desired_pad=173)

In [197]:
mu.eval_old(old_model, valid_dataloader, device, conll_old.idx2tag)



(2879.9386127178486, 0.8668600902643456, 0.4219006007646095)

ORG

In [198]:
valid_dataset, valid_sampler, valid_dataloader = dalo.create_dataloader_old(conll.sentences['valid'],
                                                                            conll.one_tag_dict['valid']['ORG'], conll_old.tag2idx,
                                                                            bert_tokenizer, datatype='valid', desired_pad=173)

In [199]:
mu.eval_old(old_model, valid_dataloader, device, conll_old.idx2tag)



(4071.6029616135816, 0.8300827423167849, 0.3436274160188289)

LOC

In [200]:
valid_dataset, valid_sampler, valid_dataloader = dalo.create_dataloader_old(conll.sentences['valid'],
                                                                            conll.one_tag_dict['valid']['LOC'], conll_old.tag2idx,
                                                                            bert_tokenizer, datatype='valid', desired_pad=173)

In [201]:
mu.eval_old(old_model, valid_dataloader, device, conll_old.idx2tag)



(4093.646503155048, 0.8329706640876854, 0.44121974053764873)

The last values above are f1-scores.

In the case of multiple-head fitting of this 3 heads at the same time we got:

'PER' - 0.3731228340392761 (vs 0.4219006007646095)

'ORG' - 0.6420724708968684 (vs 0.3436274160188289)

'LOC' - 0.8200941046221977 (vs 0.44121974053764873)

So, results of multiple heads look better!