## Imports

In [85]:
# utils
import os
import sys
from pathlib import Path

PROJECT_DIR = Path.cwd().parent
sys.path.append(str(PROJECT_DIR))

import string

# viz
import matplotlib.pyplot as plt
# basics
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm

# metrics
from utils import config
from utils.reader import read_file_yaml

np.random.seed(0)

alphabet = string.ascii_lowercase

## Parameters

In [86]:
path_outputs = PROJECT_DIR / "outputs"
path_data = PROJECT_DIR / "data"
file_path_parameters = PROJECT_DIR / "conf" / "parameters.yml"

params = read_file_yaml(file_path_parameters)

ext_type = params["outputs"]["extension_type"]
ext_local_img = params["outputs"]["extension_local_img"]
ext_best_img = params["outputs"]["extension_best_img"]

file_path_data = {
    i: (path_data / i / Path(i + ext_type))
    if f"{i}_pca{ext_type}" not in os.listdir(path_data / i)
    else (path_data / i / Path(i + "_pca" + ext_type))
    for i in config.file_names
}

file_path_distribution = path_outputs / Path("dataset_distribution_labels" + ext_best_img)
file_path_distribution_best = path_outputs / Path("dataset_distribution_labels_row" + ext_best_img)

## Read datasets

In [87]:
data = {}
for i_name, i_file_path in file_path_data.items():
    data[i_name] = {"content": pd.read_csv(i_file_path), "filepath": i_file_path}

## Plot

In [88]:
_params = {
    "x": "0",
    "y": "1",
    "palette": sns.color_palette("dark"),
    "s": 50,
    "legend": False,
}

In [89]:
for i_name, i_data in tqdm(list(data.items())):
    if "no_structure" == i_name:
        tmp_params = _params | {"data": data[i_name]["content"]}
    else:
        tmp_params = _params | {"data": data[i_name]["content"], "hue": "labels"}
    tmp_params["palette"] = sns.color_palette("dark", data[i_name]["content"]["labels"].nunique())
    plt.figure(figsize=(12, 8), **params["outputs"]["args"])
    sns.scatterplot(**tmp_params)
    txt = i_name.replace("_", " ").title()
    plt.title(txt, fontsize=30)
    plt.savefig(path_outputs / Path(i_name + "_behavior" + ext_best_img),
                format=ext_best_img[1:],
                **params["outputs"]["args"])
    plt.close()

  sns.scatterplot(**tmp_params)
100%|██████████| 9/9 [00:00<00:00, 11.16it/s]


In [90]:
n_cols = 2
fig, axs = plt.subplots(int(np.ceil(len(data.items()) / n_cols)), n_cols, figsize=(10, 12), **params["outputs"]["args"])

if axs.shape == (2, ):
    axs = np.matrix(axs)
# Loop
row = 0
col = 0
for letter, (key, iter_data) in tqdm(zip(alphabet, data.items()), total=len(data.items())):
    if key == "no_structure":
        iter_data["content"].labels = 1
    unique_labels = iter_data["content"]["labels"].nunique()
    sns.scatterplot(
        x="0",
        y="1",
        data=iter_data["content"],
        hue="labels",
        ax=axs[row, col],
        palette=sns.color_palette("dark", iter_data["content"]["labels"].nunique()),
        legend=False,
    )
    axs[row, col].set_title(f"({letter}) " + key.replace("_", " ").title(), fontsize=20)
    axs[row, col].set_xlabel("")
    axs[row, col].set_ylabel("")
    axs[row, col].set_xticks([])
    axs[row, col].set_yticks([])
    col += 1
    if col == 2:
        col = 0
        row += 1
plt.tight_layout()
plt.close()

100%|██████████| 9/9 [00:00<00:00, 65.80it/s]


In [91]:
n_cols = 2
fig_row, axs = plt.subplots(n_cols, int(np.ceil(len(data.items()) / n_cols)), figsize=(18, 10), **params["outputs"]["args"])

if axs.shape == (2, ):
    axs = np.matrix(axs)
# Loop
row = 0
col = 0
for letter, (key, iter_data) in tqdm(zip(alphabet, data.items()), total=len(data.items())):
    if key == "no_structure":
        iter_data["content"].labels = 1
    unique_labels = iter_data["content"]["labels"].nunique()
    sns.scatterplot(
        x="0",
        y="1",
        data=iter_data["content"],
        hue="labels",
        ax=axs[col, row],
        palette=sns.color_palette("dark", unique_labels),
        legend=False,
    )
    axs[col, row].set_title(f"({letter}) " + key.replace("_", " ").title(), fontsize=20)
    axs[col, row].set_xlabel("")
    axs[col, row].set_ylabel("")
    axs[col, row].set_xticks([])
    axs[col, row].set_yticks([])
    col += 1
    if col == 2:
        col = 0
        row += 1
if len(data.items()) % 2 == 1:
    axs[-1, -1].axis("off")
fig_row.tight_layout()
plt.close()

100%|██████████| 9/9 [00:00<00:00, 65.35it/s]


## Save 

In [92]:
fig.savefig(
    file_path_distribution,
    format=ext_best_img[1:]
)  # best
fig.savefig(
    str(file_path_distribution).replace(ext_best_img, ext_local_img),
    format=ext_local_img[1:],
    **params["outputs"]["args"]
)  # local

In [95]:
file_path_distribution_best

PosixPath('/home/manuel/projects/repos/aaai-claire-clustering/outputs/dataset_distribution_labels_row.eps')

In [96]:
fig_row.savefig(
    file_path_distribution_best,
    format=ext_best_img[1:]
)  # best
fig_row.savefig(
    str(file_path_distribution_best).replace(ext_best_img, ext_local_img),
    format=ext_local_img[1:],
    **params["outputs"]["args"]
)  # local

'/home/manuel/projects/repos/aaai-claire-clustering/outputs/dataset_distribution_labels_row.png'