In [2]:
import sys
sys.path.append("..")

from glob import glob
import matplotlib.pyplot as plt
import ipywidgets as ipw
from IPython.display import Audio
import numpy as np 
import pickle
from pathlib import Path

from imitative_agent import ImitativeAgent
from lib.dataset_wrapper import Dataset
from lib.notebooks import show_ema
from external import lpcynet

In [5]:
#agents_path = glob("../out/communicative_vs_imitative/imitative*/")
#agents_path.sort()
agents_path = ['../out/imitative_agent/scaling/bs_8/M0_6000_mn_seed_4']


In [6]:
agents_alias = {}

for agent_path in agents_path:
    agent = ImitativeAgent.reload(agent_path, load_nn=False)
    config = agent.config
        
    agent_i = agent_path[-2]
    agent_alias = " ".join((
        f"{','.join(config['dataset']['names'])}",
        f"synth_art={agent.synthesizer.config['dataset']['art_type']}",
        f"bi={config['model']['inverse_model']['bidirectional']}",
        f"({agent_i})",
    ))
    
    agents_alias[agent_alias] = agent_path

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [15]:
datasets_current_item = {}

def show_agent(agent_alias):
    agent_path = agents_alias[agent_alias]
    agent = ImitativeAgent.reload(agent_path)
    
    sound_type = agent.config["dataset"]["sound_type"]
    art_type = agent.synthesizer.config["dataset"]["art_type"]
    synth_dataset = agent.synthesizer.dataset
    
    def show_dataset(dataset_name, cut_silences=True, audio_only=False):
        dataset = Dataset(dataset_name)
        items_cepstrum = dataset.get_items_data(sound_type, cut_silences=cut_silences)
        items_source = dataset.get_items_data("source", cut_silences=cut_silences)
        sampling_rate = dataset.features_config["wav_sampling_rate"]
        
        items_ema = dataset.get_items_data("ema", cut_silences=cut_silences) if not audio_only else None
        
        if not audio_only:
            items_name = dataset.get_items_list()
        else:
            items_name = dataset.get_items_name(modality='cepstrum')
            
        if dataset_name in datasets_current_item:
            current_item = datasets_current_item[dataset_name]
        else:
            current_item = items_name[0][0] if audio_only else items_name[0]
        
        def resynth_item(item_name=current_item, freeze_source=False):
            datasets_current_item[dataset_name] = item_name
            
            item_cepstrum = items_cepstrum[item_name]
            item_source = items_source[item_name]
            item_wave = dataset.get_item_wave(item_name)
            nb_frames = len(item_cepstrum)
            
            repetition = agent.repeat(item_cepstrum)
            repeated_cepstrum = repetition["sound_repeated"]
            estimated_cepstrum = repetition["sound_estimated"]
            estimated_art = repetition["art_estimated"]
            
            if freeze_source:
                item_source[:] = (1, 0)
            
            repeated_sound = np.concatenate((repeated_cepstrum, item_source), axis=1)
            estimated_sound = np.concatenate((estimated_cepstrum, item_source), axis=1)

            repeated_wave = lpcynet.synthesize_frames(repeated_sound)
            estimated_wave = lpcynet.synthesize_frames(estimated_sound)
            
            print("Original sound:")
            display(Audio(item_wave, rate=sampling_rate))
            print("Repetition (Inverse model → Synthesizer → LPCNet):")
            display(Audio(repeated_wave, rate=sampling_rate))
            print("Estimation (Inverse model → Direct model → LPCNet):")
            display(Audio(estimated_wave, rate=sampling_rate))
            
            plt.figure(figsize=(nb_frames/20, 6), dpi=120)
            
            ax = plt.subplot(311)
            ax.set_title("original %s" % (sound_type))
            ax.imshow(item_cepstrum.T, origin="lower")
            
            ax = plt.subplot(312)
            ax.set_title("Repetition")
            ax.imshow(repeated_cepstrum.T, origin="lower")
            
            ax = plt.subplot(313)
            ax.set_title("Estimation")
            ax.imshow(estimated_cepstrum.T, origin="lower")
            
            plt.tight_layout()
            plt.show()
            
            if art_type == "art_params" and not audio_only:
                estimated_art = dataset.art_to_ema(estimated_art)
            
            if not audio_only:
                item_ema = items_ema[item_name]
                show_ema(estimated_art, reference=item_ema, dataset=synth_dataset)
        
        display(ipw.interactive(resynth_item, item_name=items_name, freeze_source=False))
    display(ipw.interactive(show_dataset, dataset_name=['test'], cut_silences=False, audio_only=True))

display(ipw.interactive(show_agent, agent_alias=sorted(agents_alias.keys())))

interactive(children=(Dropdown(description='agent_alias', options=('M0_6000_mn synth_art=art_params bi=True (_…