# Matplotlib: Data Visualization in Python

## Introduction

Matplotlib is a comprehensive library for creating static, animated, and interactive visualizations in Python. It provides a MATLAB-like interface, particularly when used with the pyplot module. Matplotlib is the foundation for many other visualization libraries in Python.

This tutorial will introduce you to the most essential Matplotlib concepts and functions that are widely used in machine learning and data science projects for creating professional visualizations.

**Source:** [Matplotlib Documentation](https://matplotlib.org/stable/index.html) and [Data Visualization with Matplotlib](https://github.com/pb111/Data-Analysis-and-Visualization-with-Python/blob/master/Matplotlib_Basics.ipynb)

## Why Matplotlib?

Matplotlib is a powerful tool for data visualization because:

- It provides a comprehensive set of visualization tools for creating publication-quality figures
- It offers both high-level (pyplot) and object-oriented interfaces for different use cases
- It works well with NumPy and Pandas
- It's highly customizable, allowing for precise control over every aspect of a visualization
- It supports multiple output formats (PNG, PDF, SVG, etc.) and interactive environments

These capabilities make Matplotlib the go-to library for customized data visualization in machine learning workflows.

### Key Advantages of Matplotlib:

1. **Flexibility**: From simple plots to complex visualizations, Matplotlib can handle it all
2. **Integration**: Seamlessly works with other scientific Python libraries
3. **Customization**: Fine-grained control over every element of a plot
4. **Multiple Backends**: Works in various environments (Jupyter, scripts, web applications)
5. **Community Support**: Extensive documentation and a large user community

In [None]:
# Import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Set the style
plt.style.use('seaborn-v0_8-whitegrid')

# Check matplotlib version
import matplotlib
print(f"Matplotlib version: {matplotlib.__version__}")

## 1. Matplotlib Basics

### Concept: Matplotlib Architecture

Matplotlib has a hierarchical structure with three main layers:

1. **Backend Layer**: Handles the rendering of plots to different outputs (screen, file, etc.)
2. **Artist Layer**: Contains objects that represent elements of a plot (lines, text, etc.)
3. **Scripting Layer (pyplot)**: Provides a simplified interface for common plotting tasks

Most users interact with Matplotlib through the pyplot interface, which provides a MATLAB-like experience. However, for more complex visualizations or fine-grained control, the object-oriented interface is often preferred.

In [None]:
# Simple line plot using pyplot interface
x = np.linspace(0, 10, 100)
y = np.sin(x)

plt.figure(figsize=(10, 6))
plt.plot(x, y, label='sin(x)')
plt.title('Simple Sine Wave')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.legend()
plt.grid(True)
plt.show()

# Same plot using object-oriented interface
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, y, label='sin(x)')
ax.set_title('Simple Sine Wave (Object-Oriented Interface)')
ax.set_xlabel('x')
ax.set_ylabel('sin(x)')
ax.legend()
ax.grid(True)
plt.show()

### Concept: Figures and Axes

In Matplotlib, a **Figure** is the top-level container that holds all plot elements. Each Figure can contain one or more **Axes**, which are the actual plotting areas. Understanding this relationship is crucial for creating complex visualizations with multiple subplots.

In [None]:
# Create a figure with multiple subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Flatten the 2D array of axes for easier indexing
axes = axes.flatten()

# Plot different functions on each subplot
x = np.linspace(0, 10, 100)

# Sine wave
axes[0].plot(x, np.sin(x), 'b-', label='sin(x)')
axes[0].set_title('Sine Function')
axes[0].set_xlabel('x')
axes[0].set_ylabel('sin(x)')
axes[0].legend()
axes[0].grid(True)

# Cosine wave
axes[1].plot(x, np.cos(x), 'r-', label='cos(x)')
axes[1].set_title('Cosine Function')
axes[1].set_xlabel('x')
axes[1].set_ylabel('cos(x)')
axes[1].legend()
axes[1].grid(True)

# Parabola
axes[2].plot(x, x**2, 'g-', label='x²')
axes[2].set_title('Quadratic Function')
axes[2].set_xlabel('x')
axes[2].set_ylabel('x²')
axes[2].legend()
axes[2].grid(True)

# Exponential
axes[3].plot(x, np.exp(x/5), 'm-', label='exp(x/5)')
axes[3].set_title('Exponential Function')
axes[3].set_xlabel('x')
axes[3].set_ylabel('exp(x/5)')
axes[3].legend()
axes[3].grid(True)

# Adjust layout
plt.tight_layout()
plt.show()

# Create a figure with subplots of different sizes (GridSpec)
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(12, 8))
gs = gridspec.GridSpec(2, 3, figure=fig)

# Create axes of different sizes
ax1 = fig.add_subplot(gs[0, :])
ax2 = fig.add_subplot(gs[1, 0])
ax3 = fig.add_subplot(gs[1, 1])
ax4 = fig.add_subplot(gs[1, 2])

# Plot data
ax1.plot(x, np.sin(x), 'b-', label='sin(x)')
ax1.set_title('Sine Function (Full Width)')
ax1.legend()
ax1.grid(True)

ax2.plot(x, np.cos(x), 'r-')
ax2.set_title('Cosine')
ax2.grid(True)

ax3.plot(x, x**2, 'g-')
ax3.set_title('Quadratic')
ax3.grid(True)

ax4.plot(x, np.exp(x/5), 'm-')
ax4.set_title('Exponential')
ax4.grid(True)

plt.tight_layout()
plt.show()

### Concept: Basic Plot Types

Matplotlib supports a wide variety of plot types to visualize different kinds of data. Here are some of the most commonly used plot types:

In [None]:
# Generate sample data
np.random.seed(42)
x = np.linspace(0, 10, 30)
y = np.sin(x) + np.random.normal(0, 0.2, size=len(x))
categories = ['A', 'B', 'C', 'D', 'E']
values = [25, 40, 30, 55, 15]

# Line plot
plt.figure(figsize=(12, 8))

