In [1]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os

In [2]:
# check the distribution of the target variable
def print_label_distribution(dictionary):
    for currency in dictionary:
        print(currency)
        print("excess over mean")
        print(dictionary[currency]["excess_over_mean"].value_counts())
        print("excess over median")
        print(dictionary[currency]["excess_over_median"].value_counts())
        print("triple Barrier")
        print(dictionary[currency]["tbm_label"].value_counts())
        print("fixed time horizon")
        print(dictionary[currency]["fth_label"].value_counts())
        print("tail sets")
        print(dictionary[currency]["tail_sets"].value_counts())
        print("trend scanning")
        print(dictionary[currency]["trend"].value_counts())
        print("matrix flag")
        print(dictionary[currency]["matrix_flag"].value_counts())
        print("-"*30)

In [3]:
def check_label_distribution(dictionary):
    # Initialize an empty DataFrame with columns for each label
    df_columns = [-1, 0, 1]
    df = pd.DataFrame(index=[
        "excess_over_mean",
        "excess_over_median",
        "tbm_label",
        "fth_label",
        "tail_sets",
        "trend",
        "matrix_flag"
    ], columns=df_columns, dtype=int).fillna(0)

    # Iterate through each currency and each labeling method to count the labels
    for currency in dictionary:
        for method in df.index:
            label_counts = dictionary[currency][method].value_counts()
            # Ensure all labels (-1, 0, 1) are present in the label_counts
            for label in df_columns:
                if label in label_counts:
                    df.at[method, label] += label_counts[label]

    return df

In [4]:
def explore_label_overlap(df, label_columns):
    overlaps = pd.DataFrame(index=label_columns, columns=label_columns)
    
    for col1 in label_columns:
        for col2 in label_columns:
            overlap_count = sum(df[col1] == df[col2])
            overlaps.loc[col1, col2] = overlap_count
    
    return overlaps.astype(int)

def visualize_label_overlap(overlaps_df, currency, save_dir, frequency):
    
    fig = plt.figure(figsize=(10, 8), facecolor='white')  # Create figure with white background
    sns.heatmap(overlaps_df, annot=True, cmap="YlGnBu", fmt="d")
    plt.title(f"Label Overlap Correlation Matrix for {currency}")
    plt.xlabel("Columns")
    plt.ylabel("Columns")
    ax = plt.gca()
    ax.set_xticklabels(['Excess over Mean', 'Excess over Median', 'Fixed time horizon', 'Triple barrier method', 'Trend scanning', 'Tail sets', 'Matrix flag'], rotation=45)
    ax.set_yticklabels(['Excess over Mean', 'Excess over Median', 'Fixed time horizon', 'Triple barrier method', 'Trend scanning', 'Tail sets', 'Matrix flag'], rotation=0)
    plt.tight_layout()  # Adjust layout to prevent cropping
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    plt.savefig(f"{save_dir}/{currency}_label_overlap{currency}.png", transparent=True)  # Save with transparent background
    plt.close(fig)

def explore_and_visualize_overlap(dictionary, label_columns, save_dir, frequency):
    for currency, df in dictionary.items():
        overlaps_df = explore_label_overlap(df, label_columns)
        visualize_label_overlap(overlaps_df, currency, save_dir, frequency)