In [1]:
import pandas as pd
import numpy as np
import glob
import os

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from sklearn.preprocessing import MinMaxScaler

cluster_results = pd.read_csv(
    "../clustering_results/hurst_clustering_result_Finance.csv"
)

In [2]:
sector = "Finance"

In [3]:
csv_file_paths = glob.glob("../stock_data_model/Finance/*.csv")
end_date = "2024-06-20"
start_date = "2023-12-20"
stock_fig_dir = "../stock_figure_hurst/"


def get_golden_length(start_date, end_date, sector):
    if sector == "Finance":
        csv_file_path = os.path.join(f"../stock_data_model/", sector, "AAME.csv")
    else:
        csv_file_path = os.path.join(f"../stock_data_model/", sector, "AAPL.csv")
    data = pd.read_csv(csv_file_path)
    filtered_data = data[(data["Date"] >= start_date) & (data["Date"] <= end_date)][
        ["Close"]
    ]
    date_range = data[(data["Date"] >= start_date) & (data["Date"] <= end_date)][
        ["Date"]
    ]
    return filtered_data.shape, date_range


golden_shape, date_range = get_golden_length(start_date, end_date, sector)

In [4]:
def standardize_data(data):
    """Standardize the data."""
    normalized_data = pd.DataFrame(
        MinMaxScaler().fit_transform(data[["Close"]]), columns=["Close"]
    )
    normalized_data = pd.concat([data["Date"], normalized_data], axis=1)
    return normalized_data

In [5]:
def read_and_filter_data(csv_file_path, start_date, end_date, golden_shape):
    """Read CSV file and filter data based on date range."""
    # logging.info(f'Reading and filtering data from {csv_file_path}')
    data = pd.read_csv(csv_file_path)
    data = standardize_data(data)
    filtered_data = data[(data["Date"] >= start_date) & (data["Date"] <= end_date)][
        ["Close"]
    ]

    required_length = golden_shape[0]
    # 检查长度并填补0
    if len(filtered_data) < required_length:
        print(f"filtered_data.shape = {filtered_data.shape}")
        fill_length = required_length - len(filtered_data)
        fill_zeros = pd.DataFrame({"Close": [0] * fill_length})
        filtered_data = pd.concat(
            [fill_zeros, filtered_data], axis=0, ignore_index=True
        )
    return filtered_data.values.reshape(-1)

In [25]:
stockID = "AAME"
data = pd.read_csv(f"../stock_data_model/Finance/{stockID}.csv")
normalized_data = pd.DataFrame(MinMaxScaler().fit_transform(data[["Close"]]))
data = pd.concat([data["Date"], normalized_data], axis=1)

In [6]:
cluster_results["clusterID"].unique()

array([2, 1, 0], dtype=int64)

In [9]:
for cluster in cluster_results["clusterID"].unique():
    X = []
    for stockID in cluster_results["stockID"]:
        cluster_id = cluster_results.loc[
            cluster_results["stockID"] == stockID, "clusterID"
        ].values[0]
        if cluster_id == cluster:
            data = read_and_filter_data(
                f"../stock_data_model/Finance/{stockID}.csv",
                start_date,
                end_date,
                golden_shape,
            )
            # X.append(data)
            # print(X.shape)
            # Plot the 'Close' column
            plt.figure(figsize=(10, 6))
            # Column 3 corresponds to 'Close'
            # print(f"date_range = {date_range}, close = {standardized_data[:, 3]}")
            # print(date_range.to_numpy().shape)
            # print(standardized_data[:, 3].shape)
            plt.plot(date_range.to_numpy().reshape(-1), np.array(data).tolist())
            plt.xlabel("Date")
            plt.ylabel("Standardized Close Value")
            plt.title(
                f"Standardized Close Prices for stockID = {stockID} cluster: {cluster}"
            )
            # Set date format on x-axis to make it less dense
            plt.gca().xaxis.set_major_locator(
                mdates.MonthLocator(interval=1)
            )  # Show every second week
            plt.xticks(rotation=45)
            plt.tight_layout()
            plt.savefig(os.path.join(stock_fig_dir, sector, f"{stockID}_{cluster}.jpg"))
            plt.close()
    # X = np.array(X)
    # X = np.sum(X, axis=0) / X.shape[0]