# 1 - NER PyTorch


**Sources** :
- Source de données Kaggle : https://www.kaggle.com/datasets/abhinavwalia95/entity-annotated-corpus
- Tuto Abishek : https://www.youtube.com/watch?v=MqQ7rqRllIc

**TODO**
- [x] changer le padding par une autre valeur -> pas besoin en définitive.
- [x] set num workers (minitrain, valid et test)
- [x] peut-on monter la taille des batchs ?
- [x] stratifier les splits
- [x] enregistrer la running loss
- [ ] est-on sûr que les special tokens ne contribuent pas à la loss ?
- [ ] vérifier que le code <a href="https://www.kaggle.com/code/abhishek/entity-extraction-model-using-bert-pytorch">ici</a> est le bon
- [ ] gradual unfreezing
- [ ] early stopping avec lightning
- [ ] wandb

In [1]:
%load_ext autoreload
%autoreload 2

import os
import time

import joblib
import torch
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformers import get_linear_schedule_with_warmup

os.chdir('..')

from ner_pytorch.config.params import PARAMS
from ner_pytorch.dataset import EntityDataset
from ner_pytorch.engine import eval_fn, train_fn
from ner_pytorch.model import EntityModel
from ner_pytorch.preprocessing import process_data
from ner_pytorch.utils import *

I0606 09:28:15.558454 140644682426176 modeling.py:230] Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .


## Chargement des données

In [2]:
data = pd.read_csv(PARAMS.PATHS.TRAIN, encoding='latin-1').drop('POS', axis=1)
data["Sentence #"] = data["Sentence #"].fillna(method='ffill')
data.head(7)

Unnamed: 0,Sentence #,Word,Tag
0,Sentence: 1,Thousands,O
1,Sentence: 1,of,O
2,Sentence: 1,demonstrators,O
3,Sentence: 1,have,O
4,Sentence: 1,marched,O
5,Sentence: 1,through,O
6,Sentence: 1,London,B-geo


En réalité, nous ne nous intéressons qu'à la prédiction des ORG, pour lesquelles 3 labels sont possibles :
- B-org
- I-org
- O

In [3]:
data['Tag'] = data.Tag.mask(~data.Tag.isin(['B-org', 'I-org', 'O']), 'O')
data.Tag.value_counts(normalize=True)

O        0.964784
B-org    0.019210
I-org    0.016006
Name: Tag, dtype: float64

In [4]:
num_tag = data.Tag.nunique()
print(f'Tag {num_tag} categories :', data.Tag.unique(), end='\n\n')

data.shape
data[153:160]

Tag 3 categories : ['O' 'B-org' 'I-org']



(1048575, 3)

Unnamed: 0,Sentence #,Word,Tag
153,Sentence: 8,The,O
154,Sentence: 8,International,B-org
155,Sentence: 8,Atomic,I-org
156,Sentence: 8,Energy,I-org
157,Sentence: 8,Agency,I-org
158,Sentence: 8,is,O
159,Sentence: 8,to,O


In [5]:
sentences, tag, label_enc_NER = process_data(data)
joblib.dump(label_enc_NER, 'data/outputs/label_enc_NER.joblib')

['data/outputs/label_enc_NER.joblib']

In [6]:
# Démo : 
i = 10
print(sentences[i], end='\n')
print(tag[i], end='\n')
print(label_enc_NER.classes_, end='\n')

