# Training of the Phonetic model

In [1]:
%load_ext autoreload
%autoreload 2

from finding_mnemo.pairing.dataset.phonetic_pair_dataset import PhoneticPairDataset
from finding_mnemo.pairing.dataset.phonetic_triplet_dataset import PhoneticTripletDataset
from finding_mnemo.pairing.dataset.generative_phonetic_triplet_dataset import GenerativePhoneticTripletDataset
from finding_mnemo.pairing.dataset.generative_phonetic_contrastive_dataset import GenerativePhoneticContrastiveDataset
from finding_mnemo.pairing.model.phonetic_siamese import PhoneticSiamese
from finding_mnemo.pairing.training.config import CONFIG, LossType
from torch.utils.data import Dataset
from pathlib import Path

from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.utilities.seed import seed_everything
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from jina import DocumentArray, Document
from pathlib import Path

import pandas as pd
import mlflow
import torch
import json

import umap
import umap.plot
umap.plot.output_notebook()

seed_everything(0)

  from .autonotebook import tqdm as notebook_tqdm
Install h5py to use hdf5 features: http://docs.h5py.org/
  warn(h5py_msg)
Global seed set to 0


0

In [2]:
def get_dataset() -> Dataset:
    """Returns a Dataset object given loss type."""
    if CONFIG.loss_type == LossType.Pair:
        dataset = PhoneticPairDataset(
            best_pairs_path=CONFIG.best_pairs_dataset, worst_pairs_path=CONFIG.worst_pairs_dataset
        )
    elif CONFIG.loss_type == LossType.Triplet:
        dataset = PhoneticTripletDataset(
            best_pairs_path=CONFIG.best_pairs_dataset, worst_pairs_path=CONFIG.worst_pairs_dataset
        )
    elif CONFIG.loss_type == LossType.Mixed:
        dataset = PhoneticTripletDataset(
            best_pairs_path=CONFIG.best_pairs_dataset, worst_pairs_path=CONFIG.worst_pairs_dataset
        )
    elif CONFIG.loss_type == LossType.GenerativeTriplet:
        dataset = GenerativePhoneticTripletDataset(
            english_data_path=CONFIG.english_dataset, mandarin_data_path=CONFIG.mandarin_dataset, size=500000
        )
    elif CONFIG.loss_type == LossType.GenerativeContrastive:
        dataset = GenerativePhoneticContrastiveDataset(
            english_data_path=CONFIG.english_dataset, mandarin_data_path=CONFIG.mandarin_dataset, size=50000
        )
    else:
        raise ValueError(f'Unknown loss type given: {CONFIG.loss_type}')
    return dataset

dataset = get_dataset()

In [3]:
train_set, val_set, test_set = torch.utils.data.random_split(
    dataset, [len(dataset)-600, 300, 300]
)

In [None]:
def instanciate(kwargs):
    train_dataloader = DataLoader(
        train_set, batch_size=kwargs["batch_size"], shuffle=True, num_workers=4
    )
    validation_dataloader = DataLoader(
        val_set, batch_size=kwargs["batch_size"], num_workers=4
    )
    test_dataloader = DataLoader(
        test_set, batch_size=kwargs["batch_size"], num_workers=4
    )
    model = PhoneticSiamese(
        embedding_dim=kwargs["embedding_dim"],
        dim_feedforward=kwargs["dim_feedforward"],
        nhead=kwargs["nhead"],
        dropout=kwargs["dropout"],
        loss_type=CONFIG.loss_type,
        batch_size=kwargs["batch_size"],
        weight_decay=kwargs["weight_decay"],
        lr=kwargs["lr"],
        margin=kwargs["margin"],
        lambda_triplet=kwargs["lambda_triplet"],
        lambda_pos=kwargs["lambda_pos"],
        lambda_neg=kwargs["lambda_neg"],
    )
    return {
        "train_dataloader": train_dataloader,
        "validation_dataloader": validation_dataloader,
        "test_dataloader": test_dataloader,
        "model": model,
    }

