In [1]:
from functools import partial
import logging
import ipywidgets as wd
import polars as pl
from IPython.display import display, HTML, clear_output
import bokeh.plotting as bp
from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.layouts import column
from bokeh.models import ranges, LinearAxis, HoverTool, CrosshairTool, Slider, CustomJS, Span
from src.power_curve import get_power_at_wind_velocity, read_power_curve
from src.turbines import TURBINES
from src.impingement import calculate_impingement
import matplotlib.pyplot as plt

import math

output_notebook()




class App:
    price_data: pl.LazyFrame
    weather_data: pl.LazyFrame
    _combined_data: pl.LazyFrame
    logger = logging.getLogger("App")
    logger.setLevel(logging.INFO)

    def __init__(self):
        self.plot_outputs = wd.Output()
        self.impingement_plot_output = wd.Output()
        grid = wd.GridspecLayout(1, 2)
        turbine_params = wd.Output()
        power_curves = wd.Output()
        grid[0, 0] = wd.VBox(
            (
                wd.Label("Wind Turbine"),
                (turbine_dropdown := wd.Dropdown(
                    description="Select wind turbine: ",
                    options=tuple(TURBINES.keys()),
                    value=None,
                    style={"description_width": "initial"},
                )),
                turbine_params,
                power_curves,
                self.impingement_plot_output
            )
        )

        turbine_dropdown.observe(partial(self._set_turbine_params, output=turbine_params), names="value")
        turbine_dropdown.observe(partial(self._show_power_curves, output=power_curves), names="value")

        grid[0, 1] = wd.VBox(
            (self.plot_outputs,)
        )

        self.grid = grid
    
    @staticmethod
    def _set_turbine_params(change: dict, output: wd.Output):
        turbine_name = change["new"]
        turbine_params = TURBINES[turbine_name]

        

        stats = []
        for key, value in turbine_params.items():
            if key == "power_curve":
                continue
            stats.append(wd.HBox((wd.Label(f"{key}:", layout={"width": "150px"}), wd.Label(f"{value}"))))
            

        box = wd.VBox(stats)

        with output:
            clear_output(wait=True)
            display(box)

    def _read_impingement_data(self):

        impingement_raw, impingement_testdata, r_acc_limit, lossvector = calculate_impingement(turbine=TURBINES["IEA 15 240"], windfarm="nordsen iii vest", slider=1)  # TODO: Change to selected turbine
        plt.figure(figsize=(10, 6))
        plt.plot(impingement_raw["timestamp"], lossvector)
        plt.xlabel("Timestamp")
        plt.ylabel("Turbine Efficiency [%]")
        plt.title("Turbine Efficiency Over Time")
        with self.impingement_plot_output:
            clear_output(wait=True)
            plt.show()



    def read_price_data(self, path: str):
        """Read the price history."""
        # Country,ISO3 Code,Datetime (UTC),Datetime (Local),Price (EUR/MWhe)
        schema = {
            "Country": pl.Categorical,
            "ISO3 Code": pl.Categorical,
            "Datetime (UTC)": pl.Datetime("ms"),
            "Datetime (Local)": pl.Datetime("ms"),
            "Price (EUR/MWhe)": pl.Float32,
        }
        price_data = (
            pl.read_csv(path, schema=schema)
            .lazy()
            .select(
                [
                    pl.col("Datetime (Local)").alias("time"),
                    pl.col("Price (EUR/MWhe)").alias("price"),
                ]
            )
        )
        price_data = price_data.set_sorted(pl.col("time"))
        price_data = price_data.filter(pl.col("time").dt.year() < 2020)
        self.price_data = price_data

    @property
    def combined_data(self):
        return self._combined_data.collect()

    def read_weather_data(self, path: str):
        """Read weather data."""
        schema = {
            "timestamp": pl.Datetime("ms"),
            "rainc": pl.Float32,
            "qrain_120.0": pl.Float32,
            "rho_120.0": pl.Float32,
            "wsp_120.0": pl.Float32,
            "qrain_150.0": pl.Float32,
            "rho_150.0": pl.Float32,
            "wsp_150.0": pl.Float32,
        }
        data = (
            pl.read_csv(path, schema=schema)
            .lazy()
            .select(
                [
                    pl.col("timestamp").alias("time"),
                    pl.col("wsp_150.0").alias("wind_speed"),
                ]
            )
            .set_sorted(pl.col("time"))
        )
        self.weather_data = data

    def _join_data(self):
        combined_data = self.weather_data.join(self.price_data, on="time", how="inner")
        combined_data = combined_data.with_columns(
            power=pl.col("wind_speed").map_elements(
                lambda value: get_power_at_wind_velocity(value) * 1e-6
            )
        )
        combined_data = combined_data.with_columns(
            income=pl.col("power") * pl.col("price")
        )
        combined_data = combined_data.with_columns(
            income_sma=pl.col("income").rolling_mean(window_size="1mo", by="time"),
            price_sma=pl.col("price").rolling_mean(window_size="1mo", by="time"),
        )
        self._combined_data = combined_data

    def _create_price_plot(self):
        data = self.combined_data
        source = bp.ColumnDataSource(
            data=dict(
                time=data.select(pl.col("time")).to_series(),
                price=data.select(pl.col("price")).to_series(),
                price_sma=data.select(pl.col("price_sma")).to_series(),
            )
        )
        hover_tool = HoverTool(
            tooltips=[
                ("Time", "@time{%F %T}"),
                ("Price", "@price"),
                ("Price (rolling mean)", "@price_sma"),
            ],
            formatters={"@time": "datetime"},
            mode="vline",
        )
        fig = figure(
            title="Electricity price (€/MWh), 1 hour average",
            x_axis_type="datetime",
            y_axis_label="Price (€/MWh)",
            width=600,
            height=300,
            tools=[
                hover_tool,
                CrosshairTool(),
                "pan",
                "wheel_zoom",
                "box_zoom",
                "reset",
                "save",
            ],
        )
        line = fig.line("time", "price", source=source)
        hover_tool.renderers = [line]
        fig.line("time", "price_sma", source=source, color="red", line_width=3)

        return fig

    def _create_income_plot(self):
        data = self.combined_data
        source = bp.ColumnDataSource(
            data=dict(
                time=data.select(pl.col("time")).to_series(),
                income=data.select(pl.col("income")).to_series(),
                income_sma=data.select(pl.col("income_sma")).to_series(),
            )
        )
        hover_tool = HoverTool(
            tooltips=[
                ("Time", "@time{%F %T}"),
                ("Income", "@income"),
                ("Income (rolling mean)", "@income_sma"),
            ],
            formatters={"@time": "datetime"},
            mode="vline",
        )
        fig = figure(
            title="Revenue (€), 1 hour average",
            x_axis_type="datetime",
            y_axis_label="Revenue (€)",
            width=600,
            height=300,
            tools=[
                hover_tool,
                CrosshairTool(),
                "pan",
                "wheel_zoom",
                "box_zoom",
                "reset",
                "save",
            ],
        )
        line = fig.line("time", "income", source=source)
        hover_tool.renderers = [line]
        fig.line("time", "income_sma", source=source, color="red", line_width=3)

        return fig

    def create_figures(self):
        return column(self._create_price_plot(), self._create_income_plot())

    def refresh_figures(self):
        with self.plot_outputs:
            show(self.create_figures())
        self._read_impingement_data()

    def _create_power_curve_plot(self, change):
        turbine_name = change["new"]
        curve = pl.from_dict(TURBINES[turbine_name]["power_curve"]).lazy()

        curve = curve.filter(pl.col("power") > 0).with_columns(
            tip_speed = pl.col("rotor_speed").mul(math.pi / 30).mul(TURBINES[turbine_name]["radius"])
        ).filter(pl.col("tip_speed") > 0).collect()



        source = bp.ColumnDataSource(
            data=curve.to_dict()
        )

        fig_0 = figure(
            title="Wind speed-Power curve",
            y_axis_label="Power (MW)",
            width=600,
            height=200,
            tools=[
                CrosshairTool(),
                "pan",
                "wheel_zoom",
                "box_zoom",
                "reset",
                "save",
            ],
        )
        fig_0.line("wind_speed", "power", source=source)

        fig_1 = figure(
            title="Wind speed-Rotor speed curve",
            y_axis_label="Rotor speed (1/min)",
            x_axis_label="Wind speed (m/s)",
            width=600,
            height=200,
            tools=[
                CrosshairTool(),
                "pan",
                "wheel_zoom",
                "box_zoom",
                "reset",
                "save",
            ],
        )
        fig_1.line("wind_speed", "rotor_speed", source=source)


        # Create a new ColumnDataSource to hold data for the movable point
        point_source = bp.ColumnDataSource(data=dict(x=[0], y=[0]))

        fig_2 = figure(
            title="Tip speed-Power curve",
            y_axis_label="Tip speed (m/s)",
            x_axis_label="Power (MW)",
            width=600,
            height=200,
            tools=[
                "crosshair", "pan", "wheel_zoom", "box_zoom", "reset", "save"
            ],
        )

        # Draw the curve line
        fig_2.line("power", "tip_speed", source=source)

        # Draw the movable point
        fig_2.circle('x', 'y', source=point_source, size=10, color="red")


        # Create the Slider object
        slider = Slider(start=min(source.data['tip_speed']), end=max(source.data['tip_speed']),
                        value=min(source.data['tip_speed']), step=1, title="Tip speed (m/s)")

        # Initialize a horizontal line (Span) passing through the point
        vertical_line = Span(location=point_source.data['y'][0],  # Initial y-value of the point
                            dimension='height', line_color='red', line_width=1)
        self.power_cap = vertical_line

        # Add the horizontal line to the figure
        fig_2.add_layout(vertical_line)

        # JavaScript code to be called whenever the slider moves
        callback = CustomJS(args=dict(source=source, point_source=point_source, slider=slider, vertical_line=vertical_line),
                            code="""
            const data = source.data;
            const P = data['tip_speed'];
            const T = data['power'];
            const pos = slider.value;
            const point_data = point_source.data;

            // Find index on the curve closest to the slider value
            let index = 0;
            let min_dist = Number.MAX_VALUE;
            for (let i = 0; i < P.length; i++) {
                let dist = Math.abs(P[i] - pos);
                if (dist < min_dist) {
                    min_dist = dist;
                    index = i;
                }
            }

            // Update the point's position to the curve's corresponding y-value 
            point_data['x'][0] = T[index];
            point_data['y'][0] = P[index];
            point_source.change.emit();
            // Update the horizontal line to pass through the new point
            vertical_line.location = point_data['x'][0];
        """)

        # Attach callback to slider
        slider.js_on_change('value', callback)

        layout = column(slider, fig_2)

        return layout
        # return column(fig_2, fig_0, fig_1)
        # return fig_2


    def _show_power_curves(self, change: dict, output: wd.Output):
        figures = self._create_power_curve_plot(change)
        with output:
            clear_output(wait=True)
            show(figures)




app = App()
app.read_price_data("data/price_data/Denmark.csv")
app.read_weather_data("data/weather_data.csv")
app._join_data()
display(app.grid)
app.refresh_figures()



GridspecLayout(children=(VBox(children=(Label(value='Wind Turbine'), Dropdown(description='Select wind turbine…