In [1]:
import math
from s3prl.nn import S3PRLUpstream
from artprob.model import CubicHermiteSplineInterpolation, \
    LinearInterpolation, NaturalCubicSplineInterpolation, S3PRLWrapper
from artprob.criterion import LinearCriterion
from artprob.dataset import AudioDataset, find_file_paths, Extension, ART_PARAMS
import torch
from pathlib import Path

In [2]:
import matplotlib.pyplot as plt
import numpy as np

# assuming your arrays are numpy arrays:
# res_piece, res_cubic, and res_krogh

EMA_KEY = 'ema'

def plot_arrays(
    results,
    correlations=None,
    to_discard=list(),
    utterance=None,
    timings=None,
    phonemes=None,
    fst_index=0,
    lst_index=None,
    titles=None
):
    nrows, ncols = next(iter(results.values())).shape  # assuming all arrays have the same shape
    
    if utterance is not None:
        print(f'Utterance: "{utterance}"')
    
    if EMA_KEY in results and len(results) > 1:
        others = set(results.keys())
        others.remove(EMA_KEY)
        print('MSE:')
        for oth in others:
            if oth in to_discard:
                continue
            mse = np.linalg.norm(results[EMA_KEY] - results[oth], axis=0)
            avg = mse.mean()
            mse_str = map(lambda x: f'{x:.3e}', mse.tolist())
            print(f'\t{oth}: {" | ".join(mse_str)}', end=' ')
            print(f'-> {avg:.3e}')

    if correlations is not None:
        print('Correlations:')
        for mod, corrs in correlations.items():
            if mod in to_discard:
                continue
            avg = corrs.mean()
            corrs_str = map(lambda x: f'{x:.3f}', corrs.tolist())
            print(f'\t{mod}: {" | ".join(corrs_str)}', end=' ')
            print(f'-> {avg:.3f}')

    if lst_index is None:
        lst_index = nrows

    fig, axes = plt.subplots(nrows=ncols, figsize=(10, 2*ncols))
    if ncols == 1:
        axes = [axes]

    if utterance is not None:
        fig.suptitle(f'"{utterance[:-1]}"')

    if titles is not None and len(titles) == len(axes):
        for ax, title in zip(axes, titles):
            ax.set_title(title)
    
    for j in range(ncols):
        for name, res in results.items():
            if name in to_discard:
                continue
            axes[j].plot(range(fst_index, lst_index), res[fst_index: lst_index, j], label=name)
    axes[0].legend(loc='upper right', fontsize='small')

    if timings is not None:
        for ax in axes:
            for t in timings:
                ax.axvline(math.floor(t*100), color='grey')
        if phonemes is not None:
            for ax in axes:
                for t, phnm in zip(timings, phonemes):
                    ax.text(
                        math.floor(t*100),
                        1.,
                        phnm,
                        color='blue',
                        fontsize='x-small',
                        ha='left',
                        va='top',
                        transform=ax.get_xaxis_transform()
                    )

    plt.tight_layout()
    plt.show()

In [3]:
sample_rate = 100
no_sil = True

In [4]:
file_paths = find_file_paths(
    Path('/path/to/mocha/timit/fsew0/train/'), 
    art_ext=Extension.PARAM,
    other_exts=(Extension.WAV, Extension.TXG, Extension.TRANS)
)

In [5]:
from dataclasses import dataclass
from typing import Optional

utterance = None
phonemes = None
ema = None
timings = None

@dataclass
class Configuration:
    criterion: LinearCriterion
    arguments: tuple[Optional[torch.Tensor], Optional[list[torch.Tensor]], torch.Tensor, list[torch.Tensor]]

def get_configuration(feat_set, lang, unknown_feats=False, hidden_sizes=None, utt_index='095'):
    dataset = AudioDataset(
        file_paths,
        phon_feats_path=f'/path/to/this/repo/phono_features/en_{feat_set}.npz',
        ema_sample_rate=sample_rate,
        ignore_phnm_cache=True,
        non_negative_feats=False,
        ema_not_art=False,
        txg_not_phone=True,
        language=lang,
        remove_silence=no_sil,
        biphoneme_mode=False,
        keep_unknown_phon_feats=unknown_feats
    )
    if utt_index is not None:
        index = [paths[0].stem for paths in dataset.file_paths].index(f'fsew0_{utt_index}')
    else:
        index = 0

    global utterance, phonemes
    if utterance is None:
        trans_file = dataset.file_paths[index][0].with_suffix('.trans')
        with open(trans_file, 'r') as fp:
            utterance = fp.readline().rstrip()
        print(f'The utterance to analyse is "{utterance}"')

    if phonemes is None:
        phnm_file = dataset.file_paths[index][0].with_suffix('.phnm')
        phnms = []
        with open(phnm_file, 'r') as fp:
            for line in fp:
                _, _, pp = line.split()
                phnms.append(pp)
        phonemes = phnms[1:-1]


    criterion = LinearCriterion(
        [dataset.num_features] if hidden_sizes is None else hidden_sizes,
        dataset.num_art_dimensions,
        bias=True
    )
    _ema, wav, phn = dataset[index]
    global ema, timings
    if ema is None:
        ema = _ema
        timings = phn[1:-1, :2].mean(dim=1).detach().cpu().numpy()
        print(f'Saved the ema data to the variable "ema".')
    arguments = wav.unsqueeze(0), torch.as_tensor([wav.size(0)]), phn.unsqueeze(0), torch.as_tensor([phn.size(0)])
    
    config = Configuration(criterion=criterion, arguments=arguments)
    
    return config

