In [5]:
import sys
sys.path.append("..")

from pathlib import Path
from glob import glob
import matplotlib.pyplot as plt
import ipywidgets as ipw
import numpy as np 
from tqdm.notebook import tqdm

from inverse_model import InverseModel
from lib.dataset_wrapper import Dataset
from lib import utils
from lib import abx_utils

ABX_NB_SAMPLES = 200
distance = {
        "art_estimated": {
            "metric": "cosine",
            "weight": 1,
        },
    }

model_path = Path('../out/inverse_model')
basenames = ['full_pb']

# Create alias
agents_alias = {}
for basename in basenames:
    agent_path = model_path / basename
    agent = InverseModel.reload(str(agent_path), load_nn=False)
        
    agent_alias = " ".join((
        f"path={agent_path}",
    ))
    agents_alias[agent_alias] = agent_path

# Compute ABX 
agents_abx_matrices = utils.pickle_load(model_path / 'abx_cache.pickle', {})
for agent_alias, agent_path in tqdm(agents_alias.items()):
    if agent_path not in agents_abx_matrices:
        agents_abx_matrices[agent_path] = {}
    agent_abx_matrices = agents_abx_matrices[agent_path]
    
    distance_signature = abx_utils.get_distance_signature(distance)
        
    agent = InverseModel.reload(str(agent_path))
    main_dataset = agent.get_main_dataset()
    agent_lab = agent.get_datasplit_lab(0)
    agent_features = agent.predict_datasplit(0)
    consonants = main_dataset.phones_infos["consonants"]
    vowels = main_dataset.phones_infos["vowels"]
    consonants_indexes = abx_utils.get_datasets_phones_indexes(agent_lab, consonants, vowels)
    abx_matrix = abx_utils.get_abx_matrix(consonants, consonants_indexes, agent_features, distance, ABX_NB_SAMPLES, seed=43)
    agent_abx_matrices[distance_signature] = abx_matrix
    utils.pickle_dump(model_path / 'abx_cache.pickle', agents_abx_matrices)
    
for basename in basenames:
    agent_path = model_path / basename
    distance_signature = abx_utils.get_distance_signature(distance)
    agent_abx_matrix = agents_abx_matrices[agent_path][distance_signature]
    groups_score = abx_utils.get_groups_score(consonants, agent_abx_matrix, main_dataset.phones_infos["consonant_groups"])
    global_score = abx_utils.get_global_score(agent_abx_matrix)
    print(agent_path, groups_score)
    print(agent_path, global_score)

  0%|          | 0/1 [00:00<?, ?it/s]

../out/inverse_model/full_pb {'manner': 62.391666666666666, 'place': 81.22916666666666}
../out/inverse_model/full_pb 77.98684210526315
