<link rel="stylesheet" href="berkeley.css">
<h1 class="cal cal-h1">Introduction to Plotting with Plotly – CS 189, Fall 2025</h1>



Building visualization is an important part of machine learning. In this notebook, we will explore how to create interactive visualizations using Plotly, a powerful library for creating dynamic and engaging plots.

Critically, we will learn how to use Plotly to visualize data in a way that helps us debug machine learning algorithms and communicate results effectively.

There are many different visualization libraries available in Python, but Plotly stands out for its interactivity and ease of use. It allows us to create plots that can be easily shared and embedded in web applications.  However, you will likely also encounter Matplotlib and its more friendly Seaborn wrapper. Matplotlib is a more traditional plotting library that is widely used in the Python community. It is a good choice for creating static plots, but it does not have the same level of interactivity as Plotly. Seaborn is a higher-level interface to Matplotlib that makes it easier to create complex visualizations with less code.

We have chosen to prioritize Plotly in this course because we believe it is important to be able to interact with your data as you explore it.  

## Toy Data

Here we will use the [auto-mpg dataset](https://archive.ics.uci.edu/ml/datasets/auto+mpg) from the UCI Machine Learning Repository, which contains information about various cars, including their miles per gallon (MPG), number of cylinders, horsepower, and more. This dataset is commonly used for regression tasks in machine learning.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns

mpg = pd.read_csv("hf://datasets/scikit-learn/auto-mpg/auto-mpg.csv")
mpg['origin'] = mpg['origin'].map({1: 'USA', 2: 'Europe', 3: 'Japan'})
mpg

## Matplotlib and Seaborn

### Matplotlib
Matplotlib is a versatile Python library for creating static, animated, and interactive visualizations. It offers a low-level interface for highly customizable plots, suitable for publication-quality visualizations.

#### Types of Plots:
- Line Plots
- Scatter Plots
- Bar Charts
- Histograms
- Box Plots
- Heatmaps

### Seaborn
Seaborn is a high-level interface built on top of Matplotlib, designed for statistical data visualization. It provides an intuitive interface and aesthetically pleasing default styles, working seamlessly with Pandas DataFrames.

#### Types of Plots:
- Relational Plots (scatter, line)
- Categorical Plots (bar, box, violin, swarm)
- Distribution Plots (histograms, KDE, rug)
- Regression Plots
- Heatmaps

Matplotlib offers more control, while Seaborn simplifies the creation of visually appealing plots.

In [None]:
mpg['make'] = mpg['car name'].str.split(' ').str[0]

yearly_mpg = (
    mpg.groupby(['origin', 'model year', 'make'])
    [['mpg', 'displacement', 'weight']]
    .mean().reset_index()
)

In [None]:
# Line Plot
sns.lineplot(data=yearly_mpg, x='model year', y='mpg', hue='origin', marker='o')
plt.title('Average MPG by Model Year and Origin')
plt.xlabel('Model Year')
plt.ylabel('Miles Per Gallon (MPG)')
plt.legend(title='Origin')
plt.show()

In [None]:
# Scatter Plot
sns.scatterplot(data=mpg, x='weight', y='mpg', hue='origin')
plt.title('MPG vs. Weight by Origin')
plt.xlabel('Weight (lbs)')
plt.ylabel('Miles Per Gallon (MPG)')
plt.legend(title='Origin')
plt.show()

In [None]:
# Bar Chart
mpg.groupby('origin')['mpg'].mean().plot(kind='bar', color=['blue', 'orange', 'green'])
plt.title('Average MPG by Origin')
plt.ylabel('Average MPG')
plt.xlabel('Origin')
plt.show()

In [None]:
# Histogram
sns.histplot(data=mpg, x='mpg', hue='origin', element='step', stat='count', common_norm=False)
plt.title('MPG Distribution by Origin')
plt.xlabel('Miles Per Gallon (MPG)')
plt.ylabel('Count')
plt.show()

In [None]:
# Box Plot
sns.boxplot(data=mpg, x='origin', y='mpg', hue='origin', palette='Set2')
plt.title('MPG Distribution by Origin')
plt.xlabel('Origin')
plt.ylabel('Miles Per Gallon (MPG)')
plt.show()

In [None]:
# Heatmap
corr = mpg[['mpg', 'cylinders', 'displacement', 'weight', 'acceleration']].corr()
sns.heatmap(corr, annot=True, cmap='coolwarm', fmt='.2f')
plt.title('Correlation Heatmap')
plt.show()

In [None]:
sns.scatterplot(data=mpg, x='weight', y='mpg', hue='origin', size='cylinders')
plt.title('MPG by Weight and Origin')
plt.xlabel('Weight (lbs)')
plt.ylabel('Miles Per Gallon (MPG)')
plt.show()

## Value of Interactive Visualizations

Static Visualizations are great for presenting results, but they can be limiting when it comes to exploring data. Interactive visualizations allow us to:
- **Zoom and Pan**: Focus on specific areas of the plot.
- **Hover for Details**: Get more information about specific data points.
- **Filter Data**: Select subsets of data to visualize.
- **Change Parameters**: Adjust parameters dynamically to see how they affect the visualization.
These features make it easier to understand complex datasets and identify patterns or anomalies.

## Creating an Interactive Scatter Plot with Plotly Express

This code creates an **interactive scatter plot** using the `plotly.express` library. This plot explores the relationship between car weight and fuel efficiency, considering origin and cylinder count.
Key components:

1. **`px.scatter`**: Generates a scatter plot to visualize relationships between two numerical variables.
2. **Parameters**:
    - **`mpg`**: Dataset containing car information.
    - **`x='weight'`**: X-axis represents car weight.
    - **`y='mpg'`**: Y-axis represents miles per gallon.
    - **`color='origin'`**: Groups points by car origin.
    - **`size='cylinders'`**: Marker size reflects the number of cylinders.
    - **`size_max=12`**: Limits marker size.
    - **`hover_data=mpg.columns`**: Displays all dataset columns on hover.
    - **`title='MPG vs. Weight by Origin'`**: Adds a plot title.
    - **`labels={'weight': 'Weight (lbs)', 'mpg': 'Miles Per Gallon (MPG)'}`**: Customizes axis labels.
    - **`width=800, height=600`**: Sets plot dimensions.

In [None]:
import plotly.express as px

px.scatter(mpg, x='weight', y='mpg', color='origin', 
           size='cylinders', size_max=12,
           hover_data=mpg.columns,
           title='MPG vs. Weight by Origin',
           labels={'weight': 'Weight (lbs)', 'mpg': 'Miles Per Gallon (MPG)'},
           width=800, height=600)

We can even make interactive visualizations that allow us to explore how a model's predictions change over time or with different parameters. The code you provided creates an **animated scatter plot** using Plotly Express. Here's a pedagogical explanation of the code:

### Key Components:

2. **Parameters**:
    - **`hover_data=yearly_mpg.columns`**: Displays all columns of the `yearly_mpg` DataFrame when hovering over a point.
    - **`animation_frame='model year'`**: Animates the plot over the `model year` column, showing changes over time.
    - **`animation_group='make'`**: Groups points by the car's make to track them across animation frames.

3. **`fig.update_layout`**:
    - **`xaxis_title` and `yaxis_title`**: Sets the axis titles.
    - **`xaxis_range` and `yaxis_range`**: Defines the range of the x and y axes.
    - **`legend_title_text`**: Sets the title for the legend.

In [None]:
fig = px.scatter(yearly_mpg, x='weight', y='mpg', color='origin',
                 hover_data=yearly_mpg.columns,
                 animation_frame='model year', animation_group='make', 
                 title='MPG vs. Weight by Origin',
                 labels={'weight': 'Weight (lbs)', 'mpg': 'Miles Per Gallon (MPG)'},
                 width=800, height=600)
fig.update_layout(
    xaxis_title='Weight (lbs)',
    yaxis_title='Miles Per Gallon (MPG)',
    xaxis_range=[1500, 5000],
    yaxis_range=[10, 50],
    legend_title_text='Origin',
)
fig.show()

## Three Modes for Plotly 

There are three modes for using Plotly that we will explore in this course:

1. **Pandas Plotting**: A convenient interface for creating Plotly visualizations directly from Pandas DataFrames. It simplifies the process of generating plots from tabular data.
2. **Plotly Express**: A high-level interface for creating plots with minimal code. It is ideal for quick visualizations and exploratory data analysis.  This is similar to the Seaborn interface for Matplotlib, which is a higher-level interface to Matplotlib that makes it easier to create complex visualizations with less code. Like Pandas Plotting, it is designed to work seamlessly with Pandas DataFrames and provides a simple API for creating a wide range of visualizations.
3. **Graph Objects**: A more flexible and powerful interface that allows for fine-grained control over the appearance and behavior of plots. It is suitable for creating complex visualizations and custom layouts.


## Using Pandas Plotting

To use Plotly in Pandas, you need to set the plotting backend to Plotly. This allows you to use the `plot` method on Pandas DataFrames to create interactive plots.


In [None]:
pd.set_option('plotting.backend', 'plotly')

Now we can use the `plot` method on our `DataFrame` to create interactive plots. 

In [None]:
mpg.plot(
    kind='scatter',
    x='weight', y='mpg', color='origin', size='cylinders',
    title='MPG vs. Weight by Origin',
    width=800, height=600)

Notice how we specify the `kind` of plot as well as how the data should be mapped to the axes, color, and size.  This is an interactive plot so you can mouse over the points to see more information. You can double and tripple click on the legend to hide and show different series. You can also zoom in and out of the plot by clicking and dragging on the plot area. Here we also set the width and height of the plot to make it larger and more readable.


All the basic plotting functions in Pandas and plotly express return `Figure` objects, which can be further customized using the methods available in the `plotly.graph_objects` module.

In [None]:
fig = mpg.plot(
    kind='scatter',
    x='weight', y='mpg', color='origin', size='cylinders',
    title='MPG vs. Weight by Origin',
    width=800, height=600)

# change to the style
fig.update_layout(template='plotly_dark')
# fig.update_layout(template='plotly_white')
# fig.update_layout(template='ggplot2')
# fig.update_layout(template='seaborn')
fig.update_layout(xaxis_title='Weight (lbs)',
                  yaxis_title='Miles per Gallon (MPG)',
                  legend_title='Origin')
fig.show()

We can also save plots to HTML files, which can be shared and embedded in web applications. This is useful for creating interactive reports and dashboards.



In [None]:
fig.write_html('mpg_scatter.html', include_plotlyjs='cdn')
fig.write_image('mpg_scatter.png', scale=2, width=800, height=600)
fig.write_image('mpg_scatter.pdf', scale=2, width=800, height=600)

The figure object is made of two key components: the data and the layout. The data is a list of traces, which are the individual plots that make up the figure. The layout is a dictionary that contains information about the appearance of the plot, such as the title, axis labels, and legend.

In [None]:
display(fig.data)
display(fig.layout)

## Using Plotly Express

Plotly express is closely related to Pandas Plotting, but it is a separate library that provides a high-level interface for creating plots. It is designed to work seamlessly with `pandas` `DataFrames` and provides a simple API for creating a wide range of visualizations. Plotly express offers more flexibility and customization options than `pandas` plotting, making it a powerful tool for creating complex visualizations.

In [None]:
import plotly.express as px

In [None]:
px.scatter(mpg, x='weight', y='mpg', color='origin', size='cylinders',
           title='MPG vs. Weight by Origin',
           width=800, height=600, 
           template='plotly_dark')

Just as before we get back a `Figure` object that we can further customize.

In [None]:
fig = px.scatter(mpg, x='weight', y='mpg', color='origin', size='cylinders',
                 title='MPG vs. Weight by Origin',
                 width=800, height=600, 
                 template='plotly_dark')
# change the marker symbol for the USA trace
fig.update_traces(marker=dict(symbol="square"), selector=dict(name="USA")) 
# you can also just modify the data dictionary directly
#fig.data[0]['marker']['symbol'] = "square"

# change formatting (layout) of the figure
fig.update_layout(font=dict(family="Courier New, monospace", size=16))
# You can also refer to the font family and size directly
fig.update_layout(font_family="Courier New, monospace", font_size=16)
fig

## Using Plotly Graphics Objects

The Graphics objects are a more flexible and powerful interface that allows for fine-grained control over the appearance and behavior of plots. It is suitable for creating complex visualizations and custom layouts.


In [None]:
from plotly import graph_objects as go

In [None]:
fig = go.Figure()
max_size = 20

# Iterate over unique origins and create a scatter trace for each
for i, origin in enumerate(mpg['origin'].unique()):
    # Filter the DataFrame for the current origin
    subset = mpg[mpg['origin'] == origin]
    marker_sizes = max_size*subset['cylinders']/subset['cylinders'].max()
    # Create a hover text for each point
    hover_text = (
            subset['origin'] + "<br>"
                  "Weight: " + subset['weight'].astype(str) + "<br>"
                  "MPG: " + subset['mpg'].astype(str) + "<br>"
                  "Cylinders: " + subset['cylinders'].astype(str))
    # add a trace to the figure
    fig.add_trace(
        go.Scatter(
            x=subset['weight'], y=subset['mpg'],
            mode='markers',
            name=origin,
            marker=dict(size=marker_sizes, color=i),
            text=hover_text,
        )
    )
fig.add_annotation(
    text="Data source: Auto MPG dataset",
    xref="paper", yref="paper",
    x=0, y=-0.1,
    showarrow=False,
    font=dict(size=12, color="gray")
)
fig.update_layout(
    title='MPG vs. Weight by Origin',
    xaxis_title='Weight (lbs)',
    yaxis_title='Miles per Gallon (MPG)',
    width=800, height=600,
    template='plotly_white',
    font_family="Times", font_size=16,
)
fig.show()

## Visualizing Different Kinds of Data

Now that we have seen the basics of using Plotly, let's explore how to visualize different kinds of data.


### Histograms

In [None]:
px.histogram(mpg, x='mpg', facet_row='origin')

In [None]:
mpg.hist(x='mpg', color='origin', bins=10, barmode='overlay')

In [None]:
fig = mpg.hist(x='mpg', color='origin', bins=10, facet_row='origin',
         title='MPG Distribution by Origin',
         width=800, height=600)
fig

In [None]:
mpg['make'] =mpg['car name'].str.split(' ').str[0]
mpg.plot(kind='bar',
         x='make', color='origin', 
         hover_data=['mpg', 'cylinders', 'car name'],
         title='Average MPG by Make and Origin',
         width=800, height=600)

## Scatter and Line Plots

In [None]:
yearly_mpg = (
    mpg
    .groupby(['origin', 'model year'])
    [['mpg', 'displacement', 'weight']]
    .mean().reset_index()
)
yearly_mpg.head()

In [None]:
yearly_mpg.plot(kind='line', 
                x='model year', y='mpg', color='origin',
                markers=True,
                title='Average MPG by Model Year and Origin',
                width=800, height=600)


In [None]:
px.line(yearly_mpg, x='model year', y='mpg', color='origin',
        markers=True,
        title='Average MPG by Model Year and Origin',
        width=800, height=600)