In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import os
%matplotlib widget
import matplotlib.pyplot as plt

from sklearn.manifold import TSNE
from pprint import pprint
from utils.dataset import all_datasets, update_config
from utils.utils import load_features, read_lines

In [None]:
def plot_TSS_TSNE(config):
    features = []
    labels = []
    lines = []
    for ind, accent in enumerate(config["all_accents"]):
        all_JSON_PATH = os.path.join(config["FULL_DATASET_PATH"], accent, "all.json")
        lines.extend(read_lines(all_JSON_PATH))
        curr_features = load_features(
            read_lines(all_JSON_PATH),
            config["dataset"],
            config["FULL_DATASET_PATH"],
            config["feature_type"],
        )
        features.append(curr_features)
        labels.extend([accent for _ in range(len(curr_features))])

    line_indices = {line["audio_filepath"]: ind for ind, line in enumerate(lines)}

    QUERY_JSON_PATH = os.path.join(
        config["FULL_DATASET_PATH"],
        config["target_directory_path"],
        config["target_accent"],
        "seed.json",
    )
    query_list = read_lines(QUERY_JSON_PATH)[: config["target"]]

    SETTING_PATH = os.path.join(
        config["target_directory_path"],
        config["target_accent"],
        "results",
        f"budget_{config['budget']}",
        "global-TSS",
        "target_{}".format(config["target"]),
        "fxn_{}".format(config["fxn"]),
        "feature_{}".format(config["feature"]),
        "sim_{}".format(config["sim"]),
        "eta_{}".format(config["eta"]),
        f"run_{1}",
    )
    SELECTION_JSON_PATH = os.path.join(
        config["FULL_DATASET_PATH"],
        SETTING_PATH,
        "train.json",
    )

    selection_list = read_lines(SELECTION_JSON_PATH)

    features = np.concatenate(features, axis=0)
    labels = labels

    feature_cols = [f"dim_{_}" for _ in range(features.shape[1])]
    df = pd.DataFrame(features, columns=feature_cols)
    df["label"] = labels
    perp = 40
    iters = 1000
    tsne = TSNE(n_components=2, verbose=1, perplexity=perp, n_iter=iters)
    tsne_results = tsne.fit_transform(df[feature_cols].values)
    df["tsne-2d-one"] = tsne_results[:, 0]
    df["tsne-2d-two"] = tsne_results[:, 1]

    with plt.style.context("default"):
        fig = plt.figure(figsize=(10, 10))
        _ax = fig.add_subplot(1, 1, 1)

        # palette = np.array()

        g = sns.scatterplot(
            x="tsne-2d-one",
            y="tsne-2d-two",
            hue="label",
            palette=sns.color_palette("hls", len(config["all_accents"])),
            data=df,
            # data=df.iloc[list(range(len(X_ground)))],
            legend="full",
            alpha=0.6,
            ax=_ax,
        )

        # Add query_points
        query_inds = [line_indices[line["audio_filepath"]] for line in query_list]
        _ax.scatter(
            [tsne_results[ind, 0] for ind in query_inds],
            [tsne_results[ind, 1] for ind in query_inds],
            label="query",
            marker="*",
            c="darkred",
        )

        # Add selection points
        selection_inds = [line_indices[line["audio_filepath"]] for line in selection_list]
        _ax.scatter(
            [tsne_results[ind, 0] for ind in selection_inds],
            [tsne_results[ind, 1] for ind in selection_inds],
            label="selection",
            marker="+",
            c="darkgreen",
        )

        g.legend(
            loc="upper right",
        )
        plt.legend()
        fig.tight_layout()
        plt.title("TSNE-{}-{}".format(config["dataset"], config["feature_type"]))
        plt.savefig(f"../../Results/TSS/INDIC/{config['target_accent']}" + "-TSNE-{}-{}".format(config["dataset"], config["feature_type"]))
    
    plt.show()


In [None]:
config = {
    "dataset": "INDIC",
    "server": "SWARA",
    "feature_type": "MFCC",
    "target": 20,
    "target_accent": "manipuri-rajasthani::1-1",
    "target_directory_path": "mixed",
    "budget": 500,
    "fxn": "FL2MI",
    "eta": 1.0,
    "sim": "euclidean",
    "feature": "MFCC",
}
config = update_config(config)
# pprint(config)
plot_TSS_TSNE(config)