In [1]:
from collections import Counter
from math import sqrt

import pandas as pd
import numpy as np
import plotly.express as ex
from scipy.spatial.distance import jensenshannon
from sklearn.metrics import confusion_matrix
from joblib import Parallel, delayed
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
import torch

def load_ds(path: str):
    with open(path, encoding="utf8") as f:
        for l in f:
            yield l.rstrip("\n")

x_train = load_ds("data/wili-2018/x_train.txt")
y_train = load_ds("data/wili-2018/y_train.txt")
x_test = load_ds("data/wili-2018/x_test.txt")
y_test = load_ds("data/wili-2018/y_test.txt")

x_train = pd.DataFrame(x_train, columns=["sentence"])
y_train = pd.DataFrame(y_train, columns=["lang"])
x_test = pd.DataFrame(x_test, columns=["sentence"])
y_test = pd.DataFrame(y_test, columns=["lang"])

# Create a train dev split
x_train, x_dev, y_train, y_dev = train_test_split(x_train, y_train, test_size=0.2, random_state=42, stratify=y_train)

train = pd.concat([x_train, y_train], axis=1)
dev = pd.concat([x_dev, y_dev], axis=1)
test = pd.concat([x_test, y_test], axis=1)
langs = sorted(y_train.lang.unique())
chars = set(c for s in train.sentence for c in s)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
from transformers import BertTokenizer, BertModel
from tqdm.notebook import tqdm_notebook
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertModel.from_pretrained("bert-base-multilingual-cased").to(device)

def get_bert_embedding(model, tokenizer, sentences, batch_size=4, shrinkage_fact=1):
    with torch.no_grad():
        # Create the tensor to house the CLS embeddings
        embeddings = torch.zeros((len(sentences) // shrinkage_fact, 768)).to(device)

        # Loop over the sentences in batches
        for i in tqdm_notebook(range(0, len(sentences) // shrinkage_fact, batch_size)):
            # meta = metadata_collector(sentences, device, features)
            
            encoded_input = tokenizer(list(sentences[i:i+batch_size]), padding=True, truncation=True, return_tensors="pt").to(device)
            output = model(**encoded_input)

            # Select the last hidden state of the token `[CLS]`
            last_hidden_states = output[0][:, 0, :]

            # Store the embeddings
            if i+batch_size < len(embeddings):
                embeddings[i:i+batch_size] = last_hidden_states
            else:
                # Fill up the last ones
                embeddings[i:len(embeddings)] = last_hidden_states[:len(embeddings) - i]

                last_counter = i
                idx = 0
                
                while last_counter < len(embeddings):
                    embeddings[last_counter] = last_hidden_states[idx]
                    last_counter += 1
                    idx += 1

    return embeddings

train_embeddings = get_bert_embedding(model, tokenizer, train.sentence, batch_size=24, shrinkage_fact=1)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


  0%|          | 0/3917 [00:00<?, ?it/s]

In [3]:
dev_embeddings = get_bert_embedding(model, tokenizer, dev.sentence, batch_size=24, shrinkage_fact=1)

  0%|          | 0/980 [00:00<?, ?it/s]

In [4]:
len(dev.sentence), len(dev_embeddings)

(23500, 23500)

In [5]:
len(train.sentence), len(train_embeddings)

(94000, 94000)

In [6]:
# DIAGNOSTIC CLASSIFIER
from skorch import NeuralNet
from skorch.helper import predefined_split
from skorch.dataset import Dataset
from sklearn.metrics import classification_report
import numpy as np

class LinearDiagnosticClassifier(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        # Save dims
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Construct layer
        self.layer = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.layer(x)

In [7]:
language_to_index = {lang: i for i, lang in enumerate(langs)}
index_to_language = {i: lang for i, lang in enumerate(langs)}

y_dev_id = [language_to_index[lang] for lang in y_dev.lang]
y_train_id = [language_to_index[lang] for lang in y_train.lang]
valid_ds = Dataset(dev_embeddings, y_dev_id[:len(dev_embeddings)])

In [12]:
net = NeuralNet(
    module=LinearDiagnosticClassifier,
    module__input_dim = 768,
    module__output_dim = len(set(y_train.lang)),
    criterion=torch.nn.CrossEntropyLoss,
    train_split=predefined_split(valid_ds),
    max_epochs=50,
    device=device,
    verbose=1,
    lr=0.2,
)

net.fit(train_embeddings, y_train_id[:len(train_embeddings)])

  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1        [36m4.1736[0m        [32m3.3023[0m  0.5826
      2        [36m2.7635[0m        [32m2.3757[0m  0.5837
      3        [36m2.0632[0m        [32m1.8770[0m  0.5883
      4        [36m1.6676[0m        [32m1.5785[0m  0.5986
      5        [36m1.4199[0m        [32m1.3825[0m  0.5856
      6        [36m1.2512[0m        [32m1.2440[0m  0.5859
      7        [36m1.1286[0m        [32m1.1407[0m  0.5871
      8        [36m1.0349[0m        [32m1.0603[0m  0.5794
      9        [36m0.9606[0m        [32m0.9957[0m  0.5791
     10        [36m0.9000[0m        [32m0.9425[0m  0.5787
     11        [36m0.8494[0m        [32m0.8978[0m  0.5793
     12        [36m0.8064[0m        [32m0.8596[0m  0.5788
     13        [36m0.7692[0m        [32m0.8265[0m  0.5778
     14        [36m0.7367[0m        [32m0.7975[0m  0.5775
     15        [36m0.7080[0m        [32m0

<class 'skorch.net.NeuralNet'>[initialized](
  module_=LinearDiagnosticClassifier(
    (layer): Linear(in_features=768, out_features=235, bias=True)
  ),
)

In [13]:
dev_y_pred_id = np.argmax(net.predict(dev_embeddings), axis=1)
dev_y_pred = [index_to_language[id] for id in dev_y_pred_id]
print(classification_report(y_dev[:len(dev_y_pred)], dev_y_pred, target_names=langs, zero_division=0))

              precision    recall  f1-score   support

         ace       0.91      0.97      0.94       100
         afr       0.97      0.99      0.98       100
         als       0.76      0.81      0.78       100
         amh       0.82      0.83      0.83       100
         ang       0.95      0.90      0.92       100
         ara       0.92      0.99      0.95       100
         arg       0.99      0.99      0.99       100
         arz       0.99      0.93      0.96       100
         asm       0.96      0.97      0.97       100
         ast       0.93      0.99      0.96       100
         ava       0.84      0.80      0.82       100
         aym       0.90      0.82      0.86       100
         azb       1.00      1.00      1.00       100
         aze       1.00      0.99      0.99       100
         bak       0.98      0.98      0.98       100
         bar       0.88      0.86      0.87       100
         bcl       0.92      0.92      0.92       100
   be-tarask       0.70    