# MBERT Linear Probe Training for Language Identification

In [2]:
from collections import Counter
from math import sqrt

import pandas as pd
import numpy as np
import plotly.express as ex
from scipy.spatial.distance import jensenshannon
from sklearn.metrics import confusion_matrix
from joblib import Parallel, delayed
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
import torch

from meta_collector import metadata_collector

def load_ds(path: str):
    with open(path, encoding="utf8") as f:
        for l in f:
            yield l.rstrip("\n")

x_train = load_ds("data/wili-2018/x_train.txt")
y_train = load_ds("data/wili-2018/y_train.txt")
x_test = load_ds("data/wili-2018/x_test.txt")
y_test = load_ds("data/wili-2018/y_test.txt")

x_train = pd.DataFrame(x_train, columns=["sentence"])
y_train = pd.DataFrame(y_train, columns=["lang"])
x_test = pd.DataFrame(x_test, columns=["sentence"])
y_test = pd.DataFrame(y_test, columns=["lang"])

# Create a train dev split
x_train, x_dev, y_train, y_dev = train_test_split(x_train, y_train, test_size=0.2, random_state=42, stratify=y_train)

train = pd.concat([x_train, y_train], axis=1)
dev = pd.concat([x_dev, y_dev], axis=1)
test = pd.concat([x_test, y_test], axis=1)
langs = sorted(y_train.lang.unique())
chars = set(c for s in train.sentence for c in s)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Change here between feature version or not

In [3]:
# Switch here by uncommenting and commenting
features = ["Ll", "Zs", "Lu", "Po", "Pd", "Lo", "Mn", "Ps", "Pe", "Mc"]
# features = []

h_dim = 768 + len(features)
print(h_dim)

778


In [None]:
from transformers import BertTokenizer, BertModel
from tqdm.notebook import tqdm_notebook
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertModel.from_pretrained("bert-base-multilingual-cased").to(device)