plt.subplot(2, 3, 1)
plt.plot(x, y, 'o-', label='Data')
plt.title('Line Plot')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.grid(True)

# Scatter plot
plt.subplot(2, 3, 2)
plt.scatter(x, y, c=y, cmap='viridis', s=100, alpha=0.7)
plt.title('Scatter Plot')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar(label='y value')
plt.grid(True)

# Bar plot
plt.subplot(2, 3, 3)
plt.bar(categories, values, color='skyblue', edgecolor='navy')
plt.title('Bar Plot')
plt.xlabel('Category')
plt.ylabel('Value')
plt.grid(axis='y')

# Histogram
plt.subplot(2, 3, 4)
plt.hist(y, bins=10, color='lightgreen', edgecolor='darkgreen', alpha=0.7)
plt.title('Histogram')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.grid(True)

# Pie chart
plt.subplot(2, 3, 5)
plt.pie(values, labels=categories, autopct='%1.1f%%', startangle=90, 
        colors=plt.cm.Paired(np.linspace(0, 1, len(categories))))
plt.title('Pie Chart')
plt.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle

# Box plot
plt.subplot(2, 3, 6)
data = [np.random.normal(0, std, 100) for std in range(1, 6)]
plt.boxplot(data, labels=categories, patch_artist=True)
plt.title('Box Plot')
plt.xlabel('Category')
plt.ylabel('Value')
plt.grid(True)

plt.tight_layout()
plt.show()

### Concept: Customizing Plot Appearance

Matplotlib provides extensive options for customizing the appearance of plots. This includes controlling colors, line styles, markers, fonts, and more.

In [None]:
# Generate data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
y3 = np.sin(x) * np.cos(x)

# Create a customized plot
plt.figure(figsize=(12, 8))

# Plot with different line styles, colors, and markers
plt.plot(x, y1, 'b-', linewidth=2, label='sin(x)')
plt.plot(x, y2, 'r--', linewidth=2, label='cos(x)')
plt.plot(x, y3, 'g-.', linewidth=2, label='sin(x)cos(x)')

# Add points at specific locations
special_points_x = [np.pi/2, np.pi, 3*np.pi/2, 2*np.pi]
special_points_y1 = np.sin(special_points_x)
special_points_y2 = np.cos(special_points_x)
plt.plot(special_points_x, special_points_y1, 'bo', markersize=10)
plt.plot(special_points_x, special_points_y2, 'ro', markersize=10)

# Customize the plot
plt.title('Trigonometric Functions', fontsize=18, fontweight='bold')
plt.xlabel('x', fontsize=14)
plt.ylabel('y', fontsize=14)

# Add a grid with custom properties
plt.grid(True, linestyle='--', alpha=0.7)

# Customize the axis limits
plt.xlim(0, 2*np.pi)
plt.ylim(-1.2, 1.2)

# Add custom tick marks
plt.xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi],
           ['0', '$\pi/2$', '$\pi$', '$3\pi/2$', '$2\pi$'],
           fontsize=12)

# Add a legend with custom properties
plt.legend(loc='upper right', fontsize=12, frameon=True, framealpha=0.9, 
           facecolor='white', edgecolor='gray')

# Add annotations
plt.annotate('sin(x) = 1', xy=(np.pi/2, 1), xytext=(np.pi/2 + 0.5, 1.1),
             arrowprops=dict(facecolor='black', shrink=0.05, width=1.5),
             fontsize=12)

plt.annotate('cos(x) = -1', xy=(np.pi, -1), xytext=(np.pi + 0.5, -1.1),
             arrowprops=dict(facecolor='black', shrink=0.05, width=1.5),
             fontsize=12)

# Add a horizontal line at y=0
plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)

# Add vertical lines at special x values
for x_val in special_points_x:
    plt.axvline(x=x_val, color='k', linestyle=':', alpha=0.3)

# Add text
plt.text(0.5, -0.8, 'Customized Matplotlib Plot', fontsize=14, 
         bbox=dict(facecolor='yellow', alpha=0.2))

plt.tight_layout()
plt.show()

### Concept: Saving Figures

Matplotlib can save figures in various formats, including PNG, PDF, SVG, and more. This is essential for including visualizations in reports, presentations, or publications.

In [None]:
# Create a simple figure
plt.figure(figsize=(10, 6))
plt.plot(np.random.randn(50).cumsum(), 'r-', linewidth=2)
plt.title('Random Walk')
plt.xlabel('Step')
plt.ylabel('Position')
plt.grid(True)

# Save the figure in different formats
# PNG (raster format, good for web)
plt.savefig('random_walk.png', dpi=300, bbox_inches='tight')

# PDF (vector format, good for publications)
plt.savefig('random_walk.pdf', bbox_inches='tight')

# SVG (vector format, good for web and editing)
plt.savefig('random_walk.svg', bbox_inches='tight')

plt.show()

print("Figure saved in PNG, PDF, and SVG formats.")

## 2. Plotting with Pandas and NumPy

### Concept: Integration with Pandas

Pandas DataFrames and Series have built-in plotting methods that are based on Matplotlib. This integration makes it easy to visualize data directly from Pandas objects.

In [None]:
# Create a sample DataFrame
dates = pd.date_range('2023-01-01', periods=100)
df = pd.DataFrame({
    'A': np.random.randn(100).cumsum(),
    'B': np.random.randn(100).cumsum(),
    'C': np.random.randn(100).cumsum(),
    'D': np.random.randn(100).cumsum()
}, index=dates)

print("Sample DataFrame:")
print(df.head())

# Line plot
plt.figure(figsize=(12, 6))
df.plot(figsize=(12, 6), grid=True)
plt.title('Line Plot from DataFrame')
plt.xlabel('Date')
plt.ylabel('Value')
plt.legend(loc='best')
plt.show()

