In [1]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import pickle as pk

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# input
checkpoint_dir = "./models/"

model_name = 'distilbert-base-multilingual-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

data_dir = "./data/"
output_dir = "./data/output/"
device_name = "cuda" if torch.cuda.is_available() else "cpu"
num_labels = 2

print(device_name)

LANGS = ['en', 'de', 'ja', 'es', 'fr', 'zh']


cuda


In [3]:
from sklearn.metrics import classification_report, roc_auc_score
import scipy.stats as stats

def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

In [4]:
# next step: could try to implement using pipeline: https://huggingface.co/docs/transformers/v4.25.1/en/main_classes/pipelines#transformers.pipeline
class Evaluator:
    def __init__(self, model, tokenizer, eval_lang):
        self.model = model
        self.model.eval() # set to evaluation mode
        self.tokenizer = tokenizer
        self.eval_lang = eval_lang
        self.text_input = None
        self.true_labels = None
        
        self.load_data()
        self.logits = None
        
    def load_data(self):
        data = pk.load(open(data_dir + f"clean_{self.eval_lang}_test.pk", "rb"))
        self.text_input = data['text']
        self.true_labels = data['binary_labels']
        
    def make_prediction(self):
        logits = None
        with torch.no_grad():
            for batched_inputs in batch(self.text_input, 20):
                encoding = self.tokenizer(batched_inputs, 
                                          truncation=True, 
                                          padding=True, 
                                          return_tensors='pt').to(device_name)
                batched_logits = self.model(**encoding)[0]
                if logits is None:
                    logits = batched_logits
                else:
                    logits = torch.cat((logits, batched_logits), dim=0)
        self.logits = logits
        return logits
    
    def evaluate(self):
        if self.logits is None:
            raise Exception("Run prediction first!")
        _, predicted_labels = torch.max(self.logits, dim=1)
        predicted_labels = predicted_labels.tolist()
        report = classification_report(predicted_labels, self.true_labels, output_dict=True)
        report['kendall-tau'], report['kendall-tau-p-value'] = stats.kendalltau(self.true_labels, self.logits[:, 1].tolist())
        report['auc'] = roc_auc_score(self.true_labels, self.logits[:, 1].tolist())
        return report

In [5]:
model_lang = 'zh'
checkpoint_name = f"mdistilbert-{model_lang}-epoch-6" 
print(f"init model {checkpoint_name}...")
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir + checkpoint_name, 
                                                           num_labels=num_labels).to(device_name)
pred_scores = dict()
eval_reports = dict()
for lang in LANGS:
    ev = Evaluator(model, tokenizer, lang)
    pred_scores[lang] = ev.make_prediction()
    eval_reports[lang] = ev.evaluate()
torch.save(pred_scores, output_dir + f"{checkpoint_name}_scores.pt")
pk.dump(eval_reports, open(output_dir + f"{checkpoint_name}_eval_reports.pk", "wb"))

init model mdistilbert-zh-epoch-6...


In [5]:
for model_lang in [
    'en', 'fr', 'zh', 'de', 'ja', "es"
]:
    checkpoint_name = f"mdistilbert-{model_lang}" 
    print(f"init model {checkpoint_name}...")
    model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir + checkpoint_name, 
                                                               num_labels=num_labels).to(device_name)
    pred_scores = dict()
    eval_reports = dict()
    for lang in LANGS:
        ev = Evaluator(model, tokenizer, lang)
        pred_scores[lang] = ev.make_prediction()
        eval_reports[lang] = ev.evaluate()
    torch.save(pred_scores, output_dir + f"{checkpoint_name}_scores.pt")
    pk.dump(eval_reports, open(output_dir + f"{checkpoint_name}_eval_reports.pk", "wb"))

init model mdistilbert-es...


## study

In [9]:
import pandas as pd


metrics_df = None
metrics = ['accuracy', 'auc', 'kendall-tau']

In [6]:
for model_lang in LANGS:
    checkpoint_name = f'mdistilbert-{model_lang}'
    eval_reports = pk.load(open(output_dir + f"{checkpoint_name}_eval_reports.pk", "rb"))
    df = pd.DataFrame.from_dict(eval_reports, orient='index')
    df['model_lang'] = model_lang
    df = df.reset_index().rename({"index": "test_lang"}, axis=1)
    if metrics_df is None:
        metrics_df = df[['model_lang', 'test_lang', *metrics]].copy()
    else:
        metrics_df = pd.concat(
            (
                metrics_df, 
                df[['model_lang', 'test_lang', *metrics]]
            )
        )
        
        

In [7]:
pd.pivot(metrics_df, index='model_lang', 
         columns = 'test_lang', 
         values = metrics)

Unnamed: 0_level_0,accuracy,accuracy,accuracy,accuracy,accuracy,accuracy,auc,auc,auc,auc,auc,auc,kendall-tau,kendall-tau,kendall-tau,kendall-tau,kendall-tau,kendall-tau
test_lang,de,en,es,fr,ja,zh,de,en,es,fr,ja,zh,de,en,es,fr,ja,zh
model_lang,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2
de,0.888,0.7936,0.7964,0.7424,0.5832,0.6838,0.954805,0.858742,0.862021,0.813616,0.702889,0.744877,0.630259,0.497137,0.501681,0.434602,0.281159,0.339346
en,0.7116,0.8838,0.7268,0.762,0.586,0.7026,0.757209,0.951686,0.850235,0.844972,0.731073,0.756772,0.356435,0.625938,0.485349,0.478055,0.320217,0.35583
es,0.7334,0.7044,0.8824,0.7726,0.493,0.6892,0.803325,0.80857,0.949472,0.860608,0.671722,0.728899,0.420342,0.42761,0.622869,0.499723,0.237969,0.317204
fr,0.7394,0.7226,0.7808,0.8856,0.5962,0.684,0.808079,0.83705,0.873063,0.947676,0.675161,0.735325,0.426929,0.467076,0.516983,0.62038,0.242735,0.326108
ja,0.6472,0.7132,0.628,0.6242,0.8676,0.7066,0.719193,0.789165,0.724507,0.711201,0.935959,0.749267,0.303753,0.400719,0.311117,0.292678,0.604143,0.34543
zh,0.5848,0.6946,0.5976,0.5094,0.6572,0.8412,0.73422,0.812382,0.808288,0.760685,0.730695,0.912239,0.324577,0.432892,0.427219,0.361253,0.319692,0.571272


In [10]:
checkpoint_name = f'mdistilbert-{model_lang}'
for checkpoint_name in ['mdistilbert-zh', 'mdistilbert-zh-epoch-6']:
    eval_reports = pk.load(open(output_dir + f"{checkpoint_name}_eval_reports.pk", "rb"))
    df = pd.DataFrame.from_dict(eval_reports, orient='index')
    df['checkpoint'] = checkpoint_name
    df = df.reset_index().rename({"index": "test_lang"}, axis=1)
    if metrics_df is None:
        metrics_df = df[['checkpoint', 'test_lang', *metrics]].copy()
    else:
        metrics_df = pd.concat(
            (
                metrics_df, 
                df[['checkpoint', 'test_lang', *metrics]]
            )
        )

In [12]:
pd.pivot(metrics_df, index='checkpoint', 
         columns = 'test_lang', 
         values = metrics)

Unnamed: 0_level_0,accuracy,accuracy,accuracy,accuracy,accuracy,accuracy,auc,auc,auc,auc,auc,auc,kendall-tau,kendall-tau,kendall-tau,kendall-tau,kendall-tau,kendall-tau
test_lang,de,en,es,fr,ja,zh,de,en,es,fr,ja,zh,de,en,es,fr,ja,zh
checkpoint,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2
mdistilbert-zh,0.5848,0.6946,0.5976,0.5094,0.6572,0.8412,0.73422,0.812382,0.808288,0.760685,0.730695,0.912239,0.324577,0.432892,0.427219,0.361253,0.319692,0.571272
mdistilbert-zh-epoch-6,0.5642,0.7026,0.607,0.5444,0.6208,0.8344,0.651526,0.763,0.74525,0.722617,0.655869,0.903355,0.209982,0.364459,0.339862,0.308498,0.216,0.55896
