In [None]:
import os
import sys

sys.path.append(os.path.abspath(".."))

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
import base64
import io
from typing import Dict

from smolagents import tools

import pandas as pd

from src.agent.adapters.tools.base import BaseTool

In [None]:
## monkey patching

tools.AUTHORIZED_TYPES = [
    "string",
    "boolean",
    "integer",
    "number",
    "image",
    "audio",
    "array",
    "object",
    "any",
    "null",
    "list",
    "dict",
    "dataframe",
]

In [None]:
class PlotData(BaseTool):
    name = "plot_data"
    description = """Plot data from data."""
    inputs = {
        "data": {"type": "dataframe", "description": "asset id data"},
    }
    outputs = {"plot": {"type": "str", "description": "encoded plot"}}
    output_type = "dict"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def forward(self, data: pd.DataFrame) -> Dict[str, str]:
        if data.empty:
            return {"plot": None}

        data, freq = self.simplify_time_index(data.copy())

        fig, ax = plt.subplots(figsize=(12, 6))

        for column_name in data.columns:
            ax.plot(
                data.index,
                data[column_name],
                label=column_name,
                marker="o",
                linestyle="--",
            )

        step = max(1, round(len(data.index) / 20))
        xticks = data.index[::step]

        # --- CHANGE 2: Format the labels based on the detected frequency ---
        if freq == "D":
            # For daily data, format as 'Year-Month-Day'
            xtick_labels = xticks.strftime("%Y-%m-%d")
        elif freq == "h":
            # For hourly data, format as 'Year-Month-Day Hour:Minute'
            xtick_labels = xticks.strftime("%Y-%m-%d %H:%M")
        else:
            # A sensible default for other frequencies (e.g., seconds, irregular)
            xtick_labels = xticks.strftime("%Y-%m-%d %H:%M:%S")

        ax.set_xticks(xticks)
        ax.set_xticklabels(
            xtick_labels, rotation=45, ha="right"
        )  # Use formatted labels

        ax.set_xlabel("Date")
        ax.set_ylabel("Value")
        ax.set_title("Time Series Plot")
        ax.grid(True)
        ax.legend(title="Series Name")
        fig.tight_layout()

        buf = io.BytesIO()

        # Save the figure to the buffer in PNG format (or 'jpeg', 'svg', etc.)
        # bbox_inches='tight' helps remove extra whitespace around the plot
        fig.savefig(buf, format="png", bbox_inches="tight")

        buf.seek(0)

        # Read the binary data from the buffer
        image_binary = buf.read()

        base64_bytes = base64.b64encode(image_binary)
        base64_string = base64_bytes.decode("utf-8")

        buf.close()
        plt.close(fig)

        return {"plot": base64_string}

    def simplify_time_index(self, data):
        """
        Detects if a DataFrame's index is daily or hourly, simplifies it,
        and returns the modified DataFrame along with the detected frequency.

        Returns:
            tuple: (pd.DataFrame, str or None)
                   The modified DataFrame and the detected frequency string ('D', 'h', etc.).
        """
        data.index = pd.to_datetime(data.index)
        detected_freq = None  # Initialize a variable to store the frequency

        freq = pd.infer_freq(data.index)

        if freq == "D":
            data.index = data.index.normalize()
            detected_freq = "D"
        # Use .startswith() to catch 'H', 'h', '2H', etc.
        elif freq and freq.upper().startswith("H"):
            data.index = data.index.floor("h")
            detected_freq = "h"
        else:
            # Fallback check
            is_daily = (
                (data.index.hour == 0).all()
                and (data.index.minute == 0).all()
                and (data.index.second == 0).all()
            )
            is_hourly = (data.index.minute == 0).all() and (
                data.index.second == 0
            ).all()

            if is_daily:
                data.index = data.index.normalize()
                detected_freq = "D"
            elif is_hourly:
                data.index = data.index.floor("h")
                detected_freq = "h"

        return data, detected_freq

In [None]:
kwargs = {"tools_api_base": "http://localhost:5050", "tools_api_limit": "100"}

In [None]:
tool = PlotData(**kwargs)

In [None]:
temp = pd.read_csv("/Users/steffen/dev/daily_data.csv")
temp.set_index("timestamp", inplace=True)

In [None]:
out = tool.forward(temp)

In [None]:
base64_string_to_decode = out["plot"]

In [None]:
import matplotlib.image as mpimg

In [None]:
image_bytes = base64.b64decode(base64_string_to_decode)

# 2. Create an in-memory buffer from the bytes
image_buffer = io.BytesIO(image_bytes)

# 3. Load Image from Buffer using Matplotlib
# mpimg.imread can read from a file-like object
img_data = mpimg.imread(
    image_buffer, format="png"
)  # Specify format if known (e.g., 'png', 'jpeg')

# 4. Display Image using Matplotlib
plt.figure(figsize=(8, 6))  # Optional: Adjust figure size
plt.imshow(img_data)
plt.axis("off")  # Turn off axis numbers and ticks for a cleaner image display
# plt.title("Decoded Image")
plt.show()