In [1]:
import numpy as np
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

from dialog import dialog_from_dict
from metrics import ConversationalEditDistance

In [None]:
dataset = load_dataset('multi_woz_v22', split='train')

In [3]:
embeddings_dir = 'conversation-similarity/cache'

In [4]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

model_name = 'all-MiniLM-L12-v2'
model = SentenceTransformer(model_name, device=device)

In [5]:
np.random.seed(1234)

In [6]:
pairs = np.random.choice(dataset, size=(500, 2))

In [7]:
for pair in pairs:
    pair[0] = dialog_from_dict(pair[0])
    pair[1] = dialog_from_dict(pair[1])

for pair in tqdm(pairs):
    pair[0].compute_embeddings(embeddings_dir, model, model_name)
    pair[1].compute_embeddings(embeddings_dir, model, model_name)

100%|██████████| 500/500 [00:01<00:00, 283.34it/s]


In [8]:
metric = ConversationalEditDistance(is_inverted=False)
distances = []
for pair in pairs:
    distances.append(metric(pair[0], pair[1]))