# Multiple plot types
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Line plot
df.plot(ax=axes[0, 0], title='Line Plot')
axes[0, 0].set_xlabel('Date')
axes[0, 0].set_ylabel('Value')
axes[0, 0].legend(loc='best')
axes[0, 0].grid(True)

# Bar plot
df.iloc[-10:].plot.bar(ax=axes[0, 1], title='Bar Plot (Last 10 Days)')
axes[0, 1].set_xlabel('Date')
axes[0, 1].set_ylabel('Value')
axes[0, 1].legend(loc='best')
axes[0, 1].grid(True)

# Scatter plot
df.plot.scatter(x='A', y='B', ax=axes[1, 0], c='C', cmap='viridis', 
                s=df['D']**2 * 10 + 50, alpha=0.6, title='Scatter Plot (A vs B)')
axes[1, 0].set_xlabel('A')
axes[1, 0].set_ylabel('B')
axes[1, 0].grid(True)

# Histogram
df.plot.hist(ax=axes[1, 1], bins=20, alpha=0.7, title='Histogram')
axes[1, 1].set_xlabel('Value')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

# Additional Pandas plot types
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Box plot
df.plot.box(ax=axes[0, 0], title='Box Plot')
axes[0, 0].set_ylabel('Value')
axes[0, 0].grid(True)

# Area plot
df.plot.area(ax=axes[0, 1], alpha=0.4, title='Area Plot')
axes[0, 1].set_xlabel('Date')
axes[0, 1].set_ylabel('Value')
axes[0, 1].grid(True)

# Kernel Density Estimation (KDE) plot
df.plot.kde(ax=axes[1, 0], title='KDE Plot')
axes[1, 0].set_xlabel('Value')
axes[1, 0].set_ylabel('Density')
axes[1, 0].grid(True)

# Hexbin plot
df.plot.hexbin(x='A', y='B', ax=axes[1, 1], gridsize=20, cmap='Blues', title='Hexbin Plot (A vs B)')
axes[1, 1].set_xlabel('A')
axes[1, 1].set_ylabel('B')

plt.tight_layout()
plt.show()

### Concept: Statistical Visualization

Matplotlib can be used to create statistical visualizations that help understand the distribution and relationships in data.

In [None]:
# Import seaborn for enhanced statistical visualizations
import seaborn as sns

# Generate multivariate normal data
np.random.seed(42)
mean = [0, 0, 0, 0]
cov = [[1, 0.7, 0.3, 0.2],
       [0.7, 1, 0.4, 0.1],
       [0.3, 0.4, 1, 0.5],
       [0.2, 0.1, 0.5, 1]]
data = np.random.multivariate_normal(mean, cov, 1000)
df_stats = pd.DataFrame(data, columns=['A', 'B', 'C', 'D'])

# Correlation matrix heatmap
plt.figure(figsize=(10, 8))
corr = df_stats.corr()
sns.heatmap(corr, annot=True, cmap='coolwarm', linewidths=0.5)
plt.title('Correlation Matrix Heatmap')
plt.tight_layout()
plt.show()

# Pair plot
sns.pairplot(df_stats, diag_kind='kde', plot_kws={'alpha': 0.6})
plt.suptitle('Pair Plot of Variables', y=1.02, fontsize=16)
plt.show()

# Joint distribution plot
plt.figure(figsize=(12, 10))
g = sns.jointplot(x='A', y='B', data=df_stats, kind='hex', height=8, ratio=5, 
                 marginal_kws={'bins': 20, 'fill': True})
g.fig.suptitle('Joint Distribution of A and B', y=1.02, fontsize=16)
plt.tight_layout()
plt.show()

# Violin plot
plt.figure(figsize=(12, 6))
sns.violinplot(data=df_stats, palette='Set3')
plt.title('Violin Plot of Variables')
plt.ylabel('Value')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

### Concept: 3D Plotting

Matplotlib's `mplot3d` toolkit allows for creating three-dimensional plots, which can be useful for visualizing relationships between three variables.

In [None]:
from mpl_toolkits.mplot3d import Axes3D

# 3D Line Plot
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')

# Generate a spiral
t = np.linspace(0, 20, 1000)
x = np.cos(t)
y = np.sin(t)
z = t

ax.plot(x, y, z, label='3D Spiral', linewidth=2)
ax.set_title('3D Line Plot')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.legend()
plt.show()

# 3D Surface Plot
fig = plt.figure(figsize=(14, 12))

# First subplot: Simple surface
ax1 = fig.add_subplot(221, projection='3d')
X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)

surf = ax1.plot_surface(X, Y, Z, cmap='viridis', linewidth=0, antialiased=True)
ax1.set_title('3D Surface Plot: sin(sqrt(x² + y²))')
fig.colorbar(surf, ax=ax1, shrink=0.5, aspect=5)

# Second subplot: Wireframe
ax2 = fig.add_subplot(222, projection='3d')
ax2.plot_wireframe(X, Y, Z, color='green', rstride=1, cstride=1, alpha=0.5)
ax2.set_title('3D Wireframe Plot')

# Third subplot: Contour plot
ax3 = fig.add_subplot(223, projection='3d')
ax3.contour3D(X, Y, Z, 50, cmap='binary')
ax3.set_title('3D Contour Plot')

# Fourth subplot: Scatter plot
ax4 = fig.add_subplot(224, projection='3d')
n = 100
xs = np.random.rand(n) * 4 - 2
ys = np.random.rand(n) * 4 - 2
zs = np.random.rand(n) * 4 - 2
colors = np.random.rand(n)
sizes = np.random.rand(n) * 100 + 20

ax4.scatter(xs, ys, zs, c=colors, s=sizes, alpha=0.6, cmap='viridis')
ax4.set_title('3D Scatter Plot')

plt.tight_layout()
plt.show()

## 3. Advanced Matplotlib Techniques

### Concept: Custom Colormaps and Color Schemes

Matplotlib provides a wide range of colormaps for visualizing data. You can also create custom colormaps to suit specific needs.

In [None]:
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap

# Display built-in colormaps
cmaps = ['viridis', 'plasma', 'inferno', 'magma', 'cividis',
         'Greys', 'Blues', 'Reds', 'YlOrBr', 'RdPu',
         'coolwarm', 'bwr', 'seismic', 'twilight', 'hsv']

# Create a sample image
data = np.random.rand(20, 20)
fig, axes = plt.subplots(3, 5, figsize=(15, 9), constrained_layout=True)
axes = axes.flatten()

for ax, cmap_name in zip(axes, cmaps):
    im = ax.imshow(data, cmap=cmap_name)
    ax.set_title(cmap_name)
    ax.set_xticks([])
    ax.set_yticks([])

plt.suptitle('Built-in Colormaps', fontsize=16)
plt.show()

# Create a custom colormap
colors = [(0, 'navy'), (0.4, 'blue'), (0.6, 'lime'), (0.8, 'yellow'), (1, 'red')]
cmap_name = 'custom_diverging'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=256)

# Create a sample heatmap with the custom colormap
plt.figure(figsize=(10, 8))
data = np.random.randn(20, 20)
plt.imshow(data, cmap=cm)
plt.colorbar(label='Value')
plt.title('Heatmap with Custom Colormap')
plt.xticks([])
plt.yticks([])
plt.show()

# Create a visualization showing the colormap gradient
plt.figure(figsize=(12, 4))
gradient = np.linspace(0, 1, 256).reshape(1, -1)
plt.imshow(gradient, aspect='auto', cmap=cm)
plt.title('Custom Colormap Gradient')
plt.xticks([])
plt.yticks([])
plt.show()

### Concept: Annotations and Text

Annotations and text elements can significantly enhance the clarity and information content of visualizations.

In [None]:
# Create a sample dataset
np.random.seed(42)
x = np.linspace(0, 10, 100)
y = 4 + 2 * np.sin(2 * x) + np.random.randn(100) * 0.5

# Fit a polynomial to the data
coeffs = np.polyfit(x, y, 3)
poly_fit = np.polyval(coeffs, x)

# Create the plot
plt.figure(figsize=(12, 8))

# Plot the data and fit
plt.scatter(x, y, alpha=0.7, label='Data')
plt.plot(x, poly_fit, 'r-', linewidth=2, label='Polynomial Fit')

