# MBERT Linear Probe Training for Language Identification

In [30]:
import os
from collections import Counter
from math import sqrt

import random
import pandas as pd
import numpy as np
import plotly.express as ex
import torch
import skorch
from scipy.spatial.distance import jensenshannon
from sklearn.metrics import confusion_matrix
from joblib import Parallel, delayed
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from skorch import NeuralNet
from skorch.helper import predefined_split
from skorch.dataset import Dataset
from sklearn.metrics import classification_report
from transformers import BertTokenizer, BertModel
from tqdm.auto import trange

from meta_collector import metadata_collector


def load_ds(path: str):
    with open(path, encoding="utf8") as f:
        for l in f:
            yield l.rstrip("\n")

def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(13331)
    
x_train = load_ds("data/wili-2018/x_train.txt")
y_train = load_ds("data/wili-2018/y_train.txt")
x_test = load_ds("data/wili-2018/x_test.txt")
y_test = load_ds("data/wili-2018/y_test.txt")

x_train = pd.DataFrame(x_train, columns=["sentence"])
y_train = pd.DataFrame(y_train, columns=["lang"])
x_test = pd.DataFrame(x_test, columns=["sentence"])
y_test = pd.DataFrame(y_test, columns=["lang"])

# Create a train dev split
x_train, x_dev, y_train, y_dev = train_test_split(x_train, y_train, test_size=0.2, random_state=42, stratify=y_train)

train = pd.concat([x_train, y_train], axis=1)
dev = pd.concat([x_dev, y_dev], axis=1)
test = pd.concat([x_test, y_test], axis=1)
langs = sorted(y_train.lang.unique())
chars = set(c for s in train.sentence for c in s)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Change here between feature version or not

In [31]:
# Switch here by uncommenting and commenting
features = ["Ll", "Zs", "Lu", "Po", "Pd", "Lo", "Mn", "Ps", "Pe", "Mc"]
# features = []

BERT_DIM = 768
h_dim = BERT_DIM + len(features)
print(h_dim)

778


