# (XL)MBERT Model Finetuning for Language Identification
This notebook contains code to unfreeze the layers of BERT family models, in conjunction with training a linear probe on top for language identification (as has been done in our accompanying paper.

We have chosen to include this code in a notebook rather than implement it in the framework of this repo, as this would require a major rewriting of the code.

In [1]:
from collections import Counter
from math import sqrt
import unicodedata

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Code to load the wili-2018 dataset
def load_ds(path: str):
    with open(path, encoding="utf8") as f:
        for l in f:
            yield l.rstrip("\n")


# Load the data
x_train = load_ds("../data/wili-2018/x_train.txt")
y_train = load_ds("../data/wili-2018/y_train.txt")
x_test = load_ds("../data/wili-2018/x_test.txt")
y_test = load_ds("../data/wili-2018/y_test.txt")

x_train = pd.DataFrame(x_train, columns=["sentence"])
y_train = pd.DataFrame(y_train, columns=["lang"])
x_test = pd.DataFrame(x_test, columns=["sentence"])
y_test = pd.DataFrame(y_test, columns=["lang"])

# Preperation work for classification
langs = sorted(y_train.lang.unique())

language_to_index = {lang: i for i, lang in enumerate(langs)}
index_to_language = {i: lang for i, lang in enumerate(langs)}

y_train_id = [language_to_index[lang] for lang in y_train.lang]

# [!] Select your model
Please select the model you want to finetune here. In the paper we have used "bert-base-multilingual-cased" (MBERT) and "xlm-roberta-base" as our DL approaches.

In [None]:
model_name = "xlm-roberta-base"
# model_name = "bert-base-multilingual-cased"

In [2]:
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaModel: ['lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## [1] Turn this to empty list to run without features (meta-data)

In [3]:
BERT_DIM = 768

UNICODE_CATEGORIES = ["Zs", "Po", "Lu", "Ll", "Pd", "Ps", "Pe", "Lo", "Mn", "Pf"]
# UNICODE_CATEGORIES = []

DIM = len(UNICODE_CATEGORIES) + BERT_DIM

## Add features to the classifier

In [4]:
df = pd.read_csv('../data/train_pre.csv')

# Remove the "useless" features, decided to be useless in "Feature Selection.ipynb"
df = df.drop(columns=['Nd', 'Cc', 'No', 'Nl', 'Co', 'Cn'])
df = df.reset_index()

In [5]:
data = df.drop('Language', axis=1).drop('index', axis=1)

# Drop all but interesting features.
data = data[data.columns.intersection(UNICODE_CATEGORIES)]
print(data)

        Lu   Ll   Zs  Po  Pd  Ps  Pe   Lo  Mn  Pf
0        9  283   45   6   2   0   0    0   0   0
1       17  120   31  11   3   3   3    0   0   0
2        0    0   63   8   0   0   0  237  67   0
3       42  750  159  37   1  16  15    0   0   4
4        3   15   13   1   0   1   1  179  41   0
...     ..  ...  ...  ..  ..  ..  ..  ...  ..  ..
117495  26  792  187  27   0   3   3    0   0   0
117496  14  200   45   9   1   0   0    0   0   0
117497   0    0    0  34   0   1   1  195   0   0
117498  19  495  106  16   0   0   0    0   0   0
117499   6  238   46   4   0   0   0    0   0   0

[117500 rows x 10 columns]


In [6]:
norm_data = data.div(data.sum(axis=1), axis=0)
print(norm_data)

              Lu        Ll        Zs        Po        Pd        Ps        Pe  \
0       0.026087  0.820290  0.130435  0.017391  0.005797  0.000000  0.000000   
1       0.090426  0.638298  0.164894  0.058511  0.015957  0.015957  0.015957   
2       0.000000  0.000000  0.168000  0.021333  0.000000  0.000000  0.000000   
3       0.041016  0.732422  0.155273  0.036133  0.000977  0.015625  0.014648   
4       0.011811  0.059055  0.051181  0.003937  0.000000  0.003937  0.003937   
...          ...       ...       ...       ...       ...       ...       ...   
117495  0.025048  0.763006  0.180154  0.026012  0.000000  0.002890  0.002890   
117496  0.052045  0.743494  0.167286  0.033457  0.003717  0.000000  0.000000   
117497  0.000000  0.000000  0.000000  0.147186  0.000000  0.004329  0.004329   
117498  0.029874  0.778302  0.166667  0.025157  0.000000  0.000000  0.000000   
117499  0.020408  0.809524  0.156463  0.013605  0.000000  0.000000  0.000000   

              Lo        Mn        Pf  


In [8]:
data['features'] = data.values.tolist()

In [9]:
print(data)

        Lu   Ll   Zs  Po  Pd  Ps  Pe   Lo  Mn  Pf  \
0        9  283   45   6   2   0   0    0   0   0   
1       17  120   31  11   3   3   3    0   0   0   
2        0    0   63   8   0   0   0  237  67   0   
3       42  750  159  37   1  16  15    0   0   4   
4        3   15   13   1   0   1   1  179  41   0   
...     ..  ...  ...  ..  ..  ..  ..  ...  ..  ..   
117495  26  792  187  27   0   3   3    0   0   0   
117496  14  200   45   9   1   0   0    0   0   0   
117497   0    0    0  34   0   1   1  195   0   0   
117498  19  495  106  16   0   0   0    0   0   0   
117499   6  238   46   4   0   0   0    0   0   0   

                                      features  
0            [9, 283, 45, 6, 2, 0, 0, 0, 0, 0]  
1          [17, 120, 31, 11, 3, 3, 3, 0, 0, 0]  
2           [0, 0, 63, 8, 0, 0, 0, 237, 67, 0]  
3       [42, 750, 159, 37, 1, 16, 15, 0, 0, 4]  
4          [3, 15, 13, 1, 0, 1, 1, 179, 41, 0]  
...                                        ...  
117495    [26, 792, 

In [10]:
# Create a train dev split
x_train, x_dev, y_train, y_dev = train_test_split(x_train, y_train_id, test_size=0.2, random_state=42, stratify=y_train)

train_set = list(zip(x_train["sentence"], torch.Tensor(y_train).to(device).long()))
dev_set = list(zip(x_dev["sentence"], torch.Tensor(y_dev).to(device).long()))

y_test_id = [language_to_index[lang] for lang in y_test.lang]
test = list(zip(x_test["sentence"], torch.Tensor(y_test_id).to(device).long()))

# Model Definitions

In [11]:
# Diagnostic Classifier (/ Linear Probe)
from sklearn.metrics import classification_report
import numpy as np

class LinearDiagnosticClassifier(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        # Save dims
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Construct layer
        self.layer = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.layer(x)

In [12]:
def metadata_collector(sentences, device, features=list()):
    """
    This function collects the metadata per sentence.
    """
    data = torch.zeros((len(sentences), len(features))).to(device)
    
    for i, paragraph in enumerate(sentences):
        for char in paragraph:
            
            cat = unicodedata.category(char)
            
            for idx in range(len(features)):
                if cat in features[idx]:
                    data[i][idx] += 1            
    
    # normalize the data to percentage of the sentence exists of
    return torch.div(data.T, torch.sum(data, 1)).T

In [13]:
from tqdm.notebook import tqdm_notebook

class TransformerLMwithClassifier(torch.nn.Module):
    def __init__(self, transformer, classifier):
        super().__init__()
        self.transformer = transformer
        self.classifier = classifier

    def forward(self, x):
        x = self.transformer(**tokenizer(list(x), padding=True, truncation=True, return_tensors="pt").to(device))[0][:, 0, :]
        x = self.classifier(x)
        return x
    
class TransformerLMwithClassifier_feature(torch.nn.Module):
    def __init__(self, transformer, classifier):
        super().__init__()
        self.transformer = transformer
        self.classifier = classifier

    def forward(self, x):
        xfeat = metadata_collector(list(x), device, UNICODE_CATEGORIES)
        x = self.transformer(**tokenizer(list(x), padding=True, truncation=True, return_tensors="pt").to(device))[0][:, 0, :]
        x = self.classifier(torch.cat((x, xfeat),1))
        return x

## PyTorch Training code

In [14]:
def train(model, train_dl, optimizer, loss_fn):
    # Train
    model.train()
    train_loss = 0

    for batch in tqdm_notebook(train_dl):
        # Get batch
        x, y = batch
        
        # Get predictions
        y_hat = model.forward(x)

        # Compute loss
        loss = loss_fn(y_hat, y)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update loss
        train_loss += loss.item()


def evaluate(model, dl, loss_fn):
    # Evaluate
    model.eval()
    data_loss = 0
    y_true = []
    y_pred = []

    with torch.no_grad():
        for batch in tqdm_notebook(dl):
            # Get batch
            x, y = batch

            # Get predictions
            y_hat = model.forward(x)

            # Compute loss
            loss = loss_fn(y_hat, y)

            # Update loss
            data_loss += loss.item()

            # Update predictions
            y_true.extend(y.tolist())
            y_pred.extend(y_hat.argmax(dim=1).tolist())

    # Calculate accuracy
    acc = np.mean(np.array(y_true) == np.array(y_pred))

    return data_loss, acc, y_pred, y_true

def train_model(model, train_set, dev_set, epochs=10, batch_size=32, lr=0.001, weight_decay=0.0, unfreeze_layers=1):
    # Create dataloaders
    train_dl = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    dev_dl = torch.utils.data.DataLoader(dev_set, batch_size=batch_size, shuffle=False)

    # Create optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Create loss function
    loss_fn = torch.nn.CrossEntropyLoss()

    # Create scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, verbose=True)

    # Unfreeze the last layer of the model
    for param in model.parameters():
        param.requires_grad = False
    for param in model.transformer.encoder.layer[-unfreeze_layers:].parameters():
        param.requires_grad = True

    for epoch in range(epochs):
        # Train
        train(model, train_dl, optimizer, loss_fn)

        # Evaluate
        # train_loss, train_accuracy, _, _ = evaluate(model, train_dl, loss_fn)
        dev_loss, dev_accuracy, _, _ = evaluate(model, dev_dl, loss_fn)

        # Print results
        print(f"Epoch {epoch + 1}")
        #print(f"Train loss: {train_loss / len(train_dl):.4f}, accuracy: {train_accuracy:.4f}")
        print(f"Dev loss: {dev_loss / len(dev_dl):.4f}, accuracy: {dev_accuracy:.4f}")
        print()

        # Update scheduler
        scheduler.step(dev_loss)

## Loading the (previously trained) classifier
Once you've trained the classifier in the cell below and save it, you can reload it later to do inference.

In [15]:
import os

loading_classifier = False # Be careful! Will overwrite the trained probe in memory.

In [16]:
if loading_classifier:
    if UNICODE_CATEGORIES:
        LMclassifier = torch.load("networks/LMclassifier-Mbert-uni-norm.pt")
    else:
        LMclassifier = torch.load("networks/LMclassifier-Mbert.pt")
else:
    classifier = LinearDiagnosticClassifier(DIM, len(langs)).to(device)
    
    # Create the new classification model
    if UNICODE_CATEGORIES:
        LMclassifier = TransformerLMwithClassifier_feature(model, classifier)
    else:
        LMclassifier = TransformerLMwithClassifier(model, classifier)

## Training Classifier

In [17]:
# Train the model (this was trained for 12 epochs)
# if UNICODE_CATEGORIES :
#     train_feature_model(LMclassifier, train_set, dev_set, epochs=12, batch_size=32, lr=0.001, weight_decay=0.0, unfreeze_layers=1)
# else:
train_model(LMclassifier, train_set, dev_set, epochs=12, batch_size=32, lr=0.001, weight_decay=0.0, unfreeze_layers=1)

  0%|          | 0/2938 [00:00<?, ?it/s]

  0%|          | 0/735 [00:00<?, ?it/s]

Epoch 1
Dev loss: 0.3178, accuracy: 0.9168



  0%|          | 0/2938 [00:00<?, ?it/s]

  0%|          | 0/735 [00:00<?, ?it/s]

Epoch 2
Dev loss: 0.2650, accuracy: 0.9320



  0%|          | 0/2938 [00:00<?, ?it/s]

  0%|          | 0/735 [00:00<?, ?it/s]

Epoch 3
Dev loss: 0.2481, accuracy: 0.9372



  0%|          | 0/2938 [00:00<?, ?it/s]

  0%|          | 0/735 [00:00<?, ?it/s]

Epoch 4
Dev loss: 0.2366, accuracy: 0.9379



  0%|          | 0/2938 [00:00<?, ?it/s]

  0%|          | 0/735 [00:00<?, ?it/s]

Epoch 5
Dev loss: 0.2296, accuracy: 0.9406



  0%|          | 0/2938 [00:00<?, ?it/s]

  0%|          | 0/735 [00:00<?, ?it/s]

Epoch 6
Dev loss: 0.2263, accuracy: 0.9415



  0%|          | 0/2938 [00:00<?, ?it/s]

  0%|          | 0/735 [00:00<?, ?it/s]

Epoch 7
Dev loss: 0.2191, accuracy: 0.9451



  0%|          | 0/2938 [00:00<?, ?it/s]

  0%|          | 0/735 [00:00<?, ?it/s]

Epoch 8
Dev loss: 0.2231, accuracy: 0.9433



  0%|          | 0/2938 [00:00<?, ?it/s]

  0%|          | 0/735 [00:00<?, ?it/s]

Epoch 9
Dev loss: 0.2244, accuracy: 0.9436



  0%|          | 0/2938 [00:00<?, ?it/s]

  0%|          | 0/735 [00:00<?, ?it/s]

Epoch 10
Dev loss: 0.2195, accuracy: 0.9443

Epoch 00010: reducing learning rate of group 0 to 1.0000e-04.


  0%|          | 0/2938 [00:00<?, ?it/s]

  0%|          | 0/735 [00:00<?, ?it/s]

Epoch 11
Dev loss: 0.1852, accuracy: 0.9550



  0%|          | 0/2938 [00:00<?, ?it/s]

  0%|          | 0/735 [00:00<?, ?it/s]

Epoch 12
Dev loss: 0.1835, accuracy: 0.9568



In [18]:
# Classification report on test set
test_dl = torch.utils.data.DataLoader(test, batch_size=32, shuffle=False)

# if UNICODE_CATEGORIES :
#     test_loss, test_accuracy, y_pred, y_true = evaluate_feature(LMclassifier, test_dl, torch.nn.CrossEntropyLoss())
# else:
test_loss, test_accuracy, y_pred, y_true = evaluate(LMclassifier, test_dl, torch.nn.CrossEntropyLoss())

  0%|          | 0/3672 [00:00<?, ?it/s]

In [19]:
print(classification_report(y_true, y_pred, target_names=langs, digits=4))

              precision    recall  f1-score   support

         ace     0.9820    0.9820    0.9820       500
         afr     0.9920    0.9880    0.9900       500
         als     0.8697    0.9080    0.8885       500
         amh     0.9940    0.9920    0.9930       500
         ang     0.9695    0.9520    0.9606       500
         ara     0.9368    0.9780    0.9569       500
         arg     0.9741    0.9780    0.9760       500
         arz     0.9734    0.9500    0.9615       500
         asm     0.9759    0.9700    0.9729       500
         ast     0.9439    0.9420    0.9429       500
         ava     0.8124    0.8920    0.8503       500
         aym     0.9520    0.9520    0.9520       500
         azb     0.9900    0.9940    0.9920       500
         aze     0.9800    0.9800    0.9800       500
         bak     0.9583    0.9640    0.9611       500
         bar     0.8669    0.9120    0.8889       500
         bcl     0.9834    0.9480    0.9654       500
   be-tarask     0.9336    

## Saving the full model

In [None]:
import os

saving_classifier = True # Be careful! Will overwrite the saved model.

In [None]:
if saving_classifier:
    # Save train embeddings to disk as a python pickle
    # Create a directory for the embeddings if it does not exist yet
    if not os.path.exists("networks"):
        os.mkdir("networks")
    torch.save(LMclassifier, "networks/LMclassifier-XLMbert-uni.pt")

## Print the model's failure modes

In [None]:
# Print the paragraphs that were classified wrongly of the test set
for i, (pred, true) in enumerate(zip(y_pred, y_true)):
    if pred != true:
        pred_lang, true_lang = index_to_language[pred], index_to_language[true]
        print(f"Paragraph {i} was classified as '{pred_lang}' but is actually '{true_lang}'")
        print(x_test.iloc[i].sentence)
        print()