# prepare

In [1]:
# imports
import os
import argparse
import json

import torch
import pytorch_lightning as pl
import torchmetrics
import transformers

from utils import (
    PersonaDataset,
    GenerativeCollator,
    RetrievalCollator,
    aggregate_encoder_output,
    sim_func,
)
from models import RetrievalModel, GenerativeModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# proxy
os.environ["http_proxy"] = "http://proxy.ad.speechpro.com:3128"
os.environ["https_proxy"] = "http://proxy.ad.speechpro.com:3128"
os.environ["ftp_proxy"] = "http://proxy.ad.speechpro.com:3128"

In [3]:
# config bert
parser = argparse.ArgumentParser()
bert_args = parser.parse_args("")
with open("configs/bert_config.json", "r") as config:
    opt = json.load(config)
vars(bert_args).update(opt)

# pretrained model

In [4]:
# bert tokenizer
with open(bert_args.special_tokens_dict, "r") as config:
    special_tokens_dict = json.load(config)

bert_tokenizer = transformers.AutoTokenizer.from_pretrained(
    bert_args.pretrained_bert,
    truncation_side=bert_args.truncation_side,
    padding_side=bert_args.padding_side,
)
bert_tokenizer.add_special_tokens(special_tokens_dict)

7

In [5]:
# bert
context_bert = transformers.AutoModel.from_pretrained(bert_args.pretrained_bert)
context_bert.resize_token_embeddings(len(bert_tokenizer))
candidate_bert = transformers.AutoModel.from_pretrained(bert_args.pretrained_bert)
candidate_bert.resize_token_embeddings(len(bert_tokenizer))

Some weights of the model checkpoint at pretrain_models/rubert-base-cased-conversational were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at pretrain_models/rubert-base-cased-conversational wer

Embedding(100799, 768)

# data

In [6]:
# dataset
dataset = PersonaDataset(bert_args.data_path, mod='get_examples_gk', rnd_context=True)
train_size = len(dataset) - len(dataset) // bert_args.val_split
val_size = len(dataset) // bert_args.val_split
vars(bert_args).update({"train_size": train_size, "val_size": val_size})
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

In [7]:
# bert_callator
bert_callator = RetrievalCollator(
    bert_tokenizer, padding=bert_args.padding, max_length=bert_args.context_len
)

Using eos_token, but it is not set yet.


In [8]:
# dataloader
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=bert_args.batch_size,
    shuffle=True,
    collate_fn=bert_callator,
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=bert_args.batch_size,
    shuffle=False,
    collate_fn=bert_callator,
)

In [9]:
# scheduler len
scheduler_len = len(train_dataloader) * bert_args.epochs

# pl trainloop

In [10]:
# pl model
model = RetrievalModel(
    context_bert,
    candidate_bert,
    bert_args.batch_size,
    scheduler_len,
    bert_args.num_warmup_steps,
    bert_args.lr,
    aggregation_mod=bert_args.aggregation_mod,
    sim_mod=bert_args.sim_mod
)

  rank_zero_warn(
  rank_zero_warn(


In [11]:
# logger
logger = pl.loggers.comet.CometLogger(
    api_key=bert_args.api_key,
    save_dir=bert_args.save_dir,
    project_name=bert_args.project_name,
    experiment_name=bert_args.experiment_name,
)
logger.log_hyperparams(bert_args)

CometLogger will be initialized in online mode
COMET INFO: Experiment is live on comet.com https://www.comet.com/anpopaicoconat/bi-encoder/ba9a2503126b46e2b2ec8049c669b0f1



In [12]:
# checkpoint callback
checkpoint_callback = pl.callbacks.ModelCheckpoint(
     monitor='val_r1',
     dirpath=bert_args.save_dir,
     filename='bert-{epoch:02d}-{val_r1:.2f}',
     save_top_k=1,
     mode='max',
 )

In [13]:
# trainer
trainer = pl.Trainer(
    max_epochs=bert_args.epochs,
    accelerator="gpu",
    devices=1,
    gradient_clip_val=bert_args.gradient_clip_val,
    logger=logger,
    num_sanity_val_steps=1,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
# fit
trainer.fit(model, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | context_BERT  | BertModel        | 163 M 
1 | candidat_BERT | BertModel        | 163 M 
2 | loss          | CrossEntropyLoss | 0     
3 | train_metrics | MetricCollection | 0     
4 | val_metrics   | MetricCollection | 0     
---------------------------------------------------
326 M     Trainable params
0         Non-trainable params
326 M     Total params
1,307.640 Total estimated model params size (MB)
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(100799, 768)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
         

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 29: 100%|██████████| 949/949 [10:42<00:00,  1.48it/s, loss=0.134, v_num=b0f1, lr=2.6e-13, train_loss=0.0411, val_loss_step=5.170, val_loss_epoch=6.270]  

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 949/949 [10:48<00:00,  1.46it/s, loss=0.134, v_num=b0f1, lr=2.6e-13, train_loss=0.0411, val_loss_step=5.170, val_loss_epoch=6.270]


COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.com/anpopaicoconat/bi-encoder/ba9a2503126b46e2b2ec8049c669b0f1
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     lr [455]             : (1.1479716482964974e-10, 4.9999376642517745e-05)
COMET INFO:     train_loss [455]     : (0.019493410363793373, 10.632709503173828)
COMET INFO:     train_mrr_epoch [30] : (0.09993503987789154, 0.9744289517402649)
COMET INFO:     train_mrr_step [455] : (0.05357958376407623, 1.0)
COMET INFO:     train_r1_epoch [30]  : (0.03455415368080139, 0.9533183574676514)
COMET INFO:     train_r1_step [455]  : (0.0, 1.0)
COMET INFO:     train_r5_epoch [30]  : (0.12358634918928146, 0.997671365737915)
COMET INFO:     train_r5_step [455]  : (0.058139532804489136, 1.0)
COMET INFO:     val_loss_epoch [30]  : (3.7376554012298584, 