# Imports

In [27]:
import json
from datetime import datetime
import re
import pprint
import random
import numpy as np
import pandas as pd

import comet_ml
import torch
import pytorch_lightning as pl
import torchmetrics
import transformers

import os
os.environ['http_proxy'] = "http://proxy.ad.speechpro.com:3128" 
os.environ['https_proxy'] = "http://proxy.ad.speechpro.com:3128"
os.environ['ftp_proxy'] = "http://proxy.ad.speechpro.com:3128"

In [28]:
class PersonaRetrievalDataset(torch.utils.data.Dataset):
    def __init__(self, path, rnd_context=False, seed=42):
        super().__init__()
        self.data = []
        self.rnd_context = rnd_context
        with open(path, 'r') as file:
            for line in file:
                line = json.loads(line)
                self.data+=list(self.get_examples(**line))

    def get_examples(self, person_1, person_2, dialog):
        for i in range(1, len(dialog)):
            if self.rnd_context:
                start = random.randint(0, i-1)
            else:
                start = 0
            context = dialog[start:i]
            candidate = dialog[i]
            if i % 2 == 0:
                persona = person_2
            else:
                persona = person_1
            label = 1

            yield {'context': context, 'candidate': candidate, 
                   'persona': persona, 'label': label}

    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

In [29]:
class RetrievalCollator():
    def __init__(self, tokenizer, padding, max_length, return_tensors="pt"):
        self.tokenizer = tokenizer
        self.P1 = '[P1u]'
        self.P2 = '[P2u]'
        self.Gk = '[Gk]'
        self.cls = tokenizer.cls_token
        self.padding = padding
        self.max_length = max_length
        self.return_tensors = return_tensors

    def __call__(self, batch):
        batch_new = {k:[] for k in batch[0]}
        for example in batch:
            for k in example:
                batch_new[k].append(example[k])
        batch_new['context'] = self.ContextCollator(batch_new['context'])
        batch_new['candidate'] = self.CandidateCollator(batch_new['candidate'])
        batch_new['persona'] = self.PersonaCollator(batch_new['persona'])
        return batch_new


    def ContextCollator(self, batch):
        for i, context in enumerate(batch):
            c_out = self.P2
            for c in context[::-1]:
                if i%2 == 0:
                    P = self.P1
                else:
                    P = self.P2
                c_out = P + c + c_out
            batch[i] = c_out
        return self.tokenizer.batch_encode_plus(batch, padding=self.padding, 
                                                max_length=self.max_length, 
                                                return_tensors=self.return_tensors, 
                                                truncation=True)
    def CandidateCollator(self, batch):
        return self.tokenizer.batch_encode_plus(batch, padding=self.padding, 
                                            max_length=self.max_length, 
                                            return_tensors=self.return_tensors, 
                                            truncation=True)
            
    def PersonaCollator(self, batch):
        for i, persona in enumerate(batch):
            c_out = self.Gk
            for c in persona[::-1]:
                c_out = self.Gk + c + c_out
            batch[i] = c_out
        return self.tokenizer.batch_encode_plus(batch, padding=self.padding, 
                                                max_length=self.max_length, 
                                                return_tensors=self.return_tensors, 
                                                truncation=True)

In [30]:
def aggregate_encoder_output(model_output: transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions,
                   mod: str):
    if mod == 'pooler_output':
        out = model_output.pooler_output 
    elif mod == 'last_hidden_state_cls_left':
        out = model_output.last_hidden_state[:,0,:]
    elif mod == 'last_hidden_state_cls_right':
        out = model_output.last_hidden_state[:,-1,:]
    elif mod == 'last_hidden_state_mean':
        # TODO проверить нужно ли маскирование
        out = torch.mean(out.last_hidden_state, dim=1)
    return out
    
def sim_func(x, y, mod):
    if mod == 'DotProduct':
        out = torch.mm(x, y.transpose(0,1))
    elif mod == 'CosineSimilarity':
        x = x / x.norm(dim=1)[:, None]
        y = y / y.norm(dim=1)[:, None]
        out = torch.mm(x, y.transpose(0,1))
    return out

