**Author:** Shahab Fatemi

**Email:** shahab.fatemi@umu.se   ;   shahab.fatemi@amitiscode.com

**Created:** 2025-06-03

**Last update:** 2025-08-24

**MIT License** — Shahab Fatemi (2025); For use in the *Machine Learning in Physics* course, Umeå University, Sweden; See the full license text in the parent folder.

<hr>

# Data Visualization

Data Visualization is a very important topic in data analysis.

In this notebook, we will explore the basics of data plotting using `Matplotlib` and `Seaborn`, two of the most widely used libraries for data visualization in Python.

## Matplotlib

### Introduction to Matplotlib

Matplotlib is a powerful library for creating static, animated, and interactive visualizations in Python. Similar to other libraries, we need to import it first.

```python
import matplotlib.pyplot as plt
```

### Basic line plot

In MATLAB:
```matlab
    >> x = linspace(0, 10, 100);  % 100 points from 0 to 10
    >> y = sin(x);                % Sine function
    >> plot(x, y);
    >> xlabel('X');
    >> ylabel('Y');
    >> title('Sine Wave');
    >> grid on;
```

In Python:

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Generating sample data
x = np.linspace(0, 10, 100)  # 100 points from 0 to 10
y = np.sin(x)  # Sine function

plt.plot(x, y)   # Plot the sine wave
plt.xlabel('X')  # X-axis label
plt.ylabel('Y')  # Y-axis label
plt.title('Sine Wave')  # Title of the plot
plt.grid(True)  # Add grid
plt.show()      # Display the plot

### Scatter plot

In [None]:
plt.scatter(x, y, color='blue', alpha=0.7)  # alpha is between 0 and 1, used for transparency
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Scatter plot')
plt.grid(True)
plt.show()

### `imshow` for images and 2D data

In [None]:
# Generate a 2D array of random values
image_data = np.random.rand(10, 10)

# Display data
plt.imshow(image_data, cmap='viridis', interpolation='nearest')
plt.colorbar()  # Add color bar
plt.title('Random image for MLP')
plt.show()

### Customizing figure size

In [None]:
plt.figure( figsize=(8, 4) )  # Width:8, Height:4
plt.scatter(x, y, color='tomato')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Scatter plot')
plt.grid(True)
plt.show()

### Figure display resolution

For publications, we should prepare high-quality figures. For that, we use `dpi`.

In [None]:
plt.figure( figsize=(8, 4) , dpi=200)  # Width:8, Height:4 with 200 DPI 
plt.scatter(x, y, color='tomato')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Scatter plot')
plt.grid(True)
plt.show()

### Save figures

In [None]:
plt.figure( figsize=(8, 4) )  # Width:8, Height:4
plt.scatter(x, y, color='tomato')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Scatter plot')
plt.grid(True)
plt.savefig('wave_test.png', dpi=300)  # Save as PNG with DPI 300

### Subplots 
They are used to display multiple plots in a single figure.

In [None]:
# Create subplots
fig, axs = plt.subplots(2, 2, figsize=(10, 8))  # 2 rows, 2 columns

# 1st subplot
axs[0, 0].plot(x, y)
axs[0, 0].set_title('Sine Wave')

# 2nd subplot
axs[0, 1].scatter(x, y)
axs[0, 1].set_title('Scatter plot')

# 3rd subplot
axs[1, 0].imshow(image_data, cmap='magma')
axs[1, 0].set_title('Random Image')

# 4th subplot
axs[1, 1].hist(y, bins=10, color='orange')
axs[1, 1].set_title('Histogram')

plt.tight_layout()  # Adjust figure layout
plt.show()

### Labels and changing font size

In [None]:
plt.figure()
plt.plot(x, y)
plt.xlabel('X', fontsize=12)
plt.ylabel('Y', fontsize=12)
plt.title('Sine Wave', fontsize=14, weight='bold')
plt.grid(True)
plt.show()

### Custom setting

Matplotlib provides extensive options for customizing your plots. This includes adjusting the appearance of lines and markers and more.

In the example below, I am plotting solution to a function y(t).

In [None]:
# Computes displacement: y(t) = A * exp(-d * t) * cos(omega * t)
def damped_oscillator(t, A, d, omega):
    """
    THESE LINES ARE COMMENTS
    t     : time array
    A     : Amplitude
    d     : damping coefficient
    omega : angular frequency
    """
    return A * np.exp(-d * t) * np.cos(omega * t)

t  = np.linspace(0, 5*np.pi, 100)
y1 = damped_oscillator(t, A=1, d=0.1, omega=1)
y2 = damped_oscillator(t, A=0.5, d=0.2, omega=2)

# Create a line plots
plt.plot(t, y1, linestyle='-.' , linewidth=2.5, label="Wave 1")  # Line width of 2.5
plt.plot(t, y2, linestyle='--' , linewidth=1.0, label="Wave 2")  # Line width of 1.0
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()     # Add a legend
plt.grid(True)   # Add a grid
plt.show()       # Display the plot

### Customizing marker size

When creating scatter plots or adding markers to line plots, you can customize the size of the markers using the markersize parameter.

In [None]:
# Create a scatter plot with custom marker size
plt.scatter(t, y1, marker="o", color="royalblue"  , s=70, alpha=0.7, edgecolor="black", label="Wave 1")
plt.scatter(t, y2, marker="x", color="forestgreen", s=80, alpha=0.6, edgecolor="black", label="Wave 2")
plt.xlabel("X axis")
plt.ylabel("Y axis")
plt.legend()
plt.grid(True)
plt.show()

### Importance of Alpha in Data Visualization
It is important to use `alpha < 1` in your data visualization to ensure that the data is clearly presented. When you use `alpha < 1`, you can also see the patterns more clearly, especially when the data points are overlapping and some are hiding behind others. 

See it in an example:
First, I generate data using `make_blobs` from `sklearn.datasets`. You have not heard about them yet, so just accept it as it is for now.

In [None]:
from sklearn.datasets import make_blobs

# Generate synthetic data 
X_blobs, y_blobs = make_blobs(n_samples=500, centers=1, random_state=42, cluster_std=0.1)

Plot data without alpha blending:

In [None]:
# Visualize the blobs
plt.figure(figsize=(4, 4), dpi=200)
plt.scatter(X_blobs[:, 0], X_blobs[:, 1], s=70, color='royalblue', alpha=1.0, edgecolor='k')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Without alpha blending')
plt.grid(True)
plt.show()

With alpha blending:

In [None]:
# Visualize the blobs
plt.figure(figsize=(4, 4), dpi=200)
plt.scatter(X_blobs[:, 0], X_blobs[:, 1], s=70, color='royalblue', alpha=0.3, edgecolor='k')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('With alpha blending')
plt.grid(True)
plt.show()

With alpha blending you can see the overlap between points more clearly, which can be useful for understanding the density of the data.

### Advanced data generation and visualization

#### Visualizing a 2D Gaussian distribution

The code below creates a 2D Gaussian distribution surface. We define the mean (`mu`) and covariance (`Sigma`) of the distribution, then build a grid of `(x, y)` points using `meshgrid`. These points are reshaped into a list of coordinates to evaluate the multivariate normal probability density function (`mvnpdf`). The resulting values (`Z`) are reshaped back into a grid so they match the original `X, Y` layout. 

In MATLAB
```matlab
    >> % Mean and covariance
    >> mu = [0 0];
    >> Sigma = [1 0.5; 0.5 1];

    >> % Grid of points
    >> x = linspace(-3, 3, 100);
    >> y = linspace(-3, 3, 100);
    >> [X, Y] = meshgrid(x, y);
    >> pos = [X(:), Y(:)];

    >> % Evaluate PDF
    >> Z = mvnpdf(pos, mu, Sigma);
    >> Z = reshape(Z, size(X));

    >> surf(X, Y, Z);
    >> xlabel('X');
    >> ylabel('Y');
    >> zlabel('Probability Density');
```

In Python:


In [None]:
from scipy.stats import multivariate_normal

# Mean vector and covariance matrix
mean = [0, 0]
covariance = [[1, 0.5], [0.5, 1]]  # Covariance matrix

# Create a grid of (x, y) coordinates
x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x, y)

# Calculate the Gaussian distribution
pos = np.dstack((X, Y))
Z = multivariate_normal(mean, covariance).pdf(pos)

# Plot data
plt.figure(figsize=(10, 6))
contour = plt.contour(X, Y, Z, levels=10, cmap="viridis")
plt.xlabel("X")
plt.ylabel("Y")
plt.colorbar(contour, label="Probability Density")
plt.title("Gaussian Contour Plot")
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
filled_contour = plt.contourf(X, Y, Z, levels=10, cmap="inferno")
plt.xlabel('X')
plt.ylabel('Y')
plt.colorbar(filled_contour, label='Probability Density')
plt.title('Gaussian Contour Plot')
plt.grid(True)
plt.show()

### Matplotlib Gallery

For more examples, visit https://matplotlib.org/stable/gallery/index.html 

***
## Data Visualization with Seaborn

Seaborn is a Python visualization library based on Matplotlib that provides high level interface for drawing attractive statistical graphics. This notebook will guide you in Seaborn.

We need data to visualize, so we load one of the pre-made datasets in Seaborn.
https://github.com/mwaskom/seaborn-data

I've selected the "Tips" dataset. 

According to : 
https://www.geeksforgeeks.org/data-science/seaborn-datasets-for-data-science/#1-tips-dataset

The Tips dataset contains information about tips received by waitstaff in a restaurant. It's commonly used for regression and exploratory data analysis (EDA). The dataset includes features such as total bill amount, tip amount, gender of the person paying the bill, whether the person is a smoker, day of the week, time of day, and size of the party.

- Advantages: Simple and intuitive, good for demonstrating basic statistical analysis and visualization.
- Disadvantages: Small size limits complexity of analyses, limited to restaurant tipping context.

Features and Characteristics:
   * total_bill: Total bill amount (numerical)
   * tip: Tip amount (numerical)
   * sex: Gender of the person paying the bill (categorical)
   * smoker: Whether the person is a smoker (categorical)
   * day: Day of the week (categorical)
   * time: Time of day (Lunch/Dinner) (categorical)
   * size: Size of the party (numerical)

In [None]:
import seaborn as sns

# Load the tips dataset
tips = sns.load_dataset("tips")

# Display the first few rows of the dataset
tips.head()

### ⚠️ Note:
In Python, both ' (single quotes) and " (double quotes) are valid ways to write strings. It is just a matter of style. Again, due to my C/C++ background, I feel more comfortable using double quotes.

### Basic scatter plot

We make a scatter plot using a table of data called `tips`.
We want to show the relationship between:
- x: how much people paid in total (total_bill) and 
- y: how much they gave as a tip (tip)

- Each dot is one customer.

We also want to:
- Color the dots by the day (Thursday, Friday, etc.)
- Change the shape of the dots depending on whether it was lunch or dinner.

So, the plot shows:
- Do people give more tip when they pay more?
- Does this change on different days or at different times?

In the code below, each instance of the scatter function is set to a column of the dataset. For example,
 - `x='total_bill'` uses the 'total_bill' column for the x-axis
 - `y='tip'` uses the 'tip' column for the y-axis
 - `hue='day'` colors the points by the 'day' column
 - `style='time'` changes the marker style by the 'time' column
 - `size='size'` changes the marker size by the 'size' column

In [None]:
plt.figure(figsize=(6, 4), dpi=200)
# Scatter plot
sns.scatterplot(data=tips, x="total_bill", y="tip",
                hue="day", style="time", size="size",
                palette="deep")  # Set color palette. Other palette options are 'pastel', 'dark', 'colorblind', etc.

# Move legend outside the plot
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0)
plt.tight_layout()  # Avoid cutting off labels
plt.show()

***
### ✅ Check your understanding
Study the figure above and make sure you understand all aspects of the plotted data. What does the figure show? Discuss with your friends. If you have any questions, please do not hesitate to ask.
***


### Box plot
A box plot helps us understand the distribution of a dataset. 
It gives us six important summary statistics: 
- the minimum value, 
- first quartile (25th percentile), 
- median (50th percentile), 
- third quartile (75th percentile),
- maximum value, and
- outliers.

In [None]:
# Box plot
sns.boxplot(data=tips, x="day", y="total_bill", hue="day", palette="Set2")
plt.title("Total bill by day")
plt.show()

In [None]:
# Strip plot
sns.stripplot(data=tips, x="day", y="total_bill", hue="day", jitter=True, palette="Set2")
plt.title("Strip plot of total bill by day")
plt.show()

***
### ✅ Check your understanding
Study and compare the last two figures. What do they show? Make sure you understand all aspects of the plotted data.
***


In [None]:
# Create a pivot table
pivot_table = tips.pivot_table(values="tip", index="day", columns="time", aggfunc="mean", observed=True)

# Heatmap
plt.figure(dpi=200)
sns.heatmap(pivot_table, annot=True, cmap="Spectral_r")  # I reversed the color palette by using '_r' at the end of its name.
plt.title("Heatmap of Average tips by day and time")
plt.show()

In [None]:
# Pair plot
g = sns.pairplot(tips, hue="day", palette="colorblind")

# Add a title to the whole figure
g.figure.suptitle("Pair plot of Tips data", y=1.05)

# Set DPI
g.figure.set_dpi(200)

***
### ✅ Check your understanding
Carefully study the last two figures. What do they show? Make sure you understand all aspects of the plotted data. If you have any questions, please do not hesitate to ask. 

***
END
***