In [1]:
%%capture

from os import environ

# Set this to `True` to train on Colab
ON_CLOUD = False

if ON_CLOUD:
    from getpass import getpass
    from urllib.parse import quote

    token = getpass('GitHub token: ')
    token = quote(token)
    environ["GITHUB_TOKEN"] = token
    ! if [ -d gulag ]; then rm -Rf gulag; fi
    ! git clone https://$GITHUB_TOKEN@github.com/SpirinEgor/gulag.git
    %cd gulag
    ! pip install -q -r requirements.txt
else:
    %cd ..

In [2]:
import gin

from src.data import MultiLanguageClassificationDataModule
from src.main import train
from src.utils import setup_logging

setup_logging()

In [3]:
gin.parse_config_file("config/debug.gin")

print(gin.config_str())

import gin.torch.external_configurables
import src.utils

# Parameters for AdamW:
AdamW.lr = 0.001
AdamW.weight_decay = 0.0

# Parameters for configure_optimizers:
configure_optimizers.optimizer_cls = @AdamW
configure_optimizers.scheduler_cls = @LambdaLR

# Parameters for generate_eval_samples:
generate_eval_samples.n_samples = 100

# Parameters for generate_example:
generate_example.max_langs = 5
generate_example.max_samples_per_lang = 5
generate_example.max_seq_len = 256
generate_example.min_langs = 1

# Parameters for LambdaLR:
LambdaLR.lr_lambda = @rsqrt_with_warmup

# Parameters for MultiLanguageClassificationDataModule:
MultiLanguageClassificationDataModule.batch_size = 8
MultiLanguageClassificationDataModule.languages = ('ru', 'uk', 'be')
MultiLanguageClassificationDataModule.tokenizer_name = \
    'bert-base-multilingual-cased'
MultiLanguageClassificationDataModule.val_batch_size = 16

# Parameters for MultiLanguageClassifier:
MultiLanguageClassifier.embedder_name = 'bert-base-

# Data overview

Some examples from synthetic dataset

In [4]:
data_module = MultiLanguageClassificationDataModule(batch_size=1)
data_module.setup()

INFO:src.data.data_module:Downloading and opening 'wikiann' dataset for ru, uk, be
INFO:datasets.info:Loading Dataset Infos from /Users/Egor.Spirin/.cache/huggingface/modules/datasets_modules/datasets/wikiann/4bfd4fe4468ab78bb6e096968f61fab7a888f44f9d3371c2f3fea7e74a5a354e
INFO:datasets.builder:Overwrite dataset info from restored data version.
INFO:datasets.info:Loading Dataset info from /Users/Egor.Spirin/.cache/huggingface/datasets/wikiann/ru/1.1.0/4bfd4fe4468ab78bb6e096968f61fab7a888f44f9d3371c2f3fea7e74a5a354e
INFO:datasets.info:Loading Dataset info from /Users/Egor.Spirin/.cache/huggingface/datasets/wikiann/ru/1.1.0/4bfd4fe4468ab78bb6e096968f61fab7a888f44f9d3371c2f3fea7e74a5a354e
INFO:datasets.info:Loading Dataset Infos from /Users/Egor.Spirin/.cache/huggingface/modules/datasets_modules/datasets/wikiann/4bfd4fe4468ab78bb6e096968f61fab7a888f44f9d3371c2f3fea7e74a5a354e
INFO:datasets.builder:Overwrite dataset info from restored data version.
INFO:datasets.info:Loading Dataset info f

In [5]:
for i, batch in enumerate(data_module.train_dataloader()):
    if i == 5:
        break
    seq, attn, labels = batch

    orig_str = data_module.tokenizer.decode(seq[0])
    class_names = data_module.decode_languages(labels[0])

    print(f"""{'=' * 20}
Input tokens (len = {seq.shape[-1]}):
{seq[0]}
Target classes:
{class_names}
Original string:
{orig_str}
{'=' * 20}
""")

Input tokens (len = 44):
tensor([   101,    528,  44148,  73899,  15966,  61381,  33191,  16183,  11384,
         10122,  12634,  15535,  18971,  43514,    547,  12861,  92596,  21979,
           587,  79524,  10823,  12601,  10593,  40705,  10234,  88215,  28385,
         96195,  43067,  16848,  44392,  96195,  67922, 106072,  16481,    552,
         75238,  23444,  20060,  19147,  40705,  33018, 101470,    102])
Target classes:
['[NOT LANG]', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', '[NOT LANG]']
Original string:
[CLS] У тым жа годзе пераехаў на пастаяннае жыхарства ў Англію дзе займаўся выкладчыцкай выдавецкай літаратурнай дзейнасцю [SEP]

Input tokens (len = 44):
tensor([   101,    524,  45224,  18005,  45224,  25298,  16847,  16027,  18005,
         16847,  25956,  32476,  

# Train model

In [6]:
train()

Global seed set to 7
[34m[1mwandb[0m: Currently logged in as: [33mvoudy[0m (use `wandb login --relogin` to force relogin)


INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /var/folders/vw/cn7lrm9j7bvd1rdyfymttkrh0000kt/T/tmptctdy_b1
INFO:torch.distributed.nn.jit.instantiator:Writing /var/folders/vw/cn7lrm9j7bvd1rdyfymttkrh0000kt/T/tmptctdy_b1/_remote_module_non_sriptable.py
loading configuration file https://huggingface.co/bert-base-multilingual-cased/resolve/main/config.json from cache at /Users/Egor.Spirin/.cache/huggingface/transformers/6c4a5d81a58c9791cdf76a09bce1b5abfb9cf958aebada51200f4515403e5d08.0fe59f3f4f1335dadeb4bce8b8146199d9083512b50d07323c1c319f96df450c
Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
 

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn(
INFO:src.data.data_module:Downloading and opening 'wikiann' dataset for ru, uk, be
INFO:datasets.info:Loading Dataset Infos from /Users/Egor.Spirin/.cache/huggingface/modules/datasets_modules/datasets/wikiann/4bfd4fe4468ab78bb6e096968f61fab7a888f44f9d3371c2f3fea7e74a5a354e
INFO:datasets.builder:Overwrite dataset info from restored data version.
INFO:datasets.info:Loading Dataset info from /Users/Egor.Spirin/.cache/huggingface/datasets/wikiann/ru/1.1.0/4bfd4fe4468ab78bb6e096968f61fab7a888f44f9d3371c2f3fea7e74a5a354e
INFO:datasets.info:Loading Dataset info from /Users/Egor.Spirin/.cache/huggingface/datasets/wikiann/ru/1.1.0/4bfd4fe4468ab78bb6e096968f61fab7a888f44f9d3371c2f3fea7e74a5a354e
INFO:datasets.info:Loading Dataset Infos from /Users/Egor.Spirin/.cache/huggingface/modules/datasets_modules/datasets/wikiann/4bfd4fe4468ab78bb6e096968f61fab7a888f44f9d3371c2f3fea7e74a5a354e
INFO:datasets.builder:Overwrite dataset info from restored data version.
INFO:datasets.info:Load

Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         test/f1            0.8426086902618408
        test/loss           0.4306808114051819
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