In [32]:
def get_bert_embedding(model, tokenizer, sentences, batch_size=4, shrinkage_fact=1):
    model.eval()
    sentences = sentences[:len(sentences) // shrinkage_fact]
    with torch.no_grad():
        # Create the tensor to house the CLS embeddings
        embeddings = torch.zeros((len(sentences), BERT_DIM)).to(device)

        # Loop over the sentences in batches
        for i in trange(0, len(sentences), batch_size):
            encoded_input = tokenizer(sentences[i:i+batch_size], padding=True, truncation=True, return_tensors="pt").to(device)
            output = model(**encoded_input)
            # Take [CLS] token embedding
            last_hidden_states = output[0][:, 0, :]

            # Store the embeddings
            embeddings[i:i+len(last_hidden_states)] = last_hidden_states
    return embeddings


def get_features_embeddings(sentences, features, batch_size=4, shrinkage_fact=1):
    sentences = sentences[:len(sentences) // shrinkage_fact]    
    embeddings = torch.zeros((len(sentences), len(features)))
    
    for i in trange(0, len(sentences), batch_size):
        meta = metadata_collector(sentences[i:i+batch_size], device, features)
        embeddings[i:i+len(meta)] = meta

    return embeddings.to(device)


def extend_embeddings(bert_embeddings, features_embeddings):
    return torch.cat((bert_embeddings, features_embeddings), 1)

## LOADING

In [33]:
loading_embeddings = True

if loading_embeddings:
    bert_train_embeddings = torch.from_numpy(np.load("embeddings/bert_train_embeddings.npy")).to(device)
    bert_dev_embeddings = torch.from_numpy(np.load("embeddings/bert_dev_embeddings.npy")).to(device)
    print("Loaded BERT embeddings.")

    if features:
        features_train_embeddings = torch.from_numpy(np.load("embeddings/features_train_embeddings.npy")).to(device)
        features_dev_embeddings = torch.from_numpy(np.load("embeddings/features_dev_embeddings.npy")).to(device)
        print("Loaded FEATURES embeddings.")
else:
    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
    model = BertModel.from_pretrained("bert-base-multilingual-cased").to(device)

    bert_train_embeddings = get_bert_embedding(model, tokenizer, train.sentence, features, batch_size=24)
    bert_dev_embeddings = get_bert_embedding(model, tokenizer, dev.sentence, features, batch_size=24)
    train_embeddings = bert_train_embeddings
    dev_embeddings = bert_dev_embeddings
    
    if features:
        features_train_embeddings = get_features_embeddings(train.sentence, features, batch_size=4)
        features_dev_embeddings = get_features_embeddings(dev.sentence, features, batch_size=4)

if features:
    train_embeddings = extend_embeddings(bert_train_embeddings, features_train_embeddings)
    dev_embeddings = extend_embeddings(bert_dev_embeddings, features_dev_embeddings)

Loaded BERT embeddings.
Loaded FEATURES embeddings.


## SAVING

In [34]:
saving_embeddings = False  # Be careful! Will overwrite the embeddings.

In [35]:
if saving_embeddings:
    # Save train embeddings to disk as a python pickle
    # Create a directory for the embeddings if it does not exist yet
    if not os.path.exists("embeddings"):
        os.mkdir("embeddings")
    
    np.save("embeddings/bert_train_embeddings.npy", bert_train_embeddings.cpu())
    np.save("embeddings/bert_dev_embeddings.npy", bert_dev_embeddings.cpu())

    if features:
        np.save("embeddings/features_train_embeddings.npy", features_train_embeddings.cpu())
        np.save("embeddings/features_dev_embeddings.npy", features_dev_embeddings.cpu())

In [36]:
len(dev.sentence), dev_embeddings.shape

(23500, torch.Size([23500, 778]))

In [37]:
len(train.sentence), train_embeddings.shape

(94000, torch.Size([94000, 778]))

## Training Classifier

In [38]:
# DIAGNOSTIC CLASSIFIER
class LinearDiagnosticClassifier(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        # Save dims
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Construct layer
        self.layer = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.layer(x)

In [39]:
language_to_index = {lang: i for i, lang in enumerate(langs)}
index_to_language = {i: lang for i, lang in enumerate(langs)}

y_dev_id = [language_to_index[lang] for lang in y_dev.lang]
y_train_id = [language_to_index[lang] for lang in y_train.lang]
valid_ds = Dataset(dev_embeddings, y_dev_id[:len(dev_embeddings)])

In [40]:
net = NeuralNet(
    module=LinearDiagnosticClassifier,
    module__input_dim=train_embeddings.shape[-1],
    module__output_dim=len(langs),
    criterion=torch.nn.CrossEntropyLoss,
    train_split=predefined_split(valid_ds),
    max_epochs=1000,
    device=device,
    verbose=1,
    optimizer = torch.optim.SGD,
    optimizer__momentum=0.9,
    optimizer__lr=0.2,
#     optimizer = torch.optim.Adam,
    callbacks=[
        skorch.callbacks.EarlyStopping()
    ]
)

net.fit(train_embeddings, y_train_id[:len(train_embeddings)])

  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1        [36m2.0506[0m        [32m1.1182[0m  2.3840
      2        [36m0.8340[0m        [32m0.7943[0m  2.5227
      3        [36m0.6255[0m        [32m0.6619[0m  2.7145
      4        [36m0.5213[0m        [32m0.5926[0m  2.2552
      5        [36m0.4548[0m        [32m0.5493[0m  2.2774
      6        [36m0.4076[0m        [32m0.5187[0m  2.2728
      7        [36m0.3718[0m        [32m0.4951[0m  2.4219
      8        [36m0.3434[0m        [32m0.4763[0m  2.2622
      9        [36m0.3201[0m        [32m0.4610[0m  2.2458
     10        [36m0.3004[0m        [32m0.4485[0m  2.2264
     11        [36m0.2836[0m        [32m0.4382[0m  2.2596
     12        [36m0.2688[0m        [32m0.4296[0m  2.2469
     13        [36m0.2557[0m        [32m0.4223[0m  2.3556
     14        [36m0.2440[0m        [32m0.4161[0m  2.2917
     15        [36m0.2335[0m        [32m0

<class 'skorch.net.NeuralNet'>[initialized](
  module_=LinearDiagnosticClassifier(
    (layer): Linear(in_features=778, out_features=235, bias=True)
  ),
)

## Save Linear Probe

In [None]:
import os

saving_classifier = False # Be careful! Will overwrite the saved probe.

In [None]:
if saving_classifier:
    # Save train embeddings to disk as a python pickle
    # Create a directory for the embeddings if it does not exist yet
    if not os.path.exists("networks"):
        os.mkdir("networks")
    
    if features:
        torch.save(net, "networks/linear_probe_f.pt")
    else:
        torch.save(net, "networks/linear_probe.pt")

## Load Linear Probe

In [None]:
import os

loading_classifier = False # Be careful! Will overwrite the trained probe in memory.

In [None]:
if loading_classifier:
    if features:
        net = torch.load("networks/linear_probe_f.pt")
    else:
        net = torch.load("networks/linear_probe.pt")

In [17]:
dev_y_pred_id = np.argmax(net.predict(dev_embeddings), axis=1)
dev_y_pred = [index_to_language[id] for id in dev_y_pred_id]
print(classification_report(y_dev[:len(dev_y_pred)], dev_y_pred, target_names=langs, zero_division=0))

              precision    recall  f1-score   support

         ace       0.96      0.97      0.97       100
         afr       0.98      0.99      0.99       100
         als       0.68      0.86      0.76       100
         amh       0.98      0.94      0.96       100
         ang       0.92      0.94      0.93       100
         ara       0.89      0.97      0.93       100
         arg       0.99      0.99      0.99       100
         arz       0.97      0.88      0.92       100
         asm       0.93      0.98      0.96       100
         ast       0.92      0.98      0.95       100
         ava       0.87      0.79      0.83       100
         aym       0.92      0.89      0.90       100
         azb       1.00      1.00      1.00       100
         aze       0.99      0.98      0.98       100
         bak       0.97      0.98      0.98       100
         bar       0.85      0.85      0.85       100
         bcl       0.94      0.93      0.93       100
   be-tarask       0.69    