# Advanced Usage

This notebook covers advanced Plotly customization and styling patterns.

In [None]:
import plotly.express as px
import xarray as xr

from xarray_plotly import config, xpx

config.notebook()  # Configure Plotly for notebook rendering

## Load Sample Data

In [None]:
# Stock prices
df_stocks = px.data.stocks().set_index("date")
df_stocks.index = df_stocks.index.astype("datetime64[ns]")

stocks = xr.DataArray(
    df_stocks.values,
    dims=["date", "company"],
    coords={"date": df_stocks.index, "company": df_stocks.columns.tolist()},
    name="price",
    attrs={"long_name": "Stock Price", "units": "normalized"},
)

# Gapminder data
df_gap = px.data.gapminder()
countries = ["United States", "China", "Germany", "Brazil", "Nigeria"]

df_life = df_gap[df_gap["country"].isin(countries)].pivot(
    index="year", columns="country", values="lifeExp"
)
life_exp = xr.DataArray(
    df_life.values,
    dims=["year", "country"],
    coords={"year": df_life.index, "country": df_life.columns.tolist()},
    name="life_expectancy",
    attrs={"long_name": "Life Expectancy", "units": "years"},
)

## Working with xarray Attributes

xarray_plotly automatically uses metadata from xarray attributes for labels:

In [None]:
# Check the attributes we set
print(f"Name: {stocks.name}")
print(f"Attrs: {stocks.attrs}")

In [None]:
# Labels are automatically extracted from attrs
fig = xpx(stocks).line(title="Auto-Labels from Metadata")
fig

### Configuring Label Behavior

Use `config.set_options()` to control how labels are extracted:

In [None]:
# Disable units in labels
with config.set_options(label_include_units=False):
    fig = xpx(stocks).line(title="Without Units in Labels")
fig

### Overriding Labels

You can override the automatic labels:

In [None]:
fig = xpx(stocks).line(
    labels={
        "price": "Normalized Price",
        "date": "Trading Date",
        "company": "Ticker",
    },
    title="Custom Labels",
)
fig

## Advanced Dimension Assignment

### Using Multiple Visual Encodings

Combine color, line_dash, and facets to show multiple dimensions:

In [None]:
# Create 3D data by adding a "metric" dimension
df_gdp = df_gap[df_gap["country"].isin(countries)].pivot(
    index="year", columns="country", values="gdpPercap"
)
gdp = xr.DataArray(
    df_gdp.values / 1000,
    dims=["year", "country"],
    coords={"year": df_gdp.index, "country": df_gdp.columns.tolist()},
    name="gdp",
)

# Combine into 3D: (metric, year, country)
combined = xr.concat(
    [life_exp, gdp],
    dim=xr.Variable("metric", ["Life Expectancy (years)", "GDP per Capita (thousands)"]),
)
print(f"Combined shape: {dict(combined.sizes)}")

In [None]:
# Use facet_col for metric, color for country
fig = xpx(combined).line(
    facet_col="metric",
    title="Multiple Metrics Comparison",
)
fig

In [None]:
# Use line_dash for a dimension
fig = xpx(stocks.sel(company=["GOOG", "AAPL", "MSFT"])).line(
    color=None,
    line_dash="company",
    title="Using Line Dash Instead of Color",
)
fig

## Custom Styling

### Themes

In [None]:
fig = xpx(stocks).line(
    template="plotly_dark",
    title="Dark Theme",
)
fig

In [None]:
fig = xpx(stocks).line(
    template="seaborn",
    title="Seaborn Theme",
)
fig

### Custom Colors

In [None]:
fig = xpx(stocks).line(
    color_discrete_sequence=px.colors.qualitative.Set2,
    title="Set2 Color Palette",
)
fig

In [None]:
# Custom color list
fig = xpx(life_exp).line(
    color_discrete_sequence=["#E63946", "#457B9D", "#2A9D8F", "#E9C46A", "#F4A261"],
    title="Custom Color Sequence",
)
fig

### Heatmap Colorscales

In [None]:
fig = xpx(life_exp).imshow(
    color_continuous_scale="Viridis",
    title="Viridis Colorscale",
)
fig

In [None]:
# Diverging colorscale with midpoint
# Calculate change from first year
life_change = life_exp - life_exp.isel(year=0)
life_change.name = "change"

fig = xpx(life_change).imshow(
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    title="Life Expectancy Change (Diverging Colorscale)",
)
fig

## Post-Creation Customization

All plots return Plotly `Figure` objects that you can customize further:

In [None]:
fig = xpx(stocks).line()

# Add horizontal reference line
fig.add_hline(y=1.0, line_dash="dash", line_color="gray", annotation_text="Baseline")

# Update layout
fig.update_layout(
    title="Stock Prices with Reference Line",
    legend={
        "orientation": "h",
        "yanchor": "bottom",
        "y": 1.02,
        "xanchor": "right",
        "x": 1,
    },
)

fig

In [None]:
# Add vertical line for an event
fig = xpx(stocks).line(title="Stock Prices with Event Marker")

fig.add_vline(
    x="2018-07-01",
    line_dash="dot",
    line_color="red",
    annotation_text="Mid-2018",
)
fig

### Modifying Traces

In [None]:
fig = xpx(stocks).line()

# Make all lines thicker
fig.update_traces(line_width=3)

fig.update_layout(title="Thicker Lines")
fig

In [None]:
fig = xpx(stocks).scatter()

# Change marker style
fig.update_traces(marker={"size": 10, "opacity": 0.7})

fig.update_layout(title="Custom Marker Style")
fig

### Adding Annotations

In [None]:
fig = xpx(life_exp).line(title="Life Expectancy with Annotations")

# Add annotation for a specific point
fig.add_annotation(
    x=2007,
    y=life_exp.sel(year=2007, country="China").values,
    text="China 2007",
    showarrow=True,
    arrowhead=2,
)
fig

## Exporting Figures

### Interactive HTML

```python
fig.write_html("interactive_plot.html")
```

### Static Images

Requires `kaleido`: `pip install kaleido`

```python
fig.write_image("plot.png", scale=2)  # High resolution
fig.write_image("plot.svg")  # Vector format
fig.write_image("plot.pdf")  # PDF
```

## Subplots with Shared Axes

In [None]:
# Faceted plot with shared y-axis
fig = xpx(combined).line(
    facet_col="metric",
    title="Facets with Independent Y-Axes",
)

# Each facet gets its own y-axis range by default
fig.update_yaxes(matches=None)
fig

## Combining with Plotly Graph Objects

You can add additional traces using Plotly's graph objects:

In [None]:
import plotly.graph_objects as go

fig = xpx(stocks.sel(company="GOOG")).line(title="GOOG with Moving Average")

# Calculate and add a moving average
goog = stocks.sel(company="GOOG")
ma_20 = goog.rolling(date=20, center=True).mean()

fig.add_trace(
    go.Scatter(
        x=ma_20.coords["date"].values,
        y=ma_20.values,
        mode="lines",
        name="20-day MA",
        line={"dash": "dash", "color": "red"},
    )
)
fig