# Add a title and axis labels
plt.title('Advanced Annotations Example', fontsize=16)
plt.xlabel('X-axis', fontsize=12)
plt.ylabel('Y-axis', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend()

# Find the maximum and minimum points of the fit
max_idx = np.argmax(poly_fit)
min_idx = np.argmin(poly_fit)
max_point = (x[max_idx], poly_fit[max_idx])
min_point = (x[min_idx], poly_fit[min_idx])

# Annotate the maximum point
plt.annotate('Maximum', xy=max_point, xytext=(max_point[0]+1, max_point[1]+1),
             arrowprops=dict(facecolor='black', shrink=0.05, width=1.5),
             fontsize=12, ha='center')

# Annotate the minimum point
plt.annotate('Minimum', xy=min_point, xytext=(min_point[0]-1, min_point[1]-1),
             arrowprops=dict(facecolor='black', shrink=0.05, width=1.5),
             fontsize=12, ha='center')

# Add a text box with the polynomial coefficients
coeff_text = f"$y = {coeffs[0]:.3f}x^3 + {coeffs[1]:.3f}x^2 + {coeffs[2]:.3f}x + {coeffs[3]:.3f}$"
plt.text(0.5, 2.5, coeff_text, fontsize=12, 
         bbox=dict(facecolor='white', edgecolor='gray', alpha=0.8))

# Add a text box with statistics
residuals = y - poly_fit
mse = np.mean(residuals**2)
r_squared = 1 - np.sum(residuals**2) / np.sum((y - np.mean(y))**2)

stats_text = f"MSE: {mse:.4f}\nR²: {r_squared:.4f}"
plt.text(8, 6, stats_text, fontsize=12,
         bbox=dict(facecolor='lightblue', edgecolor='blue', alpha=0.8))

# Add vertical spans to highlight regions
plt.axvspan(2, 4, alpha=0.2, color='green', label='Region of Interest 1')
plt.axvspan(7, 9, alpha=0.2, color='red', label='Region of Interest 2')

# Add horizontal lines at specific y-values
plt.axhline(y=4, color='gray', linestyle='--', alpha=0.7)
plt.text(0.2, 4.1, 'Mean Level', fontsize=10)

# Add arrows pointing to specific data points
outlier_idx = np.argmax(np.abs(residuals))
outlier_point = (x[outlier_idx], y[outlier_idx])
plt.annotate('Largest Residual', xy=outlier_point, xytext=(outlier_point[0]+1.5, outlier_point[1]),
             arrowprops=dict(facecolor='purple', shrink=0.05, width=1.5),
             fontsize=10, color='purple')

plt.tight_layout()
plt.show()

### Concept: Custom Styling and Themes

Matplotlib allows for extensive customization of plot styles through style sheets and rcParams.

In [None]:
# List available styles
print("Available styles:")
print(plt.style.available)

# Create a sample plot with different styles
x = np.linspace(0, 10, 100)
y = np.sin(x) * np.exp(-0.1 * x)

styles = ['default', 'seaborn-v0_8', 'ggplot', 'bmh', 'dark_background', 'fivethirtyeight']
fig, axes = plt.subplots(3, 2, figsize=(15, 12), constrained_layout=True)
axes = axes.flatten()

for ax, style in zip(axes, styles):
    with plt.style.context(style):
        ax.plot(x, y, 'o-', linewidth=2, markersize=4)
        ax.set_title(f"Style: {style}", fontsize=14)
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.grid(True)

plt.suptitle('Matplotlib Style Comparison', fontsize=18)
plt.show()

# Create a custom style
custom_params = {
    'figure.figsize': (12, 8),
    'figure.facecolor': '#f8f9fa',
    'axes.facecolor': '#f8f9fa',
    'axes.edgecolor': '#343a40',
    'axes.labelcolor': '#343a40',
    'axes.grid': True,
    'axes.grid.which': 'both',
    'axes.grid.axis': 'both',
    'grid.color': '#ced4da',
    'grid.linestyle': '--',
    'grid.linewidth': 0.8,
    'grid.alpha': 0.5,
    'xtick.color': '#343a40',
    'ytick.color': '#343a40',
    'text.color': '#343a40',
    'font.family': 'sans-serif',
    'font.sans-serif': ['Arial', 'DejaVu Sans', 'Liberation Sans', 'Bitstream Vera Sans', 'sans-serif'],
    'lines.linewidth': 2.5,
    'lines.markersize': 8,
    'legend.frameon': True,
    'legend.framealpha': 0.8,
    'legend.edgecolor': '#ced4da',
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.1,
}

# Apply the custom style
plt.rcParams.update(custom_params)

# Create a plot with the custom style
plt.figure()
x = np.linspace(0, 10, 100)
plt.plot(x, np.sin(x), label='sin(x)')
plt.plot(x, np.cos(x), label='cos(x)')
plt.plot(x, np.sin(x) * np.cos(x), label='sin(x)cos(x)')
plt.title('Custom Styled Plot', fontsize=16)
plt.xlabel('x', fontsize=12)
plt.ylabel('y', fontsize=12)
plt.legend()
plt.show()

# Reset to default style
plt.rcdefaults()

### Concept: Interactive Plots

Matplotlib provides basic interactivity features, which can be enhanced with libraries like ipywidgets for Jupyter notebooks.

In [None]:
from ipywidgets import interact, FloatSlider, Dropdown

# Create a function for interactive plotting
def plot_function(function='sin', amplitude=1.0, frequency=1.0, phase=0.0):
    plt.figure(figsize=(10, 6))
    x = np.linspace(0, 10, 1000)
    
    if function == 'sin':
        y = amplitude * np.sin(frequency * x + phase)
        title = f'y = {amplitude} * sin({frequency}x + {phase})'
    elif function == 'cos':
        y = amplitude * np.cos(frequency * x + phase)
        title = f'y = {amplitude} * cos({frequency}x + {phase})'
    elif function == 'tan':
        y = amplitude * np.tan(frequency * x + phase)
        # Clip extreme values for better visualization
        y = np.clip(y, -10, 10)
        title = f'y = {amplitude} * tan({frequency}x + {phase}) (clipped)'
    elif function == 'exp':
        y = amplitude * np.exp(frequency * x / 5)
        title = f'y = {amplitude} * exp({frequency}x/5)'
    
    plt.plot(x, y, 'b-', linewidth=2)
    plt.title(title, fontsize=14)
    plt.xlabel('x', fontsize=12)
    plt.ylabel('y', fontsize=12)
    plt.grid(True)
    plt.ylim(-5, 5)
    plt.show()

# Create interactive widgets
interact(plot_function, 
         function=Dropdown(options=['sin', 'cos', 'tan', 'exp'], value='sin', description='Function:'),
         amplitude=FloatSlider(min=0.1, max=3.0, step=0.1, value=1.0, description='Amplitude:'),
         frequency=FloatSlider(min=0.1, max=3.0, step=0.1, value=1.0, description='Frequency:'),
         phase=FloatSlider(min=0.0, max=6.28, step=0.1, value=0.0, description='Phase:'));

## 4. Real-world Applications of Matplotlib in Machine Learning

### 4.1 Visualizing Model Performance

Matplotlib is essential for evaluating and communicating the performance of machine learning models.

In [None]:
from sklearn.datasets import load_iris, load_breast_cancer
from sklearn.model_selection import train_test_split, learning_curve, validation_curve
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, roc_curve, precision_recall_curve, auc
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score

# Load a dataset
data = load_breast_cancer()
X = data.data
y = data.target
feature_names = data.feature_names
target_names = data.target_names

# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Standardize the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Train a model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train_scaled, y_train)

# Make predictions
y_pred = model.predict(X_test_scaled)
y_pred_proba = model.predict_proba(X_test_scaled)[:, 1]

# Calculate metrics
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=target_names))

# Visualize confusion matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=target_names, yticklabels=target_names)
plt.title('Confusion Matrix', fontsize=16)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.show()

# Visualize ROC curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('Receiver Operating Characteristic (ROC) Curve', fontsize=16)
plt.legend(loc="lower right")
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

# Visualize Precision-Recall curve
precision_curve, recall_curve, _ = precision_recall_curve(y_test, y_pred_proba)
pr_auc = auc(recall_curve, precision_curve)

plt.figure(figsize=(10, 8))
plt.plot(recall_curve, precision_curve, color='green', lw=2, 
         label=f'Precision-Recall curve (area = {pr_auc:.2f})')
plt.xlabel('Recall', fontsize=12)
plt.ylabel('Precision', fontsize=12)
plt.title('Precision-Recall Curve', fontsize=16)
plt.legend(loc="best")
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

# Visualize feature importance
importances = model.feature_importances_
indices = np.argsort(importances)[::-1]

plt.figure(figsize=(12, 10))
plt.title('Feature Importances', fontsize=16)
plt.bar(range(10), importances[indices[:10]], align='center', alpha=0.7)
plt.xticks(range(10), [feature_names[i] for i in indices[:10]], rotation=45, ha='right')
plt.tight_layout()
plt.show()

# Visualize learning curves
train_sizes, train_scores, test_scores = learning_curve(
    RandomForestClassifier(n_estimators=100, random_state=42), X_scaled, y, 
    cv=5, n_jobs=-1, train_sizes=np.linspace(0.1, 1.0, 10))

train_mean = np.mean(train_scores, axis=1)
train_std = np.std(train_scores, axis=1)
test_mean = np.mean(test_scores, axis=1)
test_std = np.std(test_scores, axis=1)

