# prepare

In [1]:
# imports
import os
import argparse
import json
import sys
sys.path.append("..")

import torch
import pytorch_lightning as pl
import torchmetrics
import transformers

from utils import (
    PersonaDataset,
    GenerativeCollator,
    RetrievalCollator,
    aggregate_encoder_output,
    sim_func,
)
from models import BERT_RetrievalModel

pl.utilities.seed.seed_everything(42)

  from .autonotebook import tqdm as notebook_tqdm
Global seed set to 42


42

In [2]:
# proxy
os.environ["http_proxy"] = "http://proxy.ad.speechpro.com:3128"
os.environ["https_proxy"] = "http://proxy.ad.speechpro.com:3128"
os.environ["ftp_proxy"] = "http://proxy.ad.speechpro.com:3128"

In [3]:
# config
parser = argparse.ArgumentParser()
bert_args = parser.parse_args("")
with open("../config.json", "r") as config:
    opt = json.load(config)
vars(bert_args).update(opt)

opt = {
    "epochs": 30,
    "lr": 5e-05,
    "gradient_clip_val": 1,
    "batch_size": 100,
    "val_split": -1,
    "num_warmup_steps": 1000,
    "rnd_context": 0,
    "context_len": 128,
    "candidate_len": 32,
    "persona_len": 3,
    "aggregation_mod":"last_hidden_state_cls_left",
    "sim_mod": "CosineSimilarity",
    "project_name": "bi_encoder",
    "experiment_name": "gk_cos(5e-05)",
    "dataset_mod": "get_examples_gk"
}
vars(bert_args).update(opt)

# pretrained model

In [4]:
# bert tokenizer
bert_tokenizer = transformers.AutoTokenizer.from_pretrained(
    bert_args.bert,
    truncation_side=bert_args.truncation_side,
    padding_side=bert_args.padding_side,
)
bert_tokenizer.add_special_tokens(bert_args.special_tokens_dict)

7

In [5]:
# bert
context_bert = transformers.AutoModel.from_pretrained(bert_args.bert)
context_bert.resize_token_embeddings(len(bert_tokenizer))
candidate_bert = transformers.AutoModel.from_pretrained(bert_args.bert)
candidate_bert.resize_token_embeddings(len(bert_tokenizer))

Some weights of the model checkpoint at ../pretrain_models/rubert-base-cased-conversational were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at ../pretrain_models/rubert-base-cased-conversation

Embedding(100799, 768)

# data

In [6]:
train_dataset = PersonaDataset(bert_args.train_data_path, mod=bert_args.dataset_mod, rnd_context=bert_args.rnd_context)
val_dataset = PersonaDataset(bert_args.test_data_path, mod=bert_args.dataset_mod, rnd_context=bert_args.rnd_context)
train_size = len(train_dataset)
val_size = len(val_dataset)
vars(bert_args).update({"train_size": train_size, "val_size": val_size})
print(train_size, val_size)

73508 8028


In [7]:
# bert_callator
bert_callator = RetrievalCollator(
    bert_tokenizer, padding=bert_args.padding, max_length_context=bert_args.context_len, max_length_candidate=bert_args.candidate_len
)

Using eos_token, but it is not set yet.


In [8]:
# dataloader
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=bert_args.batch_size,
    shuffle=True,
    collate_fn=bert_callator,
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=bert_args.batch_size,
    shuffle=False,
    collate_fn=bert_callator,
)

In [9]:
# scheduler len
scheduler_len = len(train_dataloader) * bert_args.epochs

# pl trainloop

In [10]:
# pl model
model = BERT_RetrievalModel(
    context_bert,
    candidate_bert,
    bert_args.batch_size,
    scheduler_len,
    bert_args.num_warmup_steps,
    bert_args.lr,
    aggregation_mod=bert_args.aggregation_mod,
    sim_mod=bert_args.sim_mod,
    tokenizer=bert_tokenizer,
    collator=bert_callator,
    base_config=bert_args,
)

  rank_zero_warn(
  rank_zero_warn(


In [11]:
# logger
logger = pl.loggers.comet.CometLogger(
    api_key=bert_args.api_key,
    save_dir=bert_args.save_dir,
    project_name=bert_args.project_name,
    experiment_name=bert_args.experiment_name,
)
logger.log_hyperparams(bert_args)

CometLogger will be initialized in online mode
COMET INFO: Experiment is live on comet.com https://www.comet.com/anpopaicoconat/bi-encoder/36037371cee4404b80aa618268a2e24c



In [12]:
# checkpoint callback
checkpoint_callback = pl.callbacks.ModelCheckpoint(
     monitor='val_r1',
     dirpath=bert_args.save_dir,
     filename='bert-{epoch:02d}-{val_r1:.2f}',
     save_top_k=1,
     mode='max',
 )

In [13]:
# trainer
trainer = pl.Trainer(
    max_epochs=bert_args.epochs,
    accelerator="gpu",
    devices=1,
    gradient_clip_val=bert_args.gradient_clip_val,
    logger=logger,
    num_sanity_val_steps=1,
    callbacks=[checkpoint_callback]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
# fit
trainer.fit(model, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | context_BERT  | BertModel        | 163 M 
1 | candidat_BERT | BertModel        | 163 M 
2 | loss          | CrossEntropyLoss | 0     
3 | train_metrics | MetricCollection | 0     
4 | val_metrics   | MetricCollection | 0     
---------------------------------------------------
326 M     Trainable params
0         Non-trainable params
326 M     Total params
1,307.640 Total estimated model params size (MB)
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(100799, 768)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
         

Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 29: 100%|██████████| 817/817 [08:11<00:00,  1.66it/s, loss=0.169, v_num=e24c, lr=2.78e-13, train_loss=0.00237, val_loss_step=2.750, val_loss_epoch=3.660]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 817/817 [08:16<00:00,  1.64it/s, loss=0.169, v_num=e24c, lr=2.78e-13, train_loss=0.00237, val_loss_step=2.750, val_loss_epoch=3.660]


COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.com/anpopaicoconat/bi-encoder/36037371cee4404b80aa618268a2e24c
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     lr [441]             : (2.668032739183701e-10, 4.999933298677206e-05)
COMET INFO:     train_loss [441]     : (0.0035043286625295877, 4.613931179046631)
COMET INFO:     train_mrr_epoch [30] : (0.12932683527469635, 0.9729980230331421)
COMET INFO:     train_mrr_step [441] : (0.06485883891582489, 1.0)
COMET INFO:     train_r1_epoch [30]  : (0.056997284293174744, 0.9498097896575928)
COMET INFO:     train_r1_step [441]  : (0.019999999552965164, 1.0)
COMET INFO:     train_r5_epoch [30]  : (0.16728940606117249, 0.9980027079582214)
COMET INFO:     train_r5_step [441]  : (0.07999999821186066, 1.0)
COMET INFO:     val_loss_epoch [30]  : (3.52