# Custom Colors Pyplot
----------------------------------

## Imports

In [None]:
import random

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio

## Generate Custom Style

In [None]:
from statworx_theme import (
    apply_custom_colors_plotly,
    get_stwx_cmaps,
    register_blended_cmap,
)

# define custom colors
BLACK = "#000000"
WHITE = "#ffffff"
DARK_BLUE = "#0A526B"
DARK_RED = "#6B0020"
GREY = "#808285"
DARK_GREY = "#313132"
LIGHT_GREY = "#ADAFB3"

# apply custom color schema
custom_colors = [
    DARK_BLUE,
    DARK_RED,
    GREY,
    DARK_GREY,
    LIGHT_GREY,
]

# add custom color map
_ = register_blended_cmap([BLACK, DARK_BLUE, WHITE], "stwx:custom_blue_fade")

stwx_cmaps = get_stwx_cmaps()

## Colormaps for Manipulation of the theme

In [None]:
stwx_cmaps = get_stwx_cmaps()

stwx_cmaps.keys()

## Apply Custom Style Altair

In [None]:
# Any of the colormaps or any list of hexadecimal colorcodes can be applied inside the custom theme:
# Parameters: category, diverging, sequential, sequential_minus, heatmap
apply_custom_colors_plotly(
    category=custom_colors, sequential=stwx_cmaps["stwx:custom_blue_fade"]
)
print(pio.templates.default)

## Plot

In [None]:
# Custom categorical colors
df = px.data.iris()
fig = px.scatter(
    df,
    x="sepal_width",
    y="sepal_length",
    color="species",
    title="Sepal characteristics per species",
)

fig.update_layout(
    width=800,
    height=500,
)

fig.show()

In [None]:
# Custom sequential colors
# Source: https://plotly.com/python/line-and-scatter/
df = px.data.iris()
fig = px.scatter(df, x="sepal_width", y="sepal_length", color="petal_length")

fig.update_layout(
    width=800,
    height=500,
)

fig.show()

In [None]:
# Sequential color scheme not perfect in this case

corr = [
    [-1, -0.3, 0.5, 0.7, -0.9],
    [1, 0.8, 0.6, -0.4, -0.2],
    [0.2, 0, -0.5, 0.7, -0.9],
    [0.9, -0.8, -0.4, 0.2, 0],
    [0.3, 0.4, 0.5, 0.7, 1],
]

fig = px.imshow(corr, text_auto=True)

fig.update_layout(
    width=800,
    height=500,
)

fig.show()

In [None]:
# Set diverging colorscheme for diverging colors from applied default scheme
fig = px.imshow(
    corr,
    text_auto=True,
    color_continuous_scale=pio.templates[
        pio.templates.default
    ].layout.colorscale.diverging,
)

fig.update_layout(
    width=800,
    height=500,
)

fig.show()

In [None]:
# Set different diverging colorscheme from statworx colormaps
fig = px.imshow(
    corr, text_auto=True, color_continuous_scale=stwx_cmaps["stwx:BlYw_diverging"]
)

fig.update_layout(
    width=800,
    height=500,
)
fig.show()

In [None]:
# For plotly graph object heatmaps the diverging color scheme is automatically applied

# Source: https://plotly.com/python/heatmaps/

np.random.seed(0)
focus = np.random.rand(150)
np.random.seed(3000)
distraction = -(focus * np.random.rand(150))
np.random.seed(22000)
productivity = np.random.rand(150)
home_office = random.choices([0, 1], k=150)

df = pd.DataFrame(
    {
        "focus": focus,
        "distraction": distraction,
        "productivity": productivity,
        "home_office": home_office,
    }
)
df_corr = df.corr()

fig = go.Figure(
    data=go.Heatmap(
        x=df_corr.columns,
        y=df_corr.index,
        z=np.array(df_corr),
        text=df_corr.values,
        texttemplate="%{text:.2f}",
        zmin=-1,
        zmax=1,
    ),
)

fig.update_layout(
    width=800,
    height=500,
)

fig.show()

In [None]:
# In case of non-diverging scales, the heatmap color scheme needs to be manually adapted
random.seed(123)
focus = np.random.rand(150)
distraction = focus * np.random.rand(150)
productivity = np.random.rand(150)
df = pd.DataFrame(
    {
        "focus": focus,
        "distraction": distraction,
        "productivity": productivity,
        "home_office": random.choices([0, 1], k=150),
    }
)
df_corr = df.corr().apply(lambda x: (x - x.min()) / (x.max() - x.min()))

# choose colorscale sequential and drop zmin and zmax for other settings than regular correlation
fig = go.Figure(
    data=go.Heatmap(
        x=df_corr.columns,
        y=df_corr.index,
        z=np.array(df_corr),
        text=df_corr.values,
        texttemplate="%{text:.2f}",
        colorscale=pio.templates[pio.templates.default].layout.colorscale.sequential,
    ),
)

fig.update_layout(
    width=800,
    height=500,
)

fig.show()