plt.figure(figsize=(10, 8))
plt.plot(train_sizes, train_mean, 'o-', color='r', label='Training score')
plt.fill_between(train_sizes, train_mean - train_std, train_mean + train_std, alpha=0.1, color='r')
plt.plot(train_sizes, test_mean, 'o-', color='g', label='Cross-validation score')
plt.fill_between(train_sizes, test_mean - test_std, test_mean + test_std, alpha=0.1, color='g')
plt.xlabel('Training examples', fontsize=12)
plt.ylabel('Score', fontsize=12)
plt.title('Learning Curves', fontsize=16)
plt.legend(loc='best')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

### 4.2 Visualizing Data Distributions and Relationships

Understanding the distribution of features and their relationships is crucial for effective machine learning.

In [None]:
# Load the Iris dataset
iris = load_iris()
X_iris = iris.data
y_iris = iris.target
feature_names_iris = iris.feature_names
target_names_iris = iris.target_names

# Create a DataFrame
iris_df = pd.DataFrame(X_iris, columns=feature_names_iris)
iris_df['species'] = [target_names_iris[i] for i in y_iris]

# Visualize feature distributions by class
plt.figure(figsize=(15, 10))

for i, feature in enumerate(feature_names_iris):
    plt.subplot(2, 2, i+1)
    for species in target_names_iris:
        subset = iris_df[iris_df['species'] == species]
        sns.kdeplot(subset[feature], label=species, shade=True)
    plt.title(f'Distribution of {feature}')
    plt.xlabel(feature)
    plt.ylabel('Density')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)

plt.tight_layout()
plt.show()

# Visualize pairwise relationships
sns.pairplot(iris_df, hue='species', height=2.5, diag_kind='kde')
plt.suptitle('Pairwise Relationships in Iris Dataset', y=1.02, fontsize=16)
plt.show()

# Visualize feature correlations
plt.figure(figsize=(10, 8))
corr = iris_df.drop('species', axis=1).corr()
mask = np.triu(np.ones_like(corr, dtype=bool))
sns.heatmap(corr, mask=mask, cmap='coolwarm', annot=True, fmt='.2f', square=True, linewidths=0.5)
plt.title('Feature Correlation Matrix', fontsize=16)
plt.tight_layout()
plt.show()

# Visualize data in 3D
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')

colors = ['navy', 'turquoise', 'darkorange']
for i, species in enumerate(target_names_iris):
    subset = iris_df[iris_df['species'] == species]
    ax.scatter(subset['sepal length (cm)'], 
               subset['sepal width (cm)'], 
               subset['petal length (cm)'],
               c=colors[i], label=species, s=60, alpha=0.8)

ax.set_xlabel('Sepal Length (cm)', fontsize=12)
ax.set_ylabel('Sepal Width (cm)', fontsize=12)
ax.set_zlabel('Petal Length (cm)', fontsize=12)
ax.set_title('3D Visualization of Iris Dataset', fontsize=16)
ax.legend()
plt.show()

### 4.3 Visualizing Model Decision Boundaries

Visualizing decision boundaries helps understand how machine learning models classify data points.

In [None]:
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.neural_network import MLPClassifier
from matplotlib.colors import ListedColormap

# Function to plot decision boundaries
def plot_decision_boundaries(X, y, model, model_name):
    # Reduce to 2D using PCA if more than 2 dimensions
    if X.shape[1] > 2:
        pca = PCA(n_components=2)
        X_2d = pca.fit_transform(X)
        print(f"Explained variance ratio: {pca.explained_variance_ratio_}")
    else:
        X_2d = X
    
    # Train the model on the 2D data
    model.fit(X_2d, y)
    
    # Create a mesh grid
    h = 0.02  # Step size in the mesh
    x_min, x_max = X_2d[:, 0].min() - 1, X_2d[:, 0].max() + 1
    y_min, y_max = X_2d[:, 1].min() - 1, X_2d[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    
    # Predict class for each point in the mesh
    Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    
    # Create color maps
    cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
    cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])
    
    # Plot the decision boundary
    plt.figure(figsize=(10, 8))
    plt.contourf(xx, yy, Z, alpha=0.4, cmap=cmap_light)
    
    # Plot the training points
    scatter = plt.scatter(X_2d[:, 0], X_2d[:, 1], c=y, cmap=cmap_bold, edgecolor='k', s=50, alpha=0.8)
    
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.title(f"{model_name} Decision Boundary", fontsize=16)
    plt.xlabel('Feature 1', fontsize=12)
    plt.ylabel('Feature 2', fontsize=12)
    
    # Add a legend
    legend1 = plt.legend(*scatter.legend_elements(), title="Classes")
    plt.gca().add_artist(legend1)
    
    # Add accuracy score
    accuracy = model.score(X_2d, y)
    plt.text(xx.max() - 0.3 * (xx.max() - xx.min()), 
             yy.min() + 0.1 * (yy.max() - yy.min()), 
             f'Accuracy: {accuracy:.2f}', 
             bbox=dict(facecolor='white', alpha=0.8))
    
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()

# Use the Iris dataset
X_iris = iris.data
y_iris = iris.target

# Create different models
models = [
    (LogisticRegression(max_iter=1000, random_state=42), "Logistic Regression"),
    (KNeighborsClassifier(n_neighbors=15), "K-Nearest Neighbors"),
    (DecisionTreeClassifier(max_depth=5, random_state=42), "Decision Tree"),
    (SVC(kernel='rbf', gamma=0.5, C=1.0, probability=True, random_state=42), "Support Vector Machine"),
    (MLPClassifier(hidden_layer_sizes=(10, 5), max_iter=1000, random_state=42), "Neural Network")
]

# Plot decision boundaries for each model
for model, name in models:
    plot_decision_boundaries(X_iris, y_iris, model, name)

### 4.4 Visualizing Model Training Progress

