In [26]:
from __future__ import annotations

import numpy as np
import jax.numpy as jnp
import pandas as pd
import polars as pl
import seaborn as sns
from nptyping import NDArray, Shape, Int, Float, Bool
import plotly.express as px
import plotly.graph_objects as go
from sklearn.linear_model import LinearRegression
from IPython.display import display, HTML
from rich import print


def load_mathjax():
    """Load MathJax library in JupyterLab to enable LaTeX rendering in Plotly charts.

    This function checks if the MathJax library is already loaded, and if not,
    it loads the library from the provided CDN link.
    """
    display(HTML("""
        <script>
            if (typeof MathJax === 'undefined') {
                var script = document.createElement('script');
                script.type = 'text/javascript';
                script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-AMS-MML_HTMLorMML';
                document.head.appendChild(script);
            }
        </script>
    """))
# Load MathJax
load_mathjax()

In [23]:
df: pl.DataFrame = pl.from_pandas(sns.load_dataset("anscombe"))
dataset_names: pl.Series = df["dataset"].unique()
display(dataset_names)
colors = ["tab:blue", "tab:orange", "tab:green", "tab:red"]

df

dataset
str
"""I"""
"""IV"""
"""III"""
"""II"""


dataset,x,y
str,f64,f64
"""I""",10.0,8.04
"""I""",8.0,6.95
"""I""",13.0,7.58
"""I""",9.0,8.81
"""I""",11.0,8.33
"""I""",14.0,9.96
"""I""",6.0,7.24
"""I""",4.0,4.26
"""I""",12.0,10.84
"""I""",7.0,4.82


In [32]:
def plot_regression_line(df: pl.DataFrame, color: str | None = None) -> None:
    """Plots a scatter plot of the data points in the given DataFrame and adds a regression line.

    Parameters
    ----------
    df : pl.DataFrame
        A DataFrame containing the x and y values to be plotted.
    color : str, optional
        The color of the scatter points and regression line, by default None.
    """
    x = df.select('x').to_numpy()
    y = df.select('y').to_numpy()
    model = LinearRegression().fit(x, y)

    x_range = np.linspace(1, 20, num=20).reshape(-1, 1)  # Replace jnp with np
    y_pred = model.predict(x_range)
    df_pred = pl.DataFrame({'x': x_range.squeeze(), 'y': y_pred.squeeze()})
    slope: float = round(model.coef_[0][0], 4)
    intercept: float = round(model.intercept_[0], 4)

    fig = px.scatter(data_frame=df.to_pandas(), x='x', y='y', color_discrete_sequence=[color] if color else None)
    fig.add_scatter(x=df_pred['x'], y=df_pred['y'], mode='lines', line=dict(color=color), name=f'$y={slope}x + {intercept}$')

    fig.update_xaxes(range=[0, 20], title_text="$x$")
    fig.update_yaxes(range=[0, 14], title_text="$y$")
    fig.show()


df: pl.DataFrame = pl.from_pandas(sns.load_dataset("anscombe"))
dataset_names: pl.Series = df["dataset"].unique().sort()
colors: list[str] = ['blue', 'orange', 'green', 'red']

for dataset_name, color in zip(dataset_names, colors):
    df_sub = df.filter(pl.col('dataset') == dataset_name).select(['x', 'y'])
    plot_regression_line(df=df_sub, color=color)