['In', 'Beirut', ',', 'a', 'string', 'of', 'officials', 'voiced', 'their', 'anger', ',', 'while', 'at', 'the', 'United', 'Nations', 'summit', 'in', 'New', 'York', ',', 'Prime', 'Minister', 'Fouad', 'Siniora', 'said', 'the', 'Lebanese', 'people', 'are', 'resolute', 'in', 'preventing', 'such', 'attempts', 'from', 'destroying', 'their', 'spirit', '.']
[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
['B-org' 'I-org' 'O']


## Split du jeu de données
On split notre jeu de données de la façon suivante :
- `test` = 20%
- `train` = [`minitrain`, `valid`] = 80%
- `minitrain` = 60%
- `valid` = 20%

In [7]:
len_test = int(PARAMS.SAMPLE_SIZES.TEST * len(sentences))
len_valid = int(PARAMS.SAMPLE_SIZES.VALID * len(sentences))
autre = label_enc_NER.transform(['O']).item()

# En présence d'un jeu déséquilibré, il vaut mieux stratifier :
strat_tag = [len([_ for _ in tag[i] if _ != autre]) > 0 for i in range(len(tag))]
(
    sentences_train, sentences_test,
    tag_train, tag_test
) = train_test_split(sentences, tag, random_state=PARAMS.SEED, 
                     test_size=len_test, stratify=strat_tag, shuffle=True)

strat_tag_train = [len([_ for _ in tag_train[i] if _ != autre]) > 0 for i in range(len(tag_train))]
(
    sentences_minitrain, sentences_valid,
    tag_minitrain, tag_valid
) = train_test_split(sentences_train, tag_train, random_state=PARAMS.SEED, 
                     test_size=len_valid, stratify=strat_tag_train, shuffle=True)

# Pour se rassurer sur la bonne représentativité de chaque classe dans les échantillons stratifiés : 
pd.Series(np.concatenate(tag_train)).value_counts(normalize=True)
pd.Series(np.concatenate(tag_valid)).value_counts(normalize=True)

len(sentences_minitrain), len(sentences_valid), len(sentences_test)

2    0.964808
0    0.019224
1    0.015968
dtype: float64

2    0.965036
0    0.018949
1    0.016014
dtype: float64

(28777, 9591, 9591)

In [8]:
minitrain_dataset = EntityDataset(
    texts=sentences_minitrain, tags=tag_minitrain
)
minitrain_data_loader = DataLoader(
    minitrain_dataset, batch_size=PARAMS.MODEL.TRAIN_BATCH_SIZE, num_workers=16 
)

valid_dataset = EntityDataset(
    texts=sentences_valid, tags=tag_valid
)
valid_data_loader = DataLoader(
    valid_dataset, batch_size=PARAMS.MODEL.VALID_BATCH_SIZE, num_workers=16
)

test_dataset = EntityDataset(
    texts=sentences_test, tags=tag_test
)
test_data_loader = DataLoader(
    test_dataset, batch_size=PARAMS.MODEL.VALID_BATCH_SIZE, num_workers=16
)

In [9]:
i = 40
print(test_dataset.texts[i])
for key, value in test_dataset[i].items():
    print(key + ':', value)

['The', 'Spanish', 'troops', 'will', 'join', 'the', 'European', 'Union', 'force', 'sent', 'to', 'protect', 'ships', 'against', 'hijackings', 'and', 'attacks', 'by', 'Somali', 'pirates', '.']
ids: tensor([  101,  1996,  3009,  3629,  2097,  3693,  1996,  2647,  2586,  2486,
         2741,  2000,  4047,  3719,  2114,  7632, 17364,  8613,  1998,  4491,
         2011, 16831,  8350,  1012,   102,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,   

## Expériences pré-entraînement

Avant de lancer le modèle, quelle _loss_ serait considérée comme statisfaisante ? Calculons à l'avance les loss que donneraient :

- un modèle qui prédit aléatoirement 
- un modèle qui prédit la classe modale ("O")
- un modèle qui prédit de façon satisfaisante, avec une accuracy donnée*
- un modèle oracle (prédiction parfaite)

*Remarque : pour deux modèles qui produisent la même accuracy fixée, l'un peut être meilleur que l'autre ! En effet, les probas prédites peuvent être plus ou moins proches de la réalité, même si elles peuvent résulter en des prédictions binaires identiques (et donc en la même accuracy).

![image.png](attachment:image.png)

In [10]:
loss_ce = nn.CrossEntropyLoss(reduction='mean')
y_true = torch.tensor(label_enc_NER.transform(data.Tag), dtype=torch.long)

y_true_ohe_parfait = torch.tensor(
    pd.get_dummies(y_true).astype(float).replace(0, -1234).to_numpy(),
    dtype=torch.float
)

# -3 est un nombre arbitraire qui affecte à la bonne classe une proba de 5%.
y_true_ohe = torch.tensor(
    pd.get_dummies(y_true).astype(float).replace(0, -3).to_numpy(),
    dtype=torch.float
)

In [11]:
#  Modèle oracle :
y_pred = y_true_ohe_parfait.detach()
print('Loss oracle :', loss_ce(y_pred, y_true))

# Modèle avec une accuracy de 98 %
# Pour rappel, une accuracy en dessous de 96 % est mauvaise !
# → data.Tag.value_counts(normalize=True)
accuracy_souhaitée = 0.98
nb_lignes_à_erroring = int((1 - accuracy_souhaitée) * y_true.shape[0])
y_pred = y_true_ohe.detach()
y_pred[0:nb_lignes_à_erroring, :] = shift_tensor(y_pred[0:nb_lignes_à_erroring, :])
print("Loss d'un modèle satisfaisant :", loss_ce(y_pred, y_true))

# Modèle qui prédit la classe modale : 
classe_modale = data.Tag.value_counts().nlargest(1).index[0]
y_pred = np.full((len(data), num_tag), -3)
y_pred[:, label_enc_NER.transform([classe_modale])[0]] = 1
y_pred = torch.tensor(y_pred, dtype=torch.float)
print("Loss du modèle qui prédit argcount :", loss_ce(y_pred, y_true))

# Modèle aléatoire :
y_pred = torch.tensor(
    pd.get_dummies(torch.randint(0, num_tag, (len(data),))).astype(float).replace(0, -3).to_numpy(),
    dtype=torch.float
)
print("Loss d'un modèle aléatoire :", loss_ce(y_pred, y_true))

Loss oracle : tensor(0.)
Loss d'un modèle satisfaisant : tensor(0.1160)
Loss du modèle qui prédit argcount : tensor(0.1768)
Loss d'un modèle aléatoire : tensor(2.7029)


## Paramétrage

In [12]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cpu')

In [13]:
model = EntityModel(num_tag=num_tag)
model.to(device);

I0606 09:28:21.391445 140644682426176 modeling.py:577] loading archive file data/inputs/models/bert-base-uncased
I0606 09:28:21.393192 140644682426176 modeling.py:598] Model config {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.6.0.dev0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

I0606 09:28:23.454669 140644682426176 modeling.py:648] Weights of BertForTokenClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias']
I0606 09:28:23.455593 140644682426176 modeling.py:651] Weights from pretrai

