In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import matplotlib.cm as cm
import os

In [2]:
def plot_trend_labels(df, label_columns):
    """
    Function to plot trend labels and save the plot as an image.

    Parameters:
    df (pd.DataFrame): DataFrame containing 'Price' column and trend label columns.
    label_columns (list): List of column names containing trend labels.
    """
    # Define the slice size
    slice_size = 30
    
    # Ensure the DataFrame is large enough
    if len(df) <= slice_size:
        raise ValueError("DataFrame is too small for the specified slice size.")
    
    # Randomly select a slice that is not at the beginning or the end
    start_idx = np.random.randint(1, len(df) - slice_size - 1)
    df_slice = df.iloc[start_idx:start_idx + slice_size]

    # Identify the columns
    price_col = 'Price'
    
    time = df_slice.index.to_numpy()
    price = df_slice[price_col].to_numpy()

    # Generate a list of distinct colors
    color_map = cm.get_cmap('tab20', len(label_columns))
    colors = [color_map(i) for i in range(len(label_columns))]

    # Create subplots
    fig, axes = plt.subplots(len(label_columns), 1, figsize=(12, 4 * len(label_columns)), sharex=True)
    
    if len(label_columns) == 1:
        axes = [axes]  # Ensure axes is always a list of axes

    for i, label_col in enumerate(label_columns):
        labels = df_slice[label_col].to_numpy()
        min_price = min(price)
        max_price = max(price)
        y_limits = (min_price - 0.5, max_price + 0.5)

        axes[i].plot(time, price, label='Asset price', color='black', linewidth=2)
        axes[i].scatter(time, price, color='black', s=50, zorder=5)  # Big dots for price points
        axes[i].scatter(time, labels * (max_price + 1), color=colors[i], label='Trend labels', zorder=5, s=100)
        axes[i].plot(time, labels * (max_price + 1), color=colors[i], linewidth=2, zorder=4)
        axes[i].set_title(f'{label_col} Labeling', fontsize=16)
        axes[i].set_ylabel('Price', fontsize=14)
        axes[i].set_ylim(y_limits)

        # Set secondary y-axis for trend labels (-1, 0, 1)
        secax = axes[i].twinx()
        secax.set_ylim(-1.5, 1.5)
        secax.set_yticks([-1, 0, 1])
        secax.set_ylabel('Labels', fontsize=14)

        # Set primary y-axis to show the starting price and lowest price
        axes[i].set_yticks([min_price, max_price])
        axes[i].legend()

    # Common X label
    plt.xlabel('Number of Periods', fontsize=14)
    plt.xticks(rotation=45)
    plt.tight_layout(pad=2.0)

    # Create directory if it does not exist
    save_path = 'visualizations/trend_labeling.png'
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    # Save the plot
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()