In [1]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os

In [2]:
def print_label_distribution(dictionary):
    """
    This function prints the distribution of various target variables for each currency in the given dictionary.
    
    Parameters:
    dictionary (dict): A dictionary where keys are currency names and values are DataFrames containing various target variables.
    
    Returns:
    None: This function prints the distributions directly.
    """
    
    # Iterate over each currency in the dictionary
    for currency in dictionary:
        print(currency)
        
        # Print the distribution of 'excess_over_mean' labels
        print("Excess over mean:")
        print(dictionary[currency]["excess_over_mean"].value_counts())
        
        # Print the distribution of 'excess_over_median' labels
        print("Excess over median:")
        print(dictionary[currency]["excess_over_median"].value_counts())
        
        # Print the distribution of 'tbm_label' (triple barrier method) labels
        print("Triple Barrier Method:")
        print(dictionary[currency]["tbm_label"].value_counts())
        
        # Print the distribution of 'fth_label' (fixed time horizon) labels
        print("Fixed Time Horizon:")
        print(dictionary[currency]["fth_label"].value_counts())
        
        # Print the distribution of 'tail_sets' labels
        print("Tail sets:")
        print(dictionary[currency]["tail_sets"].value_counts())
        
        # Print the distribution of 'trend' labels
        print("Trend scanning:")
        print(dictionary[currency]["trend"].value_counts())
        
        # Print the distribution of 'matrix_flag' labels
        print("Matrix flag:")
        print(dictionary[currency]["matrix_flag"].value_counts())
        
        # Print a separator line for better readability
        print("-" * 30)


In [3]:
def check_label_distribution(dictionary):
    """
    This function checks the distribution of labels across various target variables for each currency in the given dictionary
    and returns a DataFrame summarizing the counts of each label (-1, 0, 1).
    
    Parameters:
    dictionary (dict): A dictionary where keys are currency names and values are DataFrames containing various target variables.
    
    Returns:
    DataFrame: A DataFrame summarizing the counts of each label (-1, 0, 1) across different labeling methods.
    """
    
    # Initialize an empty DataFrame with columns for each label (-1, 0, 1) and rows for each labeling method
    df_columns = [-1, 0, 1]
    df = pd.DataFrame(index=[
        "excess_over_mean",
        "excess_over_median",
        "tbm_label",
        "fth_label",
        "tail_sets",
        "trend",
        "matrix_flag",
        "next_period"
    ], columns=df_columns, dtype=int).fillna(0)
    
    # Iterate through each currency in the dictionary
    for currency in dictionary:
        # Iterate through each labeling method in the DataFrame index
        for method in df.index:
            # Count the occurrences of each label (-1, 0, 1) for the current method
            label_counts = dictionary[currency][method].value_counts()
            # Ensure all labels (-1, 0, 1) are present in the label_counts and update the DataFrame
            for label in df_columns:
                if label in label_counts:
                    df.at[method, label] += label_counts[label]

    return df

In [4]:
def explore_label_overlap(df, label_columns):
    """
    This function calculates the overlap of labels between different columns in the given DataFrame.
    
    Parameters:
    df (DataFrame): The DataFrame containing the label columns to be compared.
    label_columns (list): A list of column names in the DataFrame whose label overlap is to be explored.
    
    Returns:
    DataFrame: A DataFrame summarizing the overlap counts between each pair of label columns.
    """
    
    # Initialize an empty DataFrame to store the overlap counts
    overlaps = pd.DataFrame(index=label_columns, columns=label_columns)
    
    # Iterate over each pair of columns in the label_columns list
    for col1 in label_columns:
        for col2 in label_columns:
            # Count the number of times the labels in the two columns are the same
            overlap_count = sum(df[col1] == df[col2])
            overlaps.loc[col1, col2] = overlap_count
    
    return overlaps.astype(int)


In [5]:
def visualize_label_overlap(overlaps_df, currency, save_dir, frequency):
    """
    This function visualizes the label overlap correlation matrix for a given currency and saves the plot as an image file.
    
    Parameters:
    overlaps_df (DataFrame): The DataFrame containing the overlap counts between each pair of label columns.
    currency (str): The name of the currency for which the overlap visualization is generated.
    save_dir (str): The directory where the plot image will be saved.
    frequency (str): The frequency label to include in the plot's filename.
    
    Returns:
    None: This function saves the visualization as an image file.
    """
    
    # Create figure with white background
    fig = plt.figure(figsize=(10, 8), facecolor='white')
    
    # Create a heatmap to visualize the overlap counts
    sns.heatmap(overlaps_df, annot=True, cmap="YlGnBu", fmt="d")
    plt.title(f"Label Overlap Correlation Matrix for {currency}")
    plt.xlabel("Columns")
    plt.ylabel("Columns")
    
    # Customize the axis tick labels
    ax = plt.gca()
    ax.set_xticklabels([
        'Excess over Mean', 'Excess over Median', 'Fixed time horizon', 
        'Triple barrier method', 'Trend scanning', 'Tail sets', 'Matrix flag', 'Next Period'
    ], rotation=45)
    ax.set_yticklabels([
        'Excess over Mean', 'Excess over Median', 'Fixed time horizon', 
        'Triple barrier method', 'Trend scanning', 'Tail sets', 'Matrix flag', 'Next Period'
    ], rotation=0)
    
    # Adjust layout to prevent cropping
    plt.tight_layout()
    
    # Create the save directory if it does not exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Save the plot with transparent background
    plt.savefig(f"{save_dir}/{currency}_label_overlap_{frequency}.png", transparent=True)
    plt.close(fig)