Tracking and visualizing model training progress helps identify issues like overfitting and underfitting.

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.callbacks import History
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# Load the breast cancer dataset
data = load_breast_cancer()
X = data.data
y = data.target

# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Standardize the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Create a neural network model
model = Sequential([
    Dense(64, activation='relu', input_shape=(X_train.shape[1],)),
    Dropout(0.2),
    Dense(32, activation='relu'),
    Dropout(0.2),
    Dense(16, activation='relu'),
    Dense(1, activation='sigmoid')
])

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Train the model and store the history
history = model.fit(
    X_train_scaled, y_train,
    epochs=100,
    batch_size=32,
    validation_split=0.2,
    verbose=0
)

# Visualize training history
plt.figure(figsize=(12, 5))

# Plot training & validation accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy', fontsize=14)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.legend(loc='lower right')
plt.grid(True, linestyle='--', alpha=0.7)

# Plot training & validation loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss', fontsize=14)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.legend(loc='upper right')
plt.grid(True, linestyle='--', alpha=0.7)

plt.tight_layout()
plt.show()

# Evaluate the model on the test set
test_loss, test_accuracy = model.evaluate(X_test_scaled, y_test, verbose=0)
print(f"Test accuracy: {test_accuracy:.4f}")
print(f"Test loss: {test_loss:.4f}")

# Make predictions
y_pred_proba = model.predict(X_test_scaled, verbose=0).flatten()
y_pred = (y_pred_proba > 0.5).astype(int)

# Visualize ROC curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('Receiver Operating Characteristic (ROC) Curve', fontsize=16)
plt.legend(loc="lower right")
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

## Practice Problems

Now that you've learned the fundamentals of Matplotlib, try solving these practice problems to test your understanding.

### Problem 1: Custom Visualization

Create a custom visualization that includes:
1. A figure with 2 subplots arranged horizontally
2. In the first subplot, plot a sine wave and a cosine wave with different colors and line styles
3. In the second subplot, create a scatter plot of random data with points colored according to their y-values
4. Add appropriate titles, labels, and a legend to each subplot
5. Customize the appearance with a grid, custom font sizes, and a figure title

In [None]:
# Your solution here
import numpy as np
import matplotlib.pyplot as plt

# Set random seed for reproducibility
np.random.seed(42)

# Create a figure with 2 subplots arranged horizontally
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
fig.suptitle('Custom Visualization for Practice Problem 1', fontsize=16, fontweight='bold')

# First subplot: sine and cosine waves
x = np.linspace(0, 2*np.pi, 100)
y_sin = np.sin(x)
y_cos = np.cos(x)

ax1.plot(x, y_sin, 'b-', linewidth=2, label='sin(x)')
ax1.plot(x, y_cos, 'r--', linewidth=2, label='cos(x)')
ax1.set_title('Sine and Cosine Waves', fontsize=14)
ax1.set_xlabel('x', fontsize=12)
ax1.set_ylabel('y', fontsize=12)
ax1.grid(True, linestyle='--', alpha=0.7)
ax1.legend(fontsize=10)
ax1.set_xlim(0, 2*np.pi)
ax1.set_xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi])
ax1.set_xticklabels(['0', '$\pi/2$', '$\pi$', '$3\pi/2$', '$2\pi$'])

# Second subplot: scatter plot of random data
n = 100
x_random = np.random.rand(n) * 10
y_random = np.random.rand(n) * 10

scatter = ax2.scatter(x_random, y_random, c=y_random, cmap='viridis', 
                     s=100, alpha=0.7, edgecolors='k')
ax2.set_title('Random Data Scatter Plot', fontsize=14)
ax2.set_xlabel('X Value', fontsize=12)
ax2.set_ylabel('Y Value', fontsize=12)
ax2.grid(True, linestyle='--', alpha=0.7)
cbar = plt.colorbar(scatter, ax=ax2)
cbar.set_label('Y Value', fontsize=10)

# Adjust layout
plt.tight_layout(rect=[0, 0, 1, 0.95])  # Make room for the figure title
plt.show()

### Problem 2: Data Visualization Dashboard

Create a visualization dashboard for the Iris dataset that includes:
1. A figure with 4 subplots arranged in a 2x2 grid
2. A scatter plot of sepal length vs. sepal width, colored by species
3. A scatter plot of petal length vs. petal width, colored by species
4. A histogram of sepal length with bins colored by species
5. A box plot showing the distribution of petal width for each species
6. Add appropriate titles, labels, and legends to each subplot

In [None]:
# Your solution here
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris

# Load the Iris dataset
iris = load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names
target_names = iris.target_names

# Create a figure with 4 subplots arranged in a 2x2 grid
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
fig.suptitle('Iris Dataset Visualization Dashboard', fontsize=18, fontweight='bold', y=0.98)

# Define colors for each species
colors = ['navy', 'turquoise', 'darkorange']

# 1. Scatter plot of sepal length vs. sepal width
for i, species in enumerate(target_names):
    species_indices = y == i
    axes[0, 0].scatter(X[species_indices, 0], X[species_indices, 1], 
                      c=colors[i], label=species, s=60, alpha=0.7, edgecolors='k')
    
axes[0, 0].set_title('Sepal Length vs. Sepal Width', fontsize=14)
axes[0, 0].set_xlabel('Sepal Length (cm)', fontsize=12)
axes[0, 0].set_ylabel('Sepal Width (cm)', fontsize=12)
axes[0, 0].legend(fontsize=10)
axes[0, 0].grid(True, linestyle='--', alpha=0.7)

# 2. Scatter plot of petal length vs. petal width
for i, species in enumerate(target_names):
    species_indices = y == i
    axes[0, 1].scatter(X[species_indices, 2], X[species_indices, 3], 
                      c=colors[i], label=species, s=60, alpha=0.7, edgecolors='k')
    