In [14]:
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]

In [15]:
optimizer_parameters = [
    {
        "params": [
            p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.001,
    },
    {
        
        "params": [
            p for n, p in param_optimizer if any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.0,
    },
]

In [16]:
num_train_steps = int(len(sentences_minitrain) / PARAMS.MODEL.TRAIN_BATCH_SIZE * PARAMS.MODEL.EPOCHS)
print(f"{num_train_steps} batchs vont être envoyés dans le réseau au cours de {PARAMS.MODEL.EPOCHS} epochs.")

optimizer = torch.optim.AdamW(optimizer_parameters, lr=PARAMS.MODEL.LEARNING_RATE)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=num_train_steps
)

4496 batchs vont être envoyés dans le réseau au cours de 5 epochs.


## Entraînement

In [17]:
best_loss = np.inf

pbar = tqdm(range(PARAMS.MODEL.EPOCHS))
for num_epoch, epoch in enumerate(pbar):
    train_loss = train_fn(minitrain_data_loader, model, optimizer,
                          device, scheduler, pbar=pbar, num_epoch=num_epoch)
    
    # On évalue le modèle à la fin de chaque epoch
    test_loss = eval_fn(valid_data_loader, model, device)
    print(f"Train Loss = {train_loss} Valid Loss = {test_loss}")
    
    if test_loss < best_loss:
        torch.save(model.state_dict(), PARAMS.PATHS.MODEL_SAVED)
        best_loss = test_loss

  0%|          | 0/5 [00:00<?, ?it/s]

Batch #0 : loss = 0.934720
Batch #10 : loss = 0.374666
Batch #20 : loss = 0.352794
Batch #30 : loss = 0.260477
Batch #40 : loss = 0.197657
Batch #50 : loss = 0.184880
Batch #60 : loss = 0.214349
Batch #70 : loss = 0.215716
Batch #80 : loss = 0.175364
Batch #90 : loss = 0.186485
Batch #100 : loss = 0.274800
Batch #110 : loss = 0.204509
Batch #120 : loss = 0.214829
Batch #130 : loss = 0.187493
Batch #140 : loss = 0.235032
Batch #150 : loss = 0.243582
Batch #160 : loss = 0.199930
Batch #170 : loss = 0.175343
Batch #180 : loss = 0.182348
Batch #190 : loss = 0.173486
Batch #200 : loss = 0.219236
Batch #210 : loss = 0.198770
Batch #220 : loss = 0.200684
Batch #230 : loss = 0.203912
Batch #240 : loss = 0.184976
Batch #250 : loss = 0.197620
Batch #260 : loss = 0.214885
Batch #270 : loss = 0.225195
Batch #280 : loss = 0.231820
Batch #290 : loss = 0.173404
Batch #300 : loss = 0.173494
Batch #310 : loss = 0.200523
Batch #320 : loss = 0.225873
Batch #330 : loss = 0.195248
Batch #340 : loss = 0.173

Batch #80 : loss = 0.146031
Batch #90 : loss = 0.134234
Batch #100 : loss = 0.153072
Batch #110 : loss = 0.124656
Batch #120 : loss = 0.163127
Batch #130 : loss = 0.134035
Batch #140 : loss = 0.129163
Batch #150 : loss = 0.163015
Batch #160 : loss = 0.152959
Batch #170 : loss = 0.120873
Batch #180 : loss = 0.137639
Batch #190 : loss = 0.156223
Batch #200 : loss = 0.143010
Batch #210 : loss = 0.113951
Batch #220 : loss = 0.139733
Batch #230 : loss = 0.154226
Batch #240 : loss = 0.147675
Batch #250 : loss = 0.174658
Batch #260 : loss = 0.150045
Batch #270 : loss = 0.159185
Batch #280 : loss = 0.176274
Batch #290 : loss = 0.141631
Batch #300 : loss = 0.154438
Batch #310 : loss = 0.122628
Batch #320 : loss = 0.157396
Batch #330 : loss = 0.134598
Batch #340 : loss = 0.122674
Batch #350 : loss = 0.107843
Batch #360 : loss = 0.142952
Batch #370 : loss = 0.168944
Batch #380 : loss = 0.137220
Batch #390 : loss = 0.139399
Batch #400 : loss = 0.136971
Batch #410 : loss = 0.138847
Batch #420 : los

KeyboardInterrupt: 

## Prédictions

Utilisons le modèle finetuné pour prédire une nouvelle phrase.

In [18]:
# On désérialise le meilleur modèle enregistré :
model = EntityModel(num_tag=num_tag)
model.load_state_dict(torch.load('data/models/model_trained.bin'))

I0606 12:06:27.818886 140644682426176 modeling.py:577] loading archive file data/inputs/models/bert-base-uncased
I0606 12:06:27.820515 140644682426176 modeling.py:598] Model config {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.6.0.dev0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

I0606 12:06:29.735961 140644682426176 modeling.py:648] Weights of BertForTokenClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias']
I0606 12:06:29.736854 140644682426176 modeling.py:651] Weights from pretrai

