In [1]:
from collections import Counter
from math import sqrt

import pandas as pd
import numpy as np
import plotly.express as ex
from scipy.spatial.distance import jensenshannon
from sklearn.metrics import confusion_matrix
from joblib import Parallel, delayed
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
import torch

def load_ds(path: str):
    with open(path, encoding="utf8") as f:
        for l in f:
            yield l.rstrip("\n")

x_train = load_ds("data/wili-2018/x_train.txt")
y_train = load_ds("data/wili-2018/y_train.txt")
x_test = load_ds("data/wili-2018/x_test.txt")
y_test = load_ds("data/wili-2018/y_test.txt")

x_train = pd.DataFrame(x_train, columns=["sentence"])
y_train = pd.DataFrame(y_train, columns=["lang"])
x_test = pd.DataFrame(x_test, columns=["sentence"])
y_test = pd.DataFrame(y_test, columns=["lang"])

# Create a train dev split
x_train, x_dev, y_train, y_dev = train_test_split(x_train, y_train, test_size=0.2, random_state=42, stratify=y_train)

train = pd.concat([x_train, y_train], axis=1)
dev = pd.concat([x_dev, y_dev], axis=1)
test = pd.concat([x_test, y_test], axis=1)
langs = sorted(y_train.lang.unique())
chars = set(c for s in train.sentence for c in s)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
from transformers import BertTokenizer, BertModel
from tqdm.notebook import tqdm_notebook
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertModel.from_pretrained("bert-base-multilingual-cased").to(device)

