<h1>I/O Tools - Matplotlib</h1><h2 align="center">Basic Charts</h2>

In [None]:
from matplotlib.pyplot import figure

figure()

In [None]:
from matplotlib.pyplot import gca
from matplotlib.axes import Axes


def set_chart_labels(
    ax: Axes, title: str = "", xlabel: str = "", ylabel: str = ""
) -> Axes:
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    return ax

In [None]:
from datetime import datetime
from matplotlib.dates import AutoDateLocator, AutoDateFormatter
from matplotlib.axes import Axes


def set_chart_xticks(
    xvalues: list[str | int | float | datetime], ax: Axes, percentage: bool = False
) -> Axes:
    if len(xvalues) > 0:
        if percentage:
            ax.set_ylim(0.0, 1.0)

        if isinstance(xvalues[0], datetime):
            locator = AutoDateLocator()
            ax.xaxis.set_major_locator(locator)
            ax.xaxis.set_major_formatter(
                AutoDateFormatter(locator, defaultfmt="%Y-%m-%d")
            )
        rotation: int = 0
        if not any(not isinstance(x, (int, float)) for x in xvalues):
            ax.set_xlim(left=xvalues[0], right=xvalues[-1])
            ax.set_xticks(xvalues, labels=xvalues)
        else:
            rotation = 45

        ax.tick_params(axis="x", labelrotation=rotation, labelsize="xx-small")

    return ax

<h3>Line Charts</h3>

In [None]:
from numpy import std
from matplotlib.pyplot import show, savefig
from pandas import read_csv, DataFrame, Series
from config import LINE_COLOR, FILL_COLOR


def plot_line_chart(
    xvalues: list,
    yvalues: list,
    ax: Axes = None,  # type: ignore
    title: str = "",
    xlabel: str = "",
    ylabel: str = "",
    name: str = "",
    percentage: bool = False,
    show_stdev: bool = False,
) -> Axes:
    if ax is None:
        ax = gca()
    ax = set_chart_labels(ax=ax, title=title, xlabel=xlabel, ylabel=ylabel)
    ax = set_chart_xticks(xvalues, ax, percentage=percentage)
    ax.plot(xvalues, yvalues, c=LINE_COLOR, label=name)
    if show_stdev:
        stdev: float = round(std(yvalues), 3)
        y_bottom: list[float] = [(y - stdev) for y in yvalues]
        y_top: list[float] = [(y + stdev) for y in yvalues]
        ax.fill_between(xvalues, y_bottom, y_top, color=FILL_COLOR, alpha=0.2)
    return ax


data: DataFrame = read_csv("data/algae.csv", index_col="date")

figure(figsize=(12, 4))
var = "pH"
plot_line_chart(
    data.index.to_list(),
    data[var].to_list(),
    title=f"{var} variation",
    xlabel=data.index.name,
    ylabel=var,
)
savefig("images/pH_variation.png")
show()

<h3>Bar Charts</h3>

In [None]:
from matplotlib.container import BarContainer
from dslabs_functions import FONT_TEXT


def plot_bar_chart(
    xvalues: list,
    yvalues: list,
    ax: Axes = None,  # type: ignore
    title: str = "",
    xlabel: str = "",
    ylabel: str = "",
    percentage: bool = False,
) -> Axes:
    if ax is None:
        ax = gca()
    ax = set_chart_labels(ax=ax, title=title, xlabel=xlabel, ylabel=ylabel)
    ax = set_chart_xticks(xvalues, ax=ax, percentage=percentage)
    values: BarContainer = ax.bar(
        xvalues,
        yvalues,
        label=yvalues,
        edgecolor=LINE_COLOR,
        color=FILL_COLOR,
        tick_label=xvalues,
    )
    format = "%.2f" if percentage else "%.0f"
    ax.bar_label(values, fmt=format, fontproperties=FONT_TEXT)

    return ax


figure()
var = "season"
counts: Series = data[var].value_counts()
plot_bar_chart(
    counts.index.to_list(),
    counts.to_list(),
    title=f"{var} distribution",
    xlabel=var,
    ylabel="frequency",
)
show()

<h3>Horizontal Bar Charts</h3>

In [None]:
from numpy import arange


def plot_horizontal_bar_chart(
    elements: list,
    values: list,
    error: list = [],
    ax: Axes = None,  # type: ignore
    title: str = "",
    xlabel: str = "",
    ylabel: str = "",
    percentage: bool = False,
) -> Axes:
    if ax is None:
        ax = gca()
    if percentage:
        ax.set_xlim((0, 1))
    if error == []:
        error = [0] * len(elements)
    ax = set_chart_labels(ax=ax, title=title, xlabel=xlabel, ylabel=ylabel)
    y_pos: list = list(arange(len(elements)))

    ax.barh(
        y_pos, values, xerr=error, align="center", error_kw={"lw": 0.5, "ecolor": "r"}
    )
    ax.set_yticks(y_pos, labels=elements)
    ax.invert_yaxis()  # labels read top-to-bottom
    return ax


figure()
var = "season"
counts: Series = data[var].value_counts()
print(counts)
plot_horizontal_bar_chart(
    counts.index.to_list(),
    counts.to_list(),
    title=f"{var} distribution",
    xlabel=var,
    ylabel="frequency",
)
show()

<h3>Scatter Plots</h3>

In [None]:
def plot_scatter_chart(
    var1: list,
    var2: list,
    ax: Axes = None,  # type: ignore
    title: str = "",
    xlabel: str = "",
    ylabel: str = "",
) -> Axes:
    if ax is None:
        ax = gca()
    ax = set_chart_labels(ax=ax, title=title, xlabel=xlabel, ylabel=ylabel)
    ax.scatter(var1, var2)
    return ax


var1 = "Phosphate"
var2 = "Orthophosphate"
figure()
plot_scatter_chart(
    data[var1].to_list(),
    data[var2].to_list(),
    title=f"{var1} x {var2}",
    xlabel=var1,
    ylabel=var2,
)
show()