In [1]:
import os

import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

import metrics
from dialog import Dialog, DialogTriplet, dialog_from_file

In [2]:
def get_dialog_filepath(root_dir: str, dialog_id: str):
    return os.path.join(root_dir, '{0}.json'.format(dialog_id))

In [3]:
def load_dialog_triplets(
    metadata: pd.DataFrame, root_dir: str,
) -> list[DialogTriplet]:
    dialog_triplets = []
    for triplet in metadata.to_dict('records'):
        label = triplet['more_similar_conv'] - 1
        confidence_score = triplet['more_similar_conv_confidence']

        anchor_filepath = get_dialog_filepath(root_dir, triplet['anchor_conv'])
        dialog_1_filepath = get_dialog_filepath(root_dir, triplet['conv_1'])
        dialog_2_filepath = get_dialog_filepath(root_dir, triplet['conv_2'])

        dialog_triplets.append(
            DialogTriplet(
                anchor_dialog=dialog_from_file(anchor_filepath),
                dialog_1=dialog_from_file(dialog_1_filepath),
                dialog_2=dialog_from_file(dialog_2_filepath),
                label=label,
                confidence_score=confidence_score,
            ),
        )
    return dialog_triplets

In [4]:
def compute_embeddings(
    dialog: Dialog,
    cache_dir: str,
    model: SentenceTransformer,
    model_name: str,
) -> Dialog:
    for idx, turn in enumerate(dialog.turns):
        embedding_filepath = os.path.join(
            cache_dir,
            '{0}_{1}_{2}.npy'.format(
                model_name, dialog.dialog_id, idx,
            ),
        )
        if os.path.isfile(embedding_filepath):
            embedding = np.load(embedding_filepath)
        else:
            embedding = model.encode([turn.utterance])[0]
            np.save(embedding_filepath, embedding)
        dialog.turns[idx].embedding = embedding
    return dialog

In [5]:
metadata_filepath = 'conversation-similarity/conved.csv'
dialogs_dir = 'conversation-similarity/dialogs'
embeddings_dir = 'conversation-similarity/cache'

os.makedirs(embeddings_dir, exist_ok=True)

In [6]:
triplets = load_dialog_triplets(
    pd.read_csv(metadata_filepath),
    dialogs_dir,
)

In [7]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

model_name = 'all-MiniLM-L12-v2'
model = SentenceTransformer(model_name, device=device)

In [8]:
for triplet in tqdm(triplets):
    triplet.anchor_dialog = compute_embeddings(triplet.anchor_dialog, embeddings_dir, model, model_name)
    triplet.dialog_1 = compute_embeddings(triplet.dialog_1, embeddings_dir, model, model_name)
    triplet.dialog_2 = compute_embeddings(triplet.dialog_2, embeddings_dir, model, model_name)

100%|██████████| 502/502 [00:03<00:00, 135.22it/s]


In [9]:
CONFIDENCE_THRESHOLD = 0.8

### Constant distance

In [10]:
metrics.get_metric_agreement(
    dialog_triplets=triplets,
    metric=metrics.ExampleMetric(is_inverted=False),
    confidence_threshold=CONFIDENCE_THRESHOLD,
)

0.5125

### ConvED

In [11]:
metrics.get_metric_agreement(
    dialog_triplets=triplets,
    metric=metrics.ConversationalEditDistance(is_inverted=False),
    confidence_threshold=CONFIDENCE_THRESHOLD,
)

0.625

### Cosine distance (average embedding)

In [12]:
metrics.get_metric_agreement(
    dialog_triplets=triplets,
    metric=metrics.CosineDistance(is_inverted=False),
    confidence_threshold=CONFIDENCE_THRESHOLD,
)

0.475

### Lp distance (average embedding)

#### p = 1

In [13]:
metrics.get_metric_agreement(
    dialog_triplets=triplets,
    metric=metrics.LpDistance(is_inverted=False, p=1),
    confidence_threshold=CONFIDENCE_THRESHOLD,
)

0.50625

#### p = 2

In [14]:
metrics.get_metric_agreement(
    dialog_triplets=triplets,
    metric=metrics.LpDistance(is_inverted=False, p=2),
    confidence_threshold=CONFIDENCE_THRESHOLD,
)

0.53125

### Dot product similarity (average embedding)

In [15]:
metrics.get_metric_agreement(
    dialog_triplets=triplets,
    metric=metrics.DotProductSimilarity(is_inverted=True),
    confidence_threshold=CONFIDENCE_THRESHOLD,
)

0.50625