axes[0, 1].set_title('Petal Length vs. Petal Width', fontsize=14)
axes[0, 1].set_xlabel('Petal Length (cm)', fontsize=12)
axes[0, 1].set_ylabel('Petal Width (cm)', fontsize=12)
axes[0, 1].legend(fontsize=10)
axes[0, 1].grid(True, linestyle='--', alpha=0.7)

# 3. Histogram of sepal length
for i, species in enumerate(target_names):
    species_indices = y == i
    axes[1, 0].hist(X[species_indices, 0], bins=10, alpha=0.5, 
                   color=colors[i], label=species, edgecolor='black')
    
axes[1, 0].set_title('Histogram of Sepal Length', fontsize=14)
axes[1, 0].set_xlabel('Sepal Length (cm)', fontsize=12)
axes[1, 0].set_ylabel('Frequency', fontsize=12)
axes[1, 0].legend(fontsize=10)
axes[1, 0].grid(True, linestyle='--', alpha=0.7)

# 4. Box plot of petal width by species
data_to_plot = [X[y == i, 3] for i in range(len(target_names))]
box = axes[1, 1].boxplot(data_to_plot, patch_artist=True, labels=target_names)

# Color the boxes
for patch, color in zip(box['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)
    
axes[1, 1].set_title('Box Plot of Petal Width by Species', fontsize=14)
axes[1, 1].set_xlabel('Species', fontsize=12)
axes[1, 1].set_ylabel('Petal Width (cm)', fontsize=12)
axes[1, 1].grid(True, linestyle='--', alpha=0.7, axis='y')

# Adjust layout
plt.tight_layout(rect=[0, 0, 1, 0.95])  # Make room for the figure title
plt.show()

### Problem 3: Interactive Time Series Visualization

Create an interactive visualization for stock price data that includes:
1. Generate synthetic stock price data for 3 companies over 1 year
2. Create a line plot showing the stock prices over time
3. Add interactive widgets to select which companies to display
4. Add a widget to select the time range (e.g., 1 month, 3 months, 6 months, 1 year)
5. Add annotations for the highest and lowest points in the selected time range

In [None]:
# Your solution here
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, fixed, widgets
from IPython.display import display

# Generate synthetic stock price data
np.random.seed(42)
dates = pd.date_range(start='2023-01-01', end='2023-12-31', freq='B')  # Business days
n_days = len(dates)

# Initial prices
initial_prices = [100, 200, 150]
company_names = ['Company A', 'Company B', 'Company C']

# Generate random walk for stock prices
stock_data = {}
for i, company in enumerate(company_names):
    # Random walk with drift
    returns = np.random.normal(0.0005, 0.01, n_days)  # Mean positive return (upward drift)
    price_series = initial_prices[i] * (1 + returns).cumprod()
    stock_data[company] = price_series

# Create DataFrame
df = pd.DataFrame(stock_data, index=dates)

# Interactive visualization function
def plot_stock_prices(company_a=True, company_b=True, company_c=True, time_range='1 Year'):
    # Filter companies to display
    companies_to_plot = []
    if company_a:
        companies_to_plot.append('Company A')
    if company_b:
        companies_to_plot.append('Company B')
    if company_c:
        companies_to_plot.append('Company C')
    
    if not companies_to_plot:
        print("Please select at least one company to display.")
        return
    
    # Filter time range
    if time_range == '1 Month':
        df_filtered = df.iloc[-20:]  # Approximately 20 business days in a month
    elif time_range == '3 Months':
        df_filtered = df.iloc[-60:]  # Approximately 60 business days in 3 months
    elif time_range == '6 Months':
        df_filtered = df.iloc[-120:]  # Approximately 120 business days in 6 months
    else:  # 1 Year
        df_filtered = df
    
    # Create the plot
    plt.figure(figsize=(14, 8))
    
    for company in companies_to_plot:
        plt.plot(df_filtered.index, df_filtered[company], linewidth=2, label=company)
        
        # Find highest and lowest points
        max_idx = df_filtered[company].idxmax()
        min_idx = df_filtered[company].idxmin()
        max_val = df_filtered[company].max()
        min_val = df_filtered[company].min()
        
        # Annotate highest point
        plt.annotate(f'High: ${max_val:.2f}', 
                     xy=(max_idx, max_val),
                     xytext=(max_idx, max_val + 5),
                     arrowprops=dict(facecolor='black', shrink=0.05, width=1),
                     fontsize=10)
        
        # Annotate lowest point
        plt.annotate(f'Low: ${min_val:.2f}', 
                     xy=(min_idx, min_val),
                     xytext=(min_idx, min_val - 10),
                     arrowprops=dict(facecolor='black', shrink=0.05, width=1),
                     fontsize=10)
    
    plt.title(f'Stock Prices ({time_range})', fontsize=16)
    plt.xlabel('Date', fontsize=12)
    plt.ylabel('Price ($)', fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(fontsize=12)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

# Create interactive widgets
interact(plot_stock_prices, 
         company_a=widgets.Checkbox(value=True, description='Company A'),
         company_b=widgets.Checkbox(value=True, description='Company B'),
         company_c=widgets.Checkbox(value=True, description='Company C'),
         time_range=widgets.Dropdown(
             options=['1 Month', '3 Months', '6 Months', '1 Year'],
             value='1 Year',
             description='Time Range:'
         ));

## Additional Resources

To further enhance your Matplotlib skills, check out these resources:

- [Matplotlib Documentation](https://matplotlib.org/stable/index.html)
- [Matplotlib Gallery](https://matplotlib.org/stable/gallery/index.html)
- [Python Data Science Handbook - Matplotlib Chapter](https://jakevdp.github.io/PythonDataScienceHandbook/04.00-introduction-to-matplotlib.html)
- [Matplotlib Cheat Sheet](https://github.com/matplotlib/cheatsheets)
- [Data Visualization with Matplotlib and Python](https://github.com/pb111/Data-Analysis-and-Visualization-with-Python/blob/master/Matplotlib_Basics.ipynb)
- [Matplotlib Tutorials](https://matplotlib.org/stable/tutorials/index.html)