In [45]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [46]:
import pandas as pd

def load_formatted_csv(file_path):
    """
    Loads the formatted CSV file into a DataFrame.
    
    :param file_path: str
        Path to the formatted CSV file.
    :return: pd.DataFrame
        Loaded DataFrame.
    """
    return pd.read_csv(file_path)


In [47]:
def plot_time_series(dataframe):
    """
    Plots the time series data in two ways:
    1. y_true vs y_pred_avg
    2. y_true vs individual predictions (y_pred_1, y_pred_2, ...).
    
    :param dataframe: pd.DataFrame
        The DataFrame loaded from the formatted CSV.
    """
    # Identify input columns and calculate the number of inputs
    input_cols = [col for col in dataframe.columns if '_input' in col]
    num_inputs = len(input_cols)  # Number of input columns

    # Extract timestamps and exclude those for inputs
    timestamp_cols = [col for col in dataframe.columns if '_timestamp' in col]
    y_timestamps = timestamp_cols[num_inputs:]  # Exclude input timestamps

    # Extract true and prediction columns
    true_cols = [col for col in dataframe.columns if '_true' in col]
    pred_cols = [col for col in dataframe.columns if '_pred' in col]

    # Combine all timestamps, true values, and predictions across rows
    all_timestamps = []
    all_y_true = []
    all_y_pred_avg = []
    all_y_preds = {col: [] for col in pred_cols}

    for _, row in dataframe.iterrows():
        # Get timestamps, true values, and predictions for this row
        row_timestamps = pd.to_datetime(row[y_timestamps].values)
        row_y_true = row[true_cols].values
        row_y_pred_avg = row[pred_cols].mean()  # Correctly calculate row average

        # Append to the combined list
        all_timestamps.extend(row_timestamps)
        all_y_true.extend(row_y_true)
        all_y_pred_avg.extend([row_y_pred_avg] * len(row_y_true))
        
        for col in pred_cols:
            all_y_preds[col].extend(row[col] if isinstance(row[col], (list, np.ndarray)) else [row[col]])

    # Ensure all arrays have the same length
    lengths = [len(all_timestamps), len(all_y_true), len(all_y_pred_avg)] + [len(v) for v in all_y_preds.values()]
    if len(set(lengths)) > 1:
        raise ValueError(f"Length mismatch detected: {lengths}")

    # Deduplicate and sort timestamps
    timestamp_df = pd.DataFrame({
        "timestamps": all_timestamps,
        "y_true": all_y_true,
        "y_pred_avg": all_y_pred_avg,
        **{col: all_y_preds[col] for col in pred_cols}
    }).drop_duplicates(subset=["timestamps"]).sort_values(by="timestamps")

    # Extract sorted values
    sorted_timestamps = timestamp_df["timestamps"]
    sorted_y_true = timestamp_df["y_true"]
    sorted_y_pred_avg = timestamp_df["y_pred_avg"]
    sorted_y_preds = {col: timestamp_df[col] for col in pred_cols}

    # Plot 1: y_true vs y_pred_avg
    plt.figure(figsize=(12, 6))
    plt.plot(sorted_timestamps, sorted_y_true, label="y_true", marker='o')
    plt.plot(sorted_timestamps, sorted_y_pred_avg, label="y_pred_avg", marker='x')
    plt.title("y_true vs y_pred_avg")
    plt.xlabel("Timestamps")
    plt.ylabel("Values")
    plt.xticks(rotation=45)
    plt.legend()
    plt.grid(True)
    plt.show()

    # Plot 2: y_true vs individual predictions
    plt.figure(figsize=(12, 6))
    plt.plot(sorted_timestamps, sorted_y_true, label="y_true", marker='o')
    for idx, (col, values) in enumerate(sorted_y_preds.items()):
        plt.plot(sorted_timestamps, values, label=f"{col}", linestyle='--')
    plt.title("y_true vs individual predictions")
    plt.xlabel("Timestamps")
    plt.ylabel("Values")
    plt.xticks(rotation=45)
    plt.legend()
    plt.grid(True)
    plt.show()


In [48]:
# Load the formatted CSV
file_path = 'logs/logs_2025-01-16_21-15-51/inference_results_reformatted.csv'
df = load_formatted_csv(file_path)

# Plot the time series
plot_time_series(df)


  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_timestamps].values)
  row_timestamps = pd.to_datetime(row[y_

ValueError: Length mismatch detected: [384, 384, 384, 64, 64, 64, 64, 64, 64]