def get_bert_embedding(model, tokenizer, sentences, batch_size=4, shrinkage_fact=1):
    with torch.no_grad():
        # Create the tensor to house the CLS embeddings
        embeddings = torch.zeros((len(sentences) // shrinkage_fact, 768)).to(device)

        # Loop over the sentences in batches
        for i in tqdm_notebook(range(0, len(sentences) // shrinkage_fact, batch_size)):
            encoded_input = tokenizer(list(sentences[i:i+batch_size]), padding=True, truncation=True, return_tensors="pt").to(device)
            output = model(**encoded_input)

            # Select the last hidden state of the token `[CLS]`
            last_hidden_states = output[0][:, 0, :]

            # Store the embeddings
            if i+batch_size < len(embeddings):
                embeddings[i:i+batch_size] = last_hidden_states
            else:
                # Fill up the last ones
                embeddings[i:len(embeddings)] = last_hidden_states[:len(embeddings) - i]

                last_counter = i
                idx = 0
                
                while last_counter < len(embeddings):
                    embeddings[last_counter] = last_hidden_states[idx]
                    last_counter += 1
                    idx += 1

    return embeddings

train_embeddings = get_bert_embedding(model, tokenizer, train.sentence, batch_size=24, shrinkage_fact=4)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


  0%|          | 0/980 [00:00<?, ?it/s]

In [3]:
dev_embeddings = get_bert_embedding(model, tokenizer, dev.sentence, batch_size=24, shrinkage_fact=4)

  0%|          | 0/245 [00:00<?, ?it/s]

In [4]:
len(dev.sentence), len(dev_embeddings)

(23500, 5875)

In [5]:
len(train.sentence), len(train_embeddings)

(94000, 23500)

In [6]:
# DIAGNOSTIC CLASSIFIER
from skorch import NeuralNet
from skorch.helper import predefined_split
from skorch.dataset import Dataset
from sklearn.metrics import classification_report
import numpy as np

class LinearDiagnosticClassifier(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        # Save dims
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Construct layer
        self.layer = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.layer(x)

In [7]:
language_to_index = {lang: i for i, lang in enumerate(langs)}
index_to_language = {i: lang for i, lang in enumerate(langs)}

y_dev_id = [language_to_index[lang] for lang in y_dev.lang]
y_train_id = [language_to_index[lang] for lang in y_train.lang]
valid_ds = Dataset(dev_embeddings, y_dev_id[:len(dev_embeddings)])

In [8]:
net = NeuralNet(
    module=LinearDiagnosticClassifier,
    module__input_dim = 768,
    module__output_dim = len(set(y_train.lang)),
    criterion=torch.nn.CrossEntropyLoss,
    train_split=predefined_split(valid_ds),
    max_epochs=200,
    device=device,
    verbose=1,
    lr=0.2,
)

net.fit(train_embeddings, y_train_id[:len(train_embeddings)])

  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1        [36m5.0040[0m        [32m4.6239[0m  0.1776
      2        [36m4.3192[0m        [32m4.0913[0m  0.1495
      3        [36m3.8154[0m        [32m3.6684[0m  0.1485
      4        [36m3.4069[0m        [32m3.3208[0m  0.1565
      5        [36m3.0695[0m        [32m3.0312[0m  0.1512
      6        [36m2.7881[0m        [32m2.7879[0m  0.1495
      7        [36m2.5517[0m        [32m2.5820[0m  0.1508
      8        [36m2.3517[0m        [32m2.4067[0m  0.1566
      9        [36m2.1811[0m        [32m2.2563[0m  0.1519
     10        [36m2.0346[0m        [32m2.1265[0m  0.1479
     11        [36m1.9078[0m        [32m2.0136[0m  0.1487
     12        [36m1.7974[0m        [32m1.9148[0m  0.1527
     13        [36m1.7004[0m        [32m1.8278[0m  0.1489
     14        [36m1.6148[0m        [32m1.7507[0m  0.1618
     15        [36m1.5386[0m        [32m1

    132        [36m0.3269[0m        [32m0.6239[0m  0.1480
    133        [36m0.3250[0m        [32m0.6225[0m  0.1467
    134        [36m0.3230[0m        [32m0.6212[0m  0.1452
    135        [36m0.3211[0m        [32m0.6198[0m  0.1461
    136        [36m0.3192[0m        [32m0.6185[0m  0.1465
    137        [36m0.3174[0m        [32m0.6172[0m  0.1454
    138        [36m0.3155[0m        [32m0.6159[0m  0.1451
    139        [36m0.3137[0m        [32m0.6146[0m  0.1481
    140        [36m0.3119[0m        [32m0.6133[0m  0.1484
    141        [36m0.3101[0m        [32m0.6121[0m  0.1484
    142        [36m0.3084[0m        [32m0.6109[0m  0.1493
    143        [36m0.3066[0m        [32m0.6097[0m  0.1489
    144        [36m0.3049[0m        [32m0.6085[0m  0.1486
    145        [36m0.3032[0m        [32m0.6073[0m  0.1510
    146        [36m0.3015[0m        [32m0.6061[0m  0.1532
    147        [36m0.2999[0m        [32m0.6050[0m  0.1452
    148 

<class 'skorch.net.NeuralNet'>[initialized](
  module_=LinearDiagnosticClassifier(
    (layer): Linear(in_features=768, out_features=235, bias=True)
  ),
)

In [9]:
dev_y_pred_id = np.argmax(net.predict(dev_embeddings), axis=1)
dev_y_pred = [index_to_language[id] for id in dev_y_pred_id]
print(classification_report(y_dev[:len(dev_y_pred)], dev_y_pred, target_names=langs, zero_division=0))

              precision    recall  f1-score   support

         ace       0.83      0.89      0.86        28
         afr       0.95      1.00      0.98        20
         als       0.83      0.76      0.79        25
         amh       0.90      0.97      0.93        29
         ang       0.89      0.83      0.86        30
         ara       0.89      0.89      0.89        28
         arg       1.00      0.93      0.96        27
         arz       0.94      0.88      0.91        33
         asm       0.92      0.96      0.94        25
         ast       0.96      0.96      0.96        27
         ava       0.84      0.78      0.81        27
         aym       0.95      0.95      0.95        22
         azb       1.00      1.00      1.00        25
         aze       1.00      0.92      0.96        25
         bak       1.00      1.00      1.00        29
         bar       0.75      0.84      0.79        25
         bcl       0.84      0.93      0.88        28
   be-tarask       0.59    