In [6]:
phnm_config = get_configuration(feat_set='phoneme', lang='uk', unknown_feats=False)
ap_scalar_config = get_configuration(feat_set='ap_scalar', lang='us', unknown_feats=False)
ap_1hot_config = get_configuration(feat_set='ap_1hot', lang='us', unknown_feats=False)
ipa_bin_config = get_configuration(feat_set='ipa', lang='uk', unknown_feats=False)
ipa_unk_config = get_configuration(feat_set='ipa', lang='uk', unknown_feats=True)

In [7]:
plot_arrays({'ema': ema.cpu()}, utterance=utterance, titles=ART_PARAMS, timings=timings, phonemes=phonemes)

In [8]:
linear = LinearInterpolation(sample_rate, 'mid')

unk_cubic = CubicHermiteSplineInterpolation(sample_rate, no_sil, True)
bin_cubic = CubicHermiteSplineInterpolation(sample_rate, no_sil, False)

nat_cubic_spline = NaturalCubicSplineInterpolation(sample_rate, 'mid')

hubert_base = S3PRLWrapper(S3PRLUpstream('hubert_base'), sample_rate, 'sinc')
hubert_config = get_configuration(feat_set='ipa', lang='uk', unknown_feats=True, hidden_sizes=hubert_base.model.hidden_sizes)

In [9]:
import copy

def probe_model(feat_config: Configuration, model: torch.nn.Module, checkpoint: str):
    state_dict = torch.load(checkpoint)
    criterion = copy.deepcopy(feat_config.criterion)
    criterion.load_state_dict(state_dict['best'])
    criterion.enable_ema_store()
    
    interp_seq = model(*feat_config.arguments)
    _ = criterion(interp_seq, ema.unsqueeze(0), None)
    pred_emas = torch.stack(
        [p_emas[0].T for p_emas in criterion._pred_emas],
        dim=0
    ) # num_layers seq_length ema_dims
    corrs = criterion.compute_correlations().T.detach().numpy() # num_layers ema_dims
    corrs_mean = corrs.mean(axis=1)
    best_idx = corrs_mean.argmax(axis=0)

    probe_seq = pred_emas[best_idx].detach().numpy()
    corrs = corrs[best_idx]

    return probe_seq, corrs

In [10]:
def run_linear_probing():
    results = {'ema': ema.cpu().numpy()}
    correlations = {}

    def update_dicts(model_name, *args):
        _res, _corrs = probe_model(*args)
        results[model_name] = _res
        correlations[model_name] = _corrs

    update_dicts(
        'hubert_base',
        hubert_config,
        hubert_base,
        '/path/to/this/repo/checkpoints/hubert-base/checkpoint_85.pt'
    )
    print('Added HuBERT-base!')

    update_dicts(
        'linear-phnm',
        phnm_config,
        linear,
        '/path/to/this/repo/checkpoints/linear-phnm/checkpoint_99.pt'
    )
    print('Added one-hot phoneme!')

    update_dicts(
        'cubic_hermite-ipa-unk',
        ipa_unk_config,
        unk_cubic,
        '/path/to/this/repo/checkpoints/cubic-hermite-ipa/checkpoint_99.pt'
    )
    update_dicts(
        'nat_cubic-ipa-unk',
        ipa_unk_config,
        nat_cubic_spline,
        '/path/to/this/repo/checkpoints/nat-cubic-spline-ipa/checkpoint_99.pt'
    )
    print('Added unknown IPA!')

    update_dicts(
        'cubic_hermite-ap-sca',
        ap_scalar_config,
        unk_cubic,
        '/path/to/this/repo/checkpoints/cubic-hermite-ap-sca/checkpoint_99.pt'
    )
    update_dicts(
        'nat_cubic-ap-sca',
        ap_scalar_config,
        nat_cubic_spline,
        '/path/to/this/repo/checkpoints/nat-cubic-spline-ap-sca/checkpoint_99.pt'
    )
    print('Added scalar AP!')

    update_dicts(
        'cubic_hermite-ap-1hot',
        ap_1hot_config,
        unk_cubic,
        '/path/to/this/repo/checkpoints/cubic-hermite-ap-1hot/checkpoint_99.pt'
    )
    update_dicts(
        'nat_cubic-ap-1hot',
        ap_1hot_config,
        nat_cubic_spline,
        '/path/to/this/repo/checkpoints/nat-cubic-spline-ap-1hot/checkpoint_99.pt'
    )
    print('Added 1-hot AP!')

    return results, correlations

In [11]:
results, correlations = run_linear_probing()

In [12]:
plot_arrays(
    results,
    correlations,
    utterance=utterance,
    titles=ART_PARAMS
)

In [13]:
plot_arrays(
    results,
    correlations,
    to_discard=['nat_cubic-ipa-unk', 'nat_cubic-ap-sca', 'nat_cubic-ap-1hot', 'linear-phnm'],
    utterance=utterance,
    titles=ART_PARAMS
)

In [14]:
plot_arrays(
    results,
    correlations,
    to_discard=['cubic_hermite-ap-sca', 'nat_cubic-ap-sca', 'cubic_hermite-ipa-unk', 'nat_cubic-ipa-unk', 'linear-phnm'],
    utterance=utterance,
    titles=ART_PARAMS
)

In [15]:
plot_arrays(
    results,
    correlations,
    to_discard=[
        'nat_cubic-ipa-unk',
        'nat_cubic-ap-sca',
        'nat_cubic-ap-1hot',
        'cubic_hermite-ap-1hot',
        'cubic_hermite-ap-sca',
        'hubert_base',
        # 'cubic_hermite-ipa-unk',
    ],
    utterance=utterance,
    titles=ART_PARAMS,
    timings=timings,
    phonemes=phonemes
)