<All keys matched successfully>

In [19]:
i = 2
single_example = test_dataset[i]
print(' '.join(test_dataset.texts[i]))
nb_tokens = single_example['mask'].sum().item()

single_example = {k: single_example[k].unsqueeze(0) for k in single_example}
with torch.no_grad():
    output, loss = model(**single_example)
predictions_probas = nn.functional.softmax(output, dim=2).detach().squeeze()
predictions_probas, predictions_classes = torch.max(predictions_probas, dim=1)

Automobile sales in the United States fell in June as high gasoline prices kept consumers away from trucks and sports utility vehicles that require a lot of fuel .


In [20]:
ids = single_example['ids'].squeeze()[:nb_tokens]
pd.DataFrame({
    'token_text': [EntityDataset.tokenizer.decode([token]) for token in ids],
    'token_id': ids,
    'class_code': predictions_classes[:nb_tokens],
    'class_label': label_enc_NER.inverse_transform(predictions_classes)[:nb_tokens],
    'class_proba': predictions_probas[:nb_tokens]
}).T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30
token_text,[CLS],automobile,sales,in,the,united,states,fell,in,june,as,high,gasoline,prices,kept,consumers,away,from,trucks,and,sports,utility,vehicles,that,require,a,lot,of,fuel,.,[SEP]
token_id,101,9935,4341,1999,1996,2142,2163,3062,1999,2238,2004,2152,13753,7597,2921,10390,2185,2013,9322,1998,2998,9710,4683,2008,5478,1037,2843,1997,4762,1012,102
class_code,0,2,2,2,2,2,0,2,2,1,2,1,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,0,0
class_label,B-org,O,O,O,O,O,B-org,O,O,I-org,O,I-org,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,O,B-org,B-org
class_proba,0.999283,0.990896,0.999125,0.999999,0.998966,0.980616,0.333333,0.998936,0.999536,0.499692,0.999999,0.499783,0.998798,0.999391,1,0.999527,1,0.999503,0.999373,0.998251,0.99998,0.998454,0.999294,0.999999,0.999999,0.998281,0.999675,0.998903,0.999703,0.499861,0.999317