def get_bert_embedding(model, tokenizer, sentences, features, batch_size=4, shrinkage_fact=1):
    with torch.no_grad():
        # Create the tensor to house the CLS embeddings
        embeddings = torch.zeros((len(sentences) // shrinkage_fact, 768+len(features))).to(device)

        # Loop over the sentences in batches
        for i in tqdm_notebook(range(0, len(sentences) // shrinkage_fact, batch_size)):
            
            encoded_input = tokenizer(list(sentences[i:i+batch_size]), padding=True, truncation=True, return_tensors="pt").to(device)
            output = model(**encoded_input)
            
            if features:
                meta = metadata_collector(sentences[i:i+batch_size], device, features)
                last_hidden_states = torch.cat((output[0][:, 0, :], meta), 1)
            else:
                # Select the last hidden state of the token `[CLS]`
                last_hidden_states = output[0][:, 0, :]

            # Store the embeddings
            if i+batch_size <= len(embeddings):
                embeddings[i:i+batch_size] = last_hidden_states
            else:
                # Fill up the last ones
                embeddings[i:len(embeddings)] = last_hidden_states[:len(embeddings) - i]

    return embeddings

train_embeddings = get_bert_embedding(model, tokenizer, train.sentence, features, batch_size=24)

In [None]:
dev_embeddings = get_bert_embedding(model, tokenizer, dev.sentence, features, batch_size=24)

## SAVING

In [None]:
import os

saving_embeddings = False  # Be careful! Will overwrite the embeddings.

In [None]:
if saving_embeddings:
    # Save train embeddings to disk as a python pickle
    # Create a directory for the embeddings if it does not exist yet
    if not os.path.exists("embeddings"):
        os.mkdir("embeddings")
    
    if features:
        np.save("embeddings/train_embeddings_f.npy", train_embeddings.cpu())
        np.save("embeddings/dev_embeddings_f.npy", dev_embeddings.cpu())
    else:
        np.save("embeddings/train_embeddings.npy", train_embeddings.cpu())
        np.save("embeddings/dev_embeddings.npy", dev_embeddings.cpu())


In [None]:
len(dev.sentence), len(dev_embeddings)

In [None]:
len(train.sentence), len(train_embeddings)

## LOADING

In [11]:
loading_embeddings = True

if loading_embeddings:
    
    if features:
        train_embeddings = torch.from_numpy(np.load("embeddings/train_embeddings_f.npy")).to(device)
        dev_embeddings = torch.from_numpy(np.load("embeddings/dev_embeddings_f.npy")).to(device)
        print("Model with features is loaded.")
    else:
        train_embeddings = torch.from_numpy(np.load("embeddings/train_embeddings.npy")).to(device)
        dev_embeddings = torch.from_numpy(np.load("embeddings/dev_embeddings.npy")).to(device)
        print("Model without features is loaded.")

Model with features is loaded.


In [12]:
len(dev.sentence), len(dev_embeddings)

(23500, 23500)

In [13]:
len(train.sentence), len(train_embeddings)

(94000, 94000)

## Training Classifier

In [14]:
# DIAGNOSTIC CLASSIFIER
from skorch import NeuralNet
from skorch.helper import predefined_split
from skorch.dataset import Dataset
from sklearn.metrics import classification_report
import numpy as np

class LinearDiagnosticClassifier(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        # Save dims
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Construct layer
        self.layer = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.layer(x)

In [15]:
language_to_index = {lang: i for i, lang in enumerate(langs)}
index_to_language = {i: lang for i, lang in enumerate(langs)}

y_dev_id = [language_to_index[lang] for lang in y_dev.lang]
y_train_id = [language_to_index[lang] for lang in y_train.lang]
valid_ds = Dataset(dev_embeddings, y_dev_id[:len(dev_embeddings)])

In [16]:
net = NeuralNet(
    module=LinearDiagnosticClassifier,
    module__input_dim = h_dim,
    module__output_dim = len(set(y_train.lang)),
    criterion=torch.nn.CrossEntropyLoss,
    train_split=predefined_split(valid_ds),
    max_epochs=100,
    device=device,
    verbose=1,
    optimizer = torch.optim.SGD,
    optimizer__momentum=0.9,
    optimizer__lr=0.2,
#     optimizer = torch.optim.Adam,
)

net.fit(train_embeddings, y_train_id[:len(train_embeddings)])

  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1        [36m2.0511[0m        [32m1.1200[0m  0.6182
      2        [36m0.8343[0m        [32m0.7946[0m  0.6189
      3        [36m0.6256[0m        [32m0.6619[0m  0.6193
      4        [36m0.5213[0m        [32m0.5926[0m  0.6432
      5        [36m0.4548[0m        [32m0.5494[0m  0.6310
      6        [36m0.4076[0m        [32m0.5188[0m  0.6485
      7        [36m0.3718[0m        [32m0.4952[0m  0.6359
      8        [36m0.3435[0m        [32m0.4763[0m  0.6254
      9        [36m0.3201[0m        [32m0.4610[0m  0.6286
     10        [36m0.3005[0m        [32m0.4485[0m  0.6181
     11        [36m0.2836[0m        [32m0.4382[0m  0.6226
     12        [36m0.2688[0m        [32m0.4296[0m  0.6215
     13        [36m0.2557[0m        [32m0.4224[0m  0.6253
     14        [36m0.2440[0m        [32m0.4162[0m  0.6358
     15        [36m0.2335[0m        [32m0

<class 'skorch.net.NeuralNet'>[initialized](
  module_=LinearDiagnosticClassifier(
    (layer): Linear(in_features=778, out_features=235, bias=True)
  ),
)

## Save Linear Probe

In [None]:
import os

saving_classifier = False # Be careful! Will overwrite the saved probe.

In [None]:
if saving_classifier:
    # Save train embeddings to disk as a python pickle
    # Create a directory for the embeddings if it does not exist yet
    if not os.path.exists("networks"):
        os.mkdir("networks")
    
    if features:
        torch.save(net, "networks/linear_probe_f.pt")
    else:
        torch.save(net, "networks/linear_probe.pt")

## Load Linear Probe

In [None]:
import os

loading_classifier = True # Be careful! Will overwrite the trained probe in memory.

In [None]:
if loading_classifier:
    if features:
        net = torch.load("networks/linear_probe_f.pt")
    else:
        net = torch.load("networks/linear_probe.pt")

In [17]:
dev_y_pred_id = np.argmax(net.predict(dev_embeddings), axis=1)
dev_y_pred = [index_to_language[id] for id in dev_y_pred_id]
print(classification_report(y_dev[:len(dev_y_pred)], dev_y_pred, target_names=langs, zero_division=0))

              precision    recall  f1-score   support

         ace       0.96      0.97      0.97       100
         afr       0.98      0.99      0.99       100
         als       0.68      0.86      0.76       100
         amh       0.98      0.94      0.96       100
         ang       0.92      0.94      0.93       100
         ara       0.89      0.97      0.93       100
         arg       0.99      0.99      0.99       100
         arz       0.97      0.88      0.92       100
         asm       0.93      0.98      0.96       100
         ast       0.92      0.98      0.95       100
         ava       0.87      0.79      0.83       100
         aym       0.92      0.89      0.90       100
         azb       1.00      1.00      1.00       100
         aze       0.99      0.98      0.98       100
         bak       0.97      0.98      0.98       100
         bar       0.85      0.85      0.85       100
         bcl       0.94      0.93      0.93       100
   be-tarask       0.69    