# 🏋️ Training sentence_transformers using ČTK data
Shared notebook version 1.0

In [1]:
import sys

sys.path.append("../src")

## 📑 Import Clauses

In [2]:
import json, logging, math, os, pickle, gc
from collections import Counter, OrderedDict
from os.path import join as pjoin

import numpy as np
import sklearn
import torch
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CESoftmaxAccuracyEvaluator
from sentence_transformers.evaluation import (
    SequentialEvaluator,
)
from sentence_transformers.readers import InputExample
from torch.utils.data import DataLoader

import datautils
from datautils import LABEL_NUM, LABEL_STR

logger = logging.getLogger(__name__)

## ⚓ Load a dataset (see [Dataset NB](datasets.ipynb))

In [3]:
trn_examples, tst_examples, val_examples = datautils.load_examples_from_pickle("../data/demo_splits/pickle")
[datautils.counter(split) for split in (trn_examples, tst_examples, val_examples)]

[[('NOT ENOUGH INFO', 1021, 0.2815774958632101),
  ('REFUTES', 851, 0.23469387755102042),
  ('SUPPORTS', 1754, 0.48372862658576943)],
 [('NOT ENOUGH INFO', 177, 0.36721991701244816),
  ('REFUTES', 114, 0.23651452282157676),
  ('SUPPORTS', 191, 0.3962655601659751)],
 [('NOT ENOUGH INFO', 183, 0.3279569892473118),
  ('REFUTES', 115, 0.2060931899641577),
  ('SUPPORTS', 260, 0.4659498207885305)]]

In [4]:
trn_examples[0].texts,trn_examples[0].label,val_examples[0].texts,val_examples[0].label

(['PRAHA 18. června (ČTK) - Rekordní teploty 19. června (od roku 1775 měřené v pražském Klementinu) byly následující: nejvyšší teplota 31,2 z roku 1917 a 1934, nejnižší teplota 7,3 z roku 1985\\. Dlouhodobý průměrný normál: 17,9 stupně Celsia.',
  'Rekordní teploty se od roku 1775 měří v Praze.'],
 0,
 ['České ministerstvo životního prostředí v říjnu oznámilo, že povolilo průzkum pro trvalé úložiště radioaktivního odpadu. Geologické průzkumné práce a posuzování vhodnosti hlubinných úložišť se týká celkem sedmi lokalit, o něž žádala Správa úložišť radioaktivních odpadů (SÚRAO). „V zásadě jsme neměli žádné zákonné důvody pro neudělení povolení,“ uvedl Brabec s tím, že jde o předběžný průzkum, mapování lokalit. „Vyjádření z rakouské strany beru v úvahu, všichni víme, že Rakušané se staví proti jaderným elektrárnám,“ uvedl Brabec s tím, že pro diskusi bude čas a se svým rakouským protějškem ji už zahájil.',
  'České ministerstvo životního prostředí zakázalo průzkum pro trvalé úložiště radi

## 📂 Prepare an output directory

In [6]:
outdir = "../models"

## 📅 Schedule a bunch of training jobs!
Set parameters for each in respective if's - omit them if not needed, alter the iterated range if some are to be skipped

In [11]:
for i in range(2, 3):
    if i == 0:
        bert_name = (bert_name_short) = "deepset/xlm-roberta-large-squad2"  # "DeepPavlov/bert-base-multilingual-cased-sentence" #"bert-base-multilingual-cased"  # "deepset/xlm-roberta-large-squad2"
        max_length = None
        batch_size = 12
        num_epochs = 30
        model_name = f"{bert_name_short}_bs{batch_size}"
    elif i == 1:
        bert_name = (bert_name_short) = "DeepPavlov/bert-base-multilingual-cased-sentence"  # "bert-base-multilingual-cased"  # "deepset/xlm-roberta-large-squad2"
        max_length = 512
        batch_size = 7
        num_epochs = 30
        model_name = f"{bert_name_short}_bs{batch_size}"
    if i == 2:
        bert_name = (bert_name_short) = "deepset/xlm-roberta-large-squad2"  # "DeepPavlov/bert-base-multilingual-cased-sentence" #"bert-base-multilingual-cased"  # "deepset/xlm-roberta-large-squad2"
        max_length = None
        batch_size =7
        num_epochs = 30
        model_name = f"{bert_name_short}_bs{batch_size}"

    output_path = pjoin(outdir, model_name)
    os.makedirs(output_path, exist_ok=True)
    logger.info(f"output path: {output_path}")
    pickle.dump(trn_examples, open(pjoin(output_path, "trn_examples.p"), "wb"))
    pickle.dump(tst_examples, open(pjoin(output_path, "tst_examples.p"), "wb"))
    pickle.dump(val_examples, open(pjoin(output_path, "val_examples.p"), "wb"))

    cfg = OrderedDict(
        [
            ("bert_name", bert_name),
            ("bert_name_short", bert_name_short),
            ("batch_size", batch_size),
            ("max_length", max_length),
        ]
    )

    with open(pjoin(output_path, "rteconfig.json"), "w") as outfile:
        outfile.write(json.dumps(cfg, indent=3))

    trn_dataloader = DataLoader(trn_examples, shuffle=True, batch_size=batch_size)
    val_dataloader = DataLoader(val_examples, shuffle=False, batch_size=batch_size)
    tst_dataloader = DataLoader(tst_examples, shuffle=False, batch_size=batch_size)

    trn_evaluator = CESoftmaxAccuracyEvaluator.from_input_examples(trn_examples, name="train")
    val_evaluator = CESoftmaxAccuracyEvaluator.from_input_examples(val_examples, name="validation")
    tst_evaluator = CESoftmaxAccuracyEvaluator.from_input_examples(tst_examples, name="test")

    # 10% of train data for warm-up
    warmup_steps = math.ceil(len(trn_dataloader) * num_epochs * 0.1)
    logger.info(f"warmup_steps: {warmup_steps}")

    model = CrossEncoder(bert_name, num_labels=3, max_length=max_length)

    def cb(score, epoch, steps):
        logger.info(f"E{epoch}: score: {score}")
        if score > model.best_score: logger.info(f"new best model for score: {score}")

    model.fit(
        train_dataloader=trn_dataloader,
        epochs=num_epochs,
        warmup_steps=warmup_steps,
        evaluator=SequentialEvaluator([trn_evaluator, val_evaluator]),
        output_path=output_path,
        callback=cb,
        save_best_model=True,
    )

    model = CrossEncoder(output_path, max_length=max_length)
    if 'evals' not in globals(): evals = {}
    evals[output_path] = tst_evaluator(model, output_path=output_path)
    tst_evaluator(model, output_path=output_path)

Some weights of the model checkpoint at deepset/xlm-roberta-large-squad2 were not used when initializing XLMRobertaForSequenceClassification: ['qa_outputs.weight', 'qa_outputs.bias', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at deepset/xlm-roberta-large-squad2 and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classi

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=30.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)






HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



### 🤯 Out of memory? Free some!

In [10]:
torch.cuda.empty_cache()
gc.collect()

6805

## 📜 How did Your models do?

In [19]:
evals