In [31]:
class RetrievalModel(pl.LightningModule):
    def __init__(self, context_BERT, candidat_BERT, batch_size, scheduler_len, num_warmup_steps):
        super().__init__()
        self.context_BERT = context_BERT
        self.candidat_BERT = candidat_BERT
        self.batch_size = batch_size
        self.scheduler_len = scheduler_len
        self.num_warmup_steps = num_warmup_steps
        self.loss = torch.nn.CrossEntropyLoss()
        self.train_metrics = torchmetrics.MetricCollection(
            {
                "train_r1": torchmetrics.RetrievalRecall(k=1),
                "train_r5": torchmetrics.RetrievalRecall(k=5),
                "train_mrr": torchmetrics.RetrievalMRR(),
            }
        )
        self.val_metrics = torchmetrics.MetricCollection(
            {
                "val_r1": torchmetrics.RetrievalRecall(k=1),
                "val_r5": torchmetrics.RetrievalRecall(k=5),
                "val_mrr": torchmetrics.RetrievalMRR(),
            }
        )


    def training_step(self, batch, batch_idx):
        context = batch["context"]
        candidate = batch["candidate"]
        persona = batch["persona"]
        b_size = context['input_ids'].size()[0]
        #labels = torch.range(0, candidate['input_ids'].size()[0]-1, dtype=torch.long).to(self.device)
        labels = torch.zeros((b_size, b_size), dtype=torch.long).to(self.device)
        labels.fill_diagonal_(1)
        logits = self(context, candidate, torch.argmax(labels, 1))
        loss = self.loss(logits, torch.argmax(labels, 1))
        preds = logits.view(-1)
        targets = labels.view(-1)
        indexes = (
            torch.arange(logits.shape[0])
            .unsqueeze(1)
            .expand_as(logits)
            .reshape(preds.shape)
        )
        train_metrics = self.train_metrics(preds, targets, indexes=indexes)
        self.log_dict(train_metrics, on_epoch=True, on_step=True, batch_size=self.batch_size)
        self.log('train_loss', loss)
        self.log('lr', self.trainer.optimizers[0].param_groups[0]["lr"])
        return loss

    def validation_step(self, val_batch, batch_idx):
        context = val_batch["context"]
        candidate = val_batch["candidate"]
        persona = val_batch["persona"]
        b_size = context['input_ids'].size()[0]
        labels = torch.zeros((b_size, b_size), dtype=torch.long).to(self.device)
        labels.fill_diagonal_(1)
        logits = self(context, candidate, torch.argmax(labels, 1))
        loss = self.loss(logits, torch.argmax(labels, 1))
        preds = logits.view(-1)
        targets = labels.view(-1)
        indexes = (
            torch.arange(logits.shape[0])
            .unsqueeze(1)
            .expand_as(logits)
            .reshape(preds.shape)
        )
        train_metrics = self.train_metrics(preds, targets, indexes=indexes)
        self.log('val_loss', loss)
        val_metrics = self.val_metrics(preds, targets, indexes=indexes)
        self.log_dict(val_metrics, on_epoch=True, on_step=True, batch_size=batch_size)

        return val_metrics, loss


    def configure_optimizers(self):
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        self.scheduler =transformers.get_cosine_schedule_with_warmup(self.optimizer,
                                                                     num_warmup_steps=5000, 
                                                                     num_training_steps=scheduler_len)
        return [self.optimizer], [{"scheduler": self.scheduler, "interval": "step"}]

    def forward(self, 
                context, 
                candidat, 
                labels):
        context_vec = self.context_BERT(**context)
        candidat_vec = self.candidat_BERT(**candidat)
        context_vec = aggregate_encoder_output(context_vec, mod='pooler_output')
        candidat_vec = aggregate_encoder_output(candidat_vec, mod='pooler_output')
        distance = sim_func(context_vec, candidat_vec, 'DotProduct')
        return distance

In [32]:
epochs=15
lr = 7e-5
batch_size = 6
context_len = 32
candidate_len = 32
persona_len = 32
val_split = 5

pretrained_path = '/home/posokhov@ad.speechpro.com/projects/models/conversational/'
data_path = 'TlkPersonaChatRus/TolokaPersonaChat_v1.jsonl'

tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_path, truncation_side='left', padding_side='right')
special_tokens_dict = {'additional_special_tokens': ['[P1x]', '[P1x]','[P2y]', '[P2y]', '[P1u]', '[P2u]', '[Gk]']} 
tokenizer.add_special_tokens(special_tokens_dict)
#[P1x] P-turn start, 1-user, 2-model, x-male, y-female, u-unknown
context_bert = transformers.AutoModel.from_pretrained(pretrained_path)
context_bert.resize_token_embeddings(len(tokenizer))
candidate_bert = transformers.AutoModel.from_pretrained(pretrained_path)
candidate_bert.resize_token_embeddings(len(tokenizer))

dataset = PersonaRetrievalDataset(data_path)
train_dataset, val_dataset =  torch.utils.data.random_split(dataset, [len(dataset)-(len(dataset)//val_split), (len(dataset)//val_split)])

callator = RetrievalCollator(tokenizer, padding = 'max_length', max_length = context_len)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=callator)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=callator)

scheduler_len = len(train_dataloader)*epochs
num_warmup_steps = 5000

Some weights of the model checkpoint at /home/posokhov@ad.speechpro.com/projects/models/conversational/ were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at /home/posokhov@ad.speechpro.com/proje

In [33]:
model = RetrievalModel(context_bert, candidate_bert, batch_size, scheduler_len, num_warmup_steps)
logger = pl.loggers.comet.CometLogger(api_key='sEJsZrYjwc0gxxUAUGQNBwTsb', save_dir="logs", project_name='bi_encoder', experiment_name='nopersona_answers')
trainer = pl.Trainer(max_epochs=epochs, accelerator='gpu', devices=1, 
                     gradient_clip_val=1, logger=logger, num_sanity_val_steps=0)
trainer.fit(model, train_dataloader, val_dataloader)

CometLogger will be initialized in online mode


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | context_BERT  | BertModel        | 108 M 
1 | candidat_BERT | BertModel        | 108 M 
2 | loss          | CrossEntropyLoss | 0     
3 | train_metrics | MetricCollection | 0     
4 | val_metrics   | MetricCollection | 0     
---------------------------------------------------
216 M     Trainable params
0         Non-trainable params
216 M     Total params
866.513   Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 0:   2%|▏         | 417/26463 [10:49<11:16:24,  1.56s/it, loss=1.82, v_num=2476]
Epoch 0:   0%|          | 0/13232 [11:17<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 10.92 GiB total capacity; 5.02 GiB already allocated; 14.50 MiB free; 5.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF