# Data Visualization for Machine Learning

Data visualization is a crucial step in the machine learning workflow.  
It helps us to:

- **Explore and understand** the data before building models  
- **Identify patterns and relationships** between variables  
- **Detect anomalies and outliers** that could affect model performance  
- **Communicate results** clearly to both technical and non-technical audiences  

Summary statistics can describe data, but **plots can reveal insights that numbers alone might miss**.  

In this chapter, we will learn how to create effective visualizations using three popular Python libraries:

- **Matplotlib** → the foundational plotting library in Python, offering flexibility and fine control.  
- **Pandas plotting** → quick and convenient plotting methods built into Pandas DataFrames.  
- **Seaborn** → a high-level library built on Matplotlib, designed for statistical and aesthetically pleasing graphics.  

---

### 🎯 Learning Objectives
By the end of this section, you should be able to:
1. Create basic plots with Matplotlib.  
2. Visualize data directly from Pandas Da

In [None]:
# --- Imports ---

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

## 5.1 Getting Started with Matplotlib

[Matplotlib](https://matplotlib.org/) is a fundamental plotting library in Python, providing a flexible and powerful way to create static, interactive, and animated visualizations.

It is the **basis for many other plotting libraries** (such as such as Pandas and Seaborn) and offers fine-grained control over every element of a plot.

Understanding Matplotlib is essential for creating a wide variety of visualizations for exploring and presenting data in machine learning.


### 5.1.1 Motivational Example

Let’s compare the same dataset shown as numbers (a table) and as a plot.

In [None]:
# --- Motivational Example: display dataset as a table ---

# A simple dataset
x = np.arange(1, 11)
y = x ** 2

# Show as a table
pd.DataFrame({"x": x, "y": y})

In [None]:
# --- Motivational Example: display dataset as a plot ---

plt.plot(x, y, marker="o")
plt.title("A Simple Quadratic Function")
plt.xlabel("x")
plt.ylabel("y = x^2")
plt.grid(True)
plt.show()   # Explicitly display the figure (important outside Jupyter too)

Notice how the table shows only numbers, while the plot immediately reveals the *quadratic* relationship.  
This demonstrates why visualization is often the first step in data exploration.

### 5.1.2 Basic Anatomy of a Plot

A typical Matplotlib plot usually contains:
- **Data**: what we want to visualize  
- **Figure**: the overall container (the "canvas")  
- **Axes**: the area where data is plotted  
- **Labels and title**: describe what the plot shows  
- **Grid, legends, colors, markers**: improve readability  

In [None]:
# --- Line plot with labels, title, legend, and grid ---

x = np.linspace(0, 10, 100)
y = np.sin(x)

plt.plot(x, y, color="blue", label="sin(x)")
plt.title("Line Plot Example")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.show()

### 5.1.3 Common Plot Types
Matplotlib offers many plot types, but the following are especially common in data analysis:

- **Line Plots** — Show how a variable changes over time or another ordered sequence. Often used to track model performance (e.g., loss over epochs).  
- **Scatter Plots** — Display the relationship between two numerical variables. Useful for spotting correlations, clusters, or outliers.  
- **Bar Plots** — Compare values across categories. Each bar’s height represents the category value. Helpful for categorical distributions or comparing models.  
- **Histograms** — Visualize the distribution of a single numerical variable by grouping values into bins. Useful for spotting data shape, skewness, or multimodality.  

In [None]:
# --- Scatter plot ---

np.random.seed(42)
x = np.random.rand(50)
y = np.random.rand(50)

plt.scatter(x, y, color="red", marker="x")
plt.title("Scatter Plot Example")
plt.xlabel("x")
plt.ylabel("y")
plt.show()

In [None]:
# --- Bar chart ---

categories = ["A", "B", "C", "D"]
values = [3, 7, 5, 9]

plt.bar(categories, values, color="green")
plt.title("Bar Chart Example")
plt.ylabel("Value")
plt.show()

In [None]:
# --- Histogram ---

data = np.random.randn(1000)  # normally distributed data

plt.hist(data, bins=30, color="purple", alpha=0.7)
plt.title("Histogram Example")
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.show()

In [None]:
# --- Example: Matplotlib options in action ---

# Generate sample data
x = np.linspace(0, 2*np.pi, 100)
y1, y2 = np.sin(x), np.cos(x)

# Create a figure with custom size (width=10, height=4 inches)
fig, axes = plt.subplots(1, 2, figsize=(10,4))

# --- First subplot ---
axes[0].plot(x, y1, color="blue", label="sin(x)")
axes[0].set_title("Sine")          # Set subplot title
axes[0].set_xlabel("x-axis")       # Label x-axis
axes[0].set_ylabel("y-axis")       # Label y-axis
axes[0].legend()                   # Show legend

# --- Second subplot ---
axes[1].plot(x, y2, color="red", label="cos(x)")
axes[1].set_title("Cosine")
axes[1].set_xlabel("x-axis")
axes[1].set_ylabel("y-axis")
axes[1].legend()

# Add a title for the whole figure
plt.suptitle("Example: figsize, subplots, axes/axis")

plt.show()

## 5.2 Data Visualization with Pandas

While Matplotlib provides powerful low-level control over plots, Pandas offers a more convenient, high-level interface.  
Since Pandas is built on top of Matplotlib, you can quickly create common visualizations directly from DataFrames and Series using the `.plot()` method.

Pandas plotting is especially useful when:
- You are exploring tabular data.
- You want fast visualizations with minimal code.
- You need plots that integrate seamlessly with DataFrame operations.

In this subsection, we will:
- Explore how to generate plots directly from Pandas objects.
- Compare Pandas plots with equivalent Matplotlib code.
- Demonstrate when Pandas is most effective for quick data analysis.

### 5.2.1 Basic Plotting from DataFrames

We can quickly visualize data from a single column or multiple columns.

In [None]:
# Create a sample data

data = {
    "Month": ["Jan", "Feb", "Mar", "Apr", "May"],
    "Sales_A": [250, 300, 400, 350, 500],
    "Sales_B": [200, 220, 300, 280, 400]
}

df = pd.DataFrame(data)
df

In [None]:
# --- Line plot of a single column ---
df["Sales_A"].plot(title="Sales A Over Months")

In [None]:
# --- Line plot of multiple columns ---
df.set_index("Month")[["Sales_A", "Sales_B"]].plot(title="Sales Comparison")

### 5.2.2 Plot Types Supported by Pandas

Pandas supports many plot types. Here are the most common:

- **Line plot** (`kind="line"`) — default plot type  
- **Bar plot** (`kind="bar"`)  
- **Horizontal bar** (`kind="barh"`)  
- **Histogram** (`kind="hist"`)  
- **Box plot** (`kind="box"`)  
- **Area plot** (`kind="area"`)  
- **Scatter plot** (`kind="scatter"`, requires `x` and `y`)  
- **Pie chart** (`kind="pie"`)  

In [None]:
# --- Bar plot ---
df.set_index("Month")[["Sales_A", "Sales_B"]].plot(kind="bar", title="Sales Bar Plot")

In [None]:
# --- Histogram ---
np.random.seed(0)
df_random = pd.DataFrame(np.random.randn(1000, 2), columns=["Feature1", "Feature2"])
df_random.plot(kind="hist", bins=30, alpha=0.7, title="Histogram of Random Features")

In [None]:
# --- Scatter plot ---
df_random.plot(kind="scatter", x="Feature1", y="Feature2", title="Scatter Plot Example")

### 5.2.3 Customizing Pandas Plots

Pandas plotting allows customization using parameters or by accessing the underlying Matplotlib axes.

- **Titles and labels:** `title`, `xlabel`, `ylabel`  
- **Colors and styles:** `color`, `style`, `marker`  
- **Subplots:** `subplots=True` to plot each column separately  
- **Grid:** `grid=True` for better readability

In [None]:
# --- Customization example ---
df.set_index("Month")[["Sales_A", "Sales_B"]].plot(
    kind="line",
    title="Customized Sales Plot",
    color=["blue", "orange"],
    style=["-o", "--s"],
    grid=True,
    xlabel="Month",
    ylabel="Sales Units"
)

In [None]:
# --- Subplots example ---
df.set_index("Month")[["Sales_A", "Sales_B"]].plot(
    kind="line",
    subplots=True,
    title="Sales Subplots",
    grid=True
)

Pandas plots are a convenient way to quickly visualize tabular data without writing a lot of Matplotlib code.  
For more advanced or publication-ready visualizations, we can combine Pandas plotting with Matplotlib commands or move on to **Seaborn**, which we will explore in the next subsection.

## 5.3 Data Visualization with Seaborn

Seaborn is a high-level data visualization library built on top of Matplotlib.  
It provides an easier interface for creating attractive and informative statistical graphics with less code.  
Seaborn also works seamlessly with Pandas DataFrames, making it a powerful tool for data exploration in machine learning.

Key advantages of Seaborn:
- Built-in themes and color palettes for more attractive plots.
- Concise syntax compared to Matplotlib.
- Specialized functions for visualizing statistical relationships (e.g., scatter plots with regression lines).
- Direct support for Pandas DataFrames.

In this subsection, we will:
- Learn how to get started with Seaborn.
- Explore commonly used plot types.
- Compare Seaborn’s simplicity and aesthetics with Matplotlib.

### 5.3.1 Basic Usage

Seaborn is designed to work seamlessly with Pandas DataFrames.  
Its syntax is simple: you specify the dataset and the column names for the axes, and Seaborn takes care of the rest.  
This makes it much easier to create clean, informative plots compared to writing the equivalent Matplotlib code.

In [None]:
# Load an example dataset (Iris dataset)
iris = sns.load_dataset("iris")

# Display first few rows
iris.head()

In [None]:
# --- Basic scatter plot with Seaborn ---
sns.scatterplot(data=iris, x="sepal_length", y="sepal_width")
plt.show()  # Explicit show for consistency

In [None]:
# --- Add color by category (species) ---
sns.scatterplot(data=iris, x="sepal_length", y="sepal_width", hue="species")
plt.show()

In [None]:
# --- Compare with Matplotlib (colored scatter with legend) ---
plt.figure(figsize=(6,4))

# Plot each species separately with its own color
for species, color in zip(iris["species"].unique(), ["blue", "orange", "green"]):
    subset = iris[iris["species"] == species]
    plt.scatter(subset["sepal_length"], subset["sepal_width"], label=species, color=color)

plt.xlabel("sepal_length")
plt.ylabel("sepal_width")
plt.title("Matplotlib scatter plot (with categories)")
plt.legend(title="species")
plt.show()

### 5.3.2 Common Plot Types in Seaborn
Seaborn provides many specialized plot types for data exploration.  
Here are some of the most commonly used in machine learning contexts.

In [None]:
# --- Regression plot: scatter with fitted line ---
sns.regplot(data=iris, x="sepal_length", y="sepal_width")
plt.show()

In [None]:
# --- Histogram and KDE (Kernel Density Estimate) ---
sns.histplot(data=iris, x="petal_length", bins=20, kde=True)
plt.show()

In [None]:
# --- Boxplot: useful for comparing distributions across categories ---
sns.boxplot(data=iris, x="species", y="petal_length")
plt.show()

In [None]:
# --- Violin plot: shows distribution + density ---
sns.violinplot(data=iris, x="species", y="petal_length")
plt.show()

- **Boxplots** summarize the distribution of a numerical variable across categories.  
  Each box shows:  
  - The **median** (middle line inside the box).  
  - The **interquartile range (IQR)** — from the 25th percentile (Q1) to the 75th percentile (Q3).  
  - The **"whiskers"** usually extend up to 1.5 × IQR beyond the box.  
  - Points outside this range are plotted as **outliers**.  
  In machine learning, boxplots are useful for comparing feature distributions across classes or groups.

- **Violin plots** combine a boxplot with a rotated **kernel density estimate (KDE)** on each side.  
  - The width of the “violin” shows how dense the data is at different values.  
  - Like boxplots, violin plots allow comparisons between categories, but they also reveal the **shape of the distribution** (e.g., skewness, multimodality).  
  Violin plots are often preferred when we want more detail about how the data is distributed.

In [None]:
# --- Pair plot: shows a matrix of scatterplotsand marginal distributions ---
sns.pairplot(iris, hue="species")
plt.show()

**Plot pairwise relationships in a dataset.**

By default, this function will create a **grid of Axes** such that each numeric variable in data will by shared across the y-axes across a single row and the x-axes across a single column.
The **diagonal plots** are treated differently: a univariate distribution plot is drawn to show the marginal distribution of the data in each column.

It is also possible to show a **subset of variables** or plot different variables on the rows and columns.

In [None]:
# --- Heatmap: visualize correlations between numerical features ---
corr = iris.corr(numeric_only=True)
sns.heatmap(corr, annot=True, cmap="coolwarm")
plt.show()

# 6. Exercises for Students (Self-study)

Work through the following exercises to strengthen your understanding of data visualization with **Matplotlib**, **Pandas**, and **Seaborn**.  

Start with the **Iris dataset**, then repeat the same tasks with another dataset of your choice from the list below:  

- **penguins** (Seaborn): Measurements for three penguin species (bill length/depth, flipper length, body mass). Good for categorical comparisons and scatter plots.  
- **tips** (Seaborn): Restaurant bills, tips, and categorical info (day, time, smoker/non-smoker). Great for categorical + numerical mixes.  
- **flights** (Seaborn): Monthly airline passenger counts over years. Useful for line plots and time series.  
- **diamonds** (Seaborn): Prices and characteristics of diamonds (carat, cut, color, clarity, price). Rich dataset for categorical and numerical visualizations.  

Compare how the choice of dataset affects the type of plots and the insights you can extract.  

---

1. **Getting Comfortable with Matplotlib**  
   - Create a simple line plot of any mathematical function (e.g., sine, cosine, quadratic).  
   - Customize the figure using `figsize`, add axis labels, a title, and a legend.  
   - Try saving the figure as an image file.  
   - *(Repeat with a time-related variable, e.g., passenger counts in `flights`.)*

2. **Exploring Plot Types**  
   - Using the Iris dataset, create a **scatter plot** of petal length vs. petal width, color-coded by species.  
   - Reproduce the same visualization using Seaborn.  
   - *(Repeat with another dataset, e.g., bill length vs. bill depth in `penguins` or carat vs. price in `diamonds`.)*

3. **Histograms and Distributions**  
   - Plot a histogram of sepal length using Matplotlib.  
   - Create the same plot with Pandas’ `.hist()` and Seaborn’s `histplot`.  
   - Compare the three plots: what advantages does each library offer?  
   - *(Repeat with another dataset, e.g., total bill in `tips` or diamond carat in `diamonds`.)*

4. **Categorical Visualizations**  
   - Use boxplots and violin plots to visualize petal length across Iris species.  
   - Compare what new information the violin plot provides compared to the boxplot.  
   - *(Repeat with another dataset, e.g., body mass across penguin species in `penguins` or diamond price across cut categories in `diamonds`.)*

5. **Bar Plots**  
   - Calculate the average sepal width per species with Pandas.  
   - Plot the results using both Pandas’ `.plot(kind="bar")` and Seaborn’s `barplot`.  
   - Compare the two visualizations.  
   - *(Repeat with another dataset, e.g., average tip per day of the week in `tips` or average diamond price per color in `diamonds`.)*

6. **Working with Subplots**  
   - Create a single figure with four subplots:  
     - Line plot  
     - Scatter plot  
     - Histogram  
     - Boxplot  
   - Use `plt.subplots()` to arrange them in a 2×2 grid.  
   - Add appropriate titles for each subplot.  
   - *(Repeat with another dataset of your choice from the list, selecting variables that make sense for each plot type.)*

---

💡 *Tip:* When switching datasets, ask yourself:  
- Which plot type is most appropriate for this new dataset?  
- Do numerical or categorical variables require different visualization strategies?  
- How do the insights differ compared to the Iris dataset?