## Imports

In [None]:
# utils
import os
import sys
from pathlib import Path
PROJECT_DIR = Path.cwd().parent
sys.path.append(str(PROJECT_DIR))

# basics
import numpy as np
import pandas as pd
from tqdm import tqdm

# viz
import matplotlib.pyplot as plt
import seaborn as sns

# metrics
from utils import config
from utils.reader import read_file_yaml
from utils.plot.plot import GeneratePlots
import string

np.random.seed(0)

alphabet = string.ascii_lowercase

## Parameters

In [None]:
params = read_file_yaml(file_path_parameters)

path_outputs = (
    PROJECT_DIR
    / "outputs"
)
path_data = (
    PROJECT_DIR
    / "data"
)
file_path_parameters = (
    PROJECT_DIR
    / "conf" 
    / "parameters.yml"
)

file_path_data = {
    i: path_data / i / Path(i+".csv") for i in config.file_names
}
file_path_distribution = (
    path_outputs 
    / "dataset_distribution_labels.eps"
)


## Read datasets

In [None]:
data = {}
for i_name, i_file_path in file_path_data.items():
    data[i_name] = {
        "content": pd.read_csv(i_file_path),
        "filepath": i_file_path
    }

## Plot

In [None]:
_params = {
    "x": "0",
    "y": "1",
    "palette": sns.color_palette("dark"),
    "s": 50,
    "legend": False,
}

In [None]:
_params.keys()

In [None]:
for i_name, i_data in data.items():
    if "no_structure" == i_name:
        tmp_params = _params | {"data": data[i_name]["content"]}
    else:
        tmp_params = _params | {"data": data[i_name]["content"], "hue": "labels"}
    plt.figure(figsize=(12, 8))
    sns.scatterplot(**tmp_params)
    txt = i_name.replace("_", " ").title()
    plt.title(txt, fontsize=30)
    plt.savefig(path_outputs / Path(i_name + "_behavior.eps"), format="eps")

In [None]:
int(np.ceil(len(data.items()) / 2))

In [None]:
n_cols = 2
fig, axs = plt.subplots(int(np.ceil(len(data.items()) / n_cols)), n_cols, figsize=(10, 12))

# Loop
row = 0
col = 0
for letter, (key, iter_data) in zip(alphabet, data.items()):
    if key == "no_structure":
        iter_data["content"].labels = 1
    sns.scatterplot(
        x="0",
        y="1",
        data=iter_data["content"],
        hue="labels",
        ax=axs[row, col],
        palette=sns.color_palette("dark"),
        legend=False,
    )
    axs[row, col].set_title(
        f"({letter}) " + key.replace("_", " ").title(), fontsize=20
    )
    axs[row, col].set_xlabel("")
    axs[row, col].set_ylabel("")
    axs[row, col].set_xticks([])
    axs[row, col].set_yticks([])
    col += 1
    if col == 2:
        col = 0
        row += 1
plt.tight_layout()
plt.savefig(file_path_distribution, format = "eps")