In [6]:
def explore_and_visualize_overlap(dictionary, label_columns, save_dir, frequency):
    """
    This function explores the overlap of labels between different columns for each currency in the given dictionary
    and visualizes the overlap as a heatmap, saving the plot as an image file.
    
    Parameters:
    dictionary (dict): A dictionary where keys are currency names and values are DataFrames containing various label columns.
    label_columns (list): A list of column names in the DataFrame whose label overlap is to be explored and visualized.
    save_dir (str): The directory where the plot images will be saved.
    frequency (str): The frequency label to include in the plot's filename.
    
    Returns:
    None: This function saves the visualizations as image files.
    """
    
    # Iterate over each currency and its corresponding DataFrame in the dictionary
    for currency, df in dictionary.items():
        # Explore the label overlap for the current DataFrame
        overlaps_df = explore_label_overlap(df, label_columns)
        
        # Visualize the label overlap and save the plot as an image file
        visualize_label_overlap(overlaps_df, currency, save_dir, frequency)

In [7]:
def check_label_distribution(dictionary):
    """
    This function checks the distribution of labels across various target variables for each currency in the given dictionary
    and returns a DataFrame summarizing the counts of each label (-1, 0, 1).
    
    Parameters:
    dictionary (dict): A dictionary where keys are currency names and values are DataFrames containing various target variables.
    
    Returns:
    DataFrame: A DataFrame summarizing the counts of each label (-1, 0, 1) across different labeling methods.
    """
    
    # Initialize an empty DataFrame with columns for each label (-1, 0, 1) and rows for each labeling method
    df_columns = [-1, 0, 1]
    df = pd.DataFrame(index=[
        "excess_over_mean",
        "excess_over_median",
        "tbm_label",
        "fth_label",
        "tail_sets",
        "trend",
        "matrix_flag",
        "next_period"
    ], columns=df_columns, dtype=int).fillna(0)
    
    # Iterate through each currency in the dictionary
    for currency in dictionary:
        # Iterate through each labeling method in the DataFrame index
        for method in df.index:
            # Count the occurrences of each label (-1, 0, 1) for the current method
            label_counts = dictionary[currency][method].value_counts()
            # Ensure all labels (-1, 0, 1) are present in the label_counts and update the DataFrame
            for label in df_columns:
                if label in label_counts:
                    df.at[method, label] += label_counts[label]

    # Rename the columns and index of the DataFrame
    df.columns = ['Short signal', 'No trade signal', 'Long signal']
    df.index = [
        'Excess over Mean',
        'Excess over Median',
        'Triple Barrier',
        'Fixed Time Horizon',
        'Tail Sets',
        'Trend Scanning',
        'Matrix Flag',
        'Next Period Labeling'
    ]

    return df

In [8]:
def visualize_label_overlap(overlaps_df, currency, save_dir, frequency):
    """
    This function visualizes the label overlap correlation matrix for a given currency and saves the plot as an image file.
    
    Parameters:
    overlaps_df (DataFrame): The DataFrame containing the overlap counts between each pair of label columns.
    currency (str): The name of the currency for which the overlap visualization is generated.
    save_dir (str): The directory where the plot image will be saved.
    frequency (str): The frequency label to include in the plot's filename.
    
    Returns:
    None: This function saves the visualization as an image file.
    """
    
    # Calculate the relative values
    overlaps_df_relative = overlaps_df.div(overlaps_df.max(axis=1), axis=0) * 100

    # Create figure with white background
    fig = plt.figure(figsize=(10, 8), facecolor='white')
    
    # Create a heatmap to visualize the relative overlap counts
    sns.heatmap(overlaps_df_relative, annot=True, cmap="YlGnBu", fmt=".2f", cbar_kws={'label': 'Percentage'})
    plt.title(f"Label Overlap Correlation Matrix for {currency}")
    
    # Customize the axis tick labels
    ax = plt.gca()
    ax.set_xticklabels([
        'Excess over Mean', 'Excess over Median', 'Fixed time horizon', 
        'Triple barrier method', 'Trend scanning', 'Tail sets', 'Matrix flag', 'Next Period'
    ], rotation=45, ha="right")
    ax.set_yticklabels([
        'Excess over Mean', 'Excess over Median', 'Fixed time horizon', 
        'Triple barrier method', 'Trend scanning', 'Tail sets', 'Matrix flag', 'Next Period'
    ], rotation=0)
    
    # Remove the axis labels
    ax.set_xlabel("")
    ax.set_ylabel("")
    
    # Adjust layout to prevent cropping
    plt.tight_layout()
    
    # Create the save directory if it does not exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Save the plot with transparent background
    plt.savefig(f"{save_dir}/{currency}_label_overlap_{frequency}.png", transparent=True)
    plt.close(fig)