In [None]:
def train(
        dropout,
        lr,
        weight_decay,
        dim_feedforward,
        batch_size,
        nhead,
        embedding_dim,
        margin
):
    mlf_logger = MLFlowLogger(
        experiment_name=CONFIG.experiment_name, tracking_uri=CONFIG.log_folder
    )
    trainer = Trainer(
        max_epochs=CONFIG.max_epochs,
        logger=mlf_logger,
        # callbacks=[EarlyStopping(monitor="validation_loss", mode="min")],
        accelerator="gpu", 
        devices=1
    )

    model_config = {
            "dropout": dropout,
            "lr": lr,
            "weight_decay": weight_decay,
            "dim_feedforward": dim_feedforward,
            "batch_size": batch_size,
            "nhead": nhead,
            "embedding_dim": embedding_dim,
            "margin": margin,
            "model": "phonetic_siamese",
            "lambda_triplet": 0.5,
            "lambda_pos": 0.5,
            "lambda_neg": 0,
        }
    instance = instanciate(
        model_config
    )

    with open("model_config.json", "w") as f:
        json.dump(model_config, f)

    mlflow.pytorch.autolog()

    with mlflow.start_run():
        model = fit_model(
            instance["model"],
            instance["train_dataloader"],
            instance["validation_dataloader"],
            trainer,
        )

        test_loss = test_model(model, instance["test_dataloader"], trainer)[0]["test_loss"]

        torch.save(model.state_dict(), "model_dict")
        torch.save(model.state_dict(), "../finding_mnemo/pairing/model/model_dict")
        mlflow.log_artifact("model_dict", "model_dict")

        mlflow.pytorch.log_model(model, "model")

        with open("model_config.json", "w") as f:
            json.dump(model_config, f)

        mlflow.log_artifact("model_config.json", "model_config.json")

    return model, test_loss

def fit_model(model, train_dataloader, validation_dataloader, trainer):
    trainer.fit(model, train_dataloader, validation_dataloader)
    return model

def test_model(model, test_dataloader, trainer):
    return trainer.test(model, test_dataloader, verbose=False)

In [None]:
model, test_loss = train(
    dropout=0.2,
    lr=1e-3,
    weight_decay=1e-3,
    dim_feedforward=16,
    batch_size=16,
    nhead=1,
    embedding_dim=32,
    margin=0.2
)

print(f'Final test loss: {test_loss}')

## Ranking Evaluation

In [4]:
with open("model_config.json", "r") as f:
    model_config = json.load(f)
model = PhoneticSiamese(
    embedding_dim=model_config["embedding_dim"],
    dim_feedforward=model_config["dim_feedforward"],
    nhead=model_config["nhead"],
    dropout=model_config["dropout"],
    batch_size=model_config["batch_size"],
    weight_decay=model_config["weight_decay"],
    lr=model_config["lr"],
    margin=model_config["margin"],
    lambda_triplet=model_config["lambda_triplet"],
    lambda_pos=model_config["lambda_pos"],
    lambda_neg=model_config["lambda_neg"],
)
model.load_state_dict(
    torch.load("model_dict")
)

model = model.eval()

device = torch.device('cuda:0')
model.eval().to(device)

<All keys matched successfully>

In [None]:
model = model.to("cuda:0")

def load_documents() -> DocumentArray:
    dataframe = pd.read_csv(Path("../finding_mnemo/pairing/dataset/data/english.csv"))
    words = dataframe[['word', 'ipa']].astype(str)  

    local_da = DocumentArray([Document(text=w['word'], ipa=w['ipa']) for _, w in words.iterrows()])
    def embed(da: DocumentArray) -> DocumentArray:
            x = da[:,'tags__ipa']
            da.embeddings = model.encode(x).detach().cpu() 
            return da
    local_da.apply_batch(embed, batch_size=32)

    with DocumentArray() as da:
        da += local_da
    return da

In [None]:
def evaluate(da, n_limit=10):
    sample_data = dataset.best_pairs.sample(100)

    queries = sample_data['ipa_a']
    query_text = sample_data['word_a']
    targets = sample_data['word_b']

    # 1. Generate all embeddings for pairs.
    query_embeddings = model.encode(queries).detach().cpu() 
    query_docs = DocumentArray([Document(text=text, ipa=ipa, embedding=embedding) for ipa, text, embedding in zip(queries, query_text, query_embeddings)])

    # 2. For each pair: query the da
    [
        doc.match(da, metric='euclidean', limit=n_limit) for doc in query_docs
    ]

    matches = [
        doc.matches[:,'text'] for doc in query_docs
    ]

    # 3. For each pair: check if targetted word is amongst results
    scores = [
        t in match for t, match in zip(targets, matches)
    ]
    return scores, query_text, targets, matches


In [None]:
da = load_documents()

scores, queries, targets, matches = evaluate(da, n_limit=50)
print(f"Percentage of the original pair that have been retrieved: {sum(scores) / len(scores)*100} %")

In [None]:
for q, t, m in zip(queries, targets, matches):
    print(q, t, m)

## Inference example

In [None]:
device = torch.device('cuda:0')

In [None]:
import numpy as np
model.eval().to(device)

for i in np.random.randint(0, len(test_set), 10):
    sample = test_set[i]
    anchor_match = sample['anchor_phonetic']
    positive_match = sample['similar_phonetic']
    negative_match = sample['distant_phonetic']

    anchor_embedding = model.encode([anchor_match])
    positive_embedding = model.encode([positive_match])
    negative_embedding = model.encode([negative_match])

    loss = model.triplet_loss(anchor_embedding, positive_embedding, negative_embedding)

    positive_dist = torch.cdist(anchor_embedding, positive_embedding, p=2)
    negative_dist = torch.cdist(anchor_embedding, negative_embedding, p=2)

    print(f"Loss: {loss}")
    print(f"Positive: {positive_dist}")
    print(f"Negative: {negative_dist}")

In [None]:
from eng_to_ipa import convert
words = ['dog', 'parade', 'cascade', "palace", "table"]
ipas = [convert(x) for x in words]

device = torch.device('cuda:0')
model.eval().to(device)
embdedings = model.encode(ipas)

a = torch.cdist(embdedings[1].view(1, -1), embdedings[0].view(1, -1), p=2)
b = torch.cdist(embdedings[1].view(1, -1), embdedings[2].view(1, -1), p=2)
c = torch.cdist(embdedings[1].view(1, -1), embdedings[3].view(1, -1), p=2)
d = torch.cdist(embdedings[1].view(1, -1), embdedings[4].view(1, -1), p=2)

a, b, c, d

In [None]:
device = torch.device('cuda:0')
model.eval().to(device)

sample = test_set[4]
print("Sample: ", sample)

embdedings = model.encode([sample['anchor_phonetic'], sample['similar_phonetic'], sample['distant_phonetic']])

a = torch.cdist(embdedings[0].view(1, -1), embdedings[1].view(1, -1), p=2)
b = torch.cdist(embdedings[0].view(1, -1), embdedings[2].view(1, -1), p=2)
print("Good distance: ", a)
print("Bad distance: ", b)

In [None]:
embdedings = model.encode(['eɪdriætɪk', 'ɪdriætɪkl', 'edritɪk'])

a = torch.cdist(embdedings[0].view(1, -1), embdedings[1].view(1, -1), p=2)
b = torch.cdist(embdedings[0].view(1, -1), embdedings[2].view(1, -1), p=2)
print(a, b)

In [None]:
from eng_to_ipa import convert
words = ['dog', 'parade', 'cascade', "palace", "table"]
ipas = [convert(x) for x in words]
print(ipas)

In [None]:
a = "sɪmɪlə" 
b = "prɪˈdɒmɪnəntli"

In [None]:
from finding_mnemo.pairing.utils.distance import panphon_dtw, levenshtein_distance

In [None]:
panphon_dtw(a,b), levenshtein_distance(a,b)

In [None]:
a = "sɪmɪ" 
b = "lar"

In [None]:
panphon_dtw(a,b), levenshtein_distance(a,b)

## Embedding visualization

In [7]:
english_df = pd.read_csv(CONFIG.english_dataset, nrows= 5000)
mandarin_df = pd.read_csv(CONFIG.mandarin_dataset, nrows= 5000)

df = pd.concat((mandarin_df[['ipa', 'pinyin']].rename(columns={"pinyin": "word"}), english_df[['ipa', 'word']])).reset_index()

df = df.sample(1000)

In [8]:
words = df['ipa'].astype(str)
embdedings = model.encode(words)

mapper = umap.UMAP().fit(embdedings.detach().cpu())

In [9]:
p = umap.plot.interactive(mapper, labels=df.word.str.len(), hover_data=df[["word", "ipa"]], point_size=10)
umap.plot.show(p)