# Introduction to data analysis in Python: Week 2

This week, we're going to be using the [Palmer Penguins](https://allisonhorst.github.io/palmerpenguins/articles/intro.html) dataset. This dataset contains size measurements for three penguin species observed on three islands in the Palmer Archipelago, Antarctica. The data was collected from 2007-2009 by Dr. Kristen Gorman with the Palmer Station Long Term Ecological Research Program.

The dataset contains the following columns:

1. **Species:** The dataset includes three species of penguins: Adelie, Chinstrap, and Gentoo. Each species has distinct physical characteristics and behaviors.
2. **Island:** Penguins in the dataset were observed on three different islands: Biscoe, Dream, and Torgersen. These islands are part of the Palmer Archipelago in Antarctica.
3. **Bill Length (mm):** Length of a penguin's bill in millimeters. Bill length can vary significantly between species and is an important characteristic for identifying species.
4. **Bill Depth (mm):** Depth of a penguin's bill in millimeters. Bill depth, like bill length, is an important characteristic for species identification.
5. **Flipper Length (mm):** Length of a penguin's flipper in millimeters. Flipper length is related to a penguin's swimming ability and varies between species.
6. **Body Mass (g):** Body mass of a penguin in grams. Body mass can provide insights into the health and nutrition of the penguins.
7. **Sex:** Male or female.
8. **Year:** The year when the observation was made.

## Getting set up

As with last week, the first thing we need to do is import the packages we'll be using and then read in our dataset. This week, since we're moving on to data visualisation, we'll also need to load a couple of additional packages: ``seaborn`` is a wrapper on ``matplotlib`` which makes it easy to work with pandas dataframes, but we still need to import matplotlib to get access to some basic functions.

In [None]:
# Import the packages
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# Read in the data
data = pd.read_csv("penguins.csv")

In [None]:
# Visually inspect the data
data.head(10)

In [None]:
# Get a quick summary of the dataframe
data.info()

In [None]:
data.isnull().sum()

Ok, so it doesn't look like we've got much missing data in this dataset, that's nice! But I can see one thing here I want to deal with before I start my analysis: the ``rowid`` column is stored as an integer, but ID columns should really be categorical (since it would be pretty meaningless to calculate any statistics on them, like a mean!). Here's an easy way to convert it:

In [None]:
data["rowid"] = data["rowid"].astype("category")

You can also decide whether you'd like to do anything with the missing data -- it might not be a terrible idea to just drop all rows with any missing data in this case since there's only 11 rows in that category. But it might also depend on whether you're actually interested in sex, which is where most of the missing data is -- no point throwing away rows with data you do care about if you don't care about sex!

In [None]:
# Deal with missing data as desired

Finally, you might want to perform some data cleaning here, like sending all the text columns to lowercase and stripping whitespace. Refer back to the Week 1 notebook to remind you how to perform these steps!

In [None]:
# Carry out any necessary data cleaning

## Calculating basic summary statistics

Pandas provides several methods to generate descriptive statistics that summarise the central tendency, dispersion, and shape of a dataset's distribution.

The ``describe()`` function generates summary statistics for numerical columns in the dataset, including count (the number of rows with a non-null value), mean, standard deviation, minimum and maximum values, as well as the 25th, 50th (median), and 75th percentiles.

In [None]:
# Generate summary statistics for numerical columns
data.describe()

We can also give ``describe()`` an additional argument to generate a comprehensive overview of the whole dataset, including summary statistics for both numerical and categorical columns.

In [None]:
data.describe(include="all")

We can also calculate any of the statistics included in these tables individually. The next cell calculates the mean (average) value for each numerical column in the dataset. The mean provides a measure of the central tendency of the data.

In [None]:
data.mean(numeric_only=True)

We can also call functions like these on a specific column or columns, using the syntax we learned last week to select subsets of columns.

In [None]:
data[["bill_length_mm", "flipper_length_mm"]].mean()

The `median()` function calculates the median value for each numerical column in the dataset. The median is the middle value when the data is sorted, providing another measure of central tendency that is less sensitive to outliers than the mean.

In [None]:
data.median(numeric_only=True)

The next cell computes the standard deviation for each numerical column in the dataset. The standard deviation measures the amount of variation or dispersion in the data, indicating how spread out the values are from the mean.

In [None]:
data.std(numeric_only=True)

The ``min()`` and ``max()`` functions calculate the minimum and maximum value for each numerical column. These functions help identify the range of the data.

In [None]:
data.min(numeric_only=True)

In [None]:
data.max(numeric_only=True)

The next cell calculates the variance for each numerical column in the dataset. Variance measures the average squared deviation of each number from the mean, providing another way to understand data dispersion.

In [None]:
data.var(numeric_only=True)

We can also calculate summary statistics for categorical columns. A very common one is ``value_counts()``, which calculates how often each unique value in a given column occurs.

In [None]:
data["species"].value_counts()

We can also add an extra argument to this function to get normalised value counts, which are proportions of the total count rather than raw counts.

In [None]:
data["species"].value_counts(normalize=True)

This is a fairly simple dataset, but if you were working with data where your column of interest had many unique values, it might be helpful to sort by frequency to see at a glance which value has the most/least observations (change ``ascending=False`` to ``True`` in the following cell to see the least common species first). By default, ``value_counts()`` will ignore missing data, but if you want to include a count of missing rows, you can add ``dropna=False`` as an argument.

In [None]:
data["species"].value_counts().sort_values(ascending=False)

You can also get a list of the unique values in a given column, as follows:

In [None]:
list(data["species"].unique())

Obviously in this case you can see at a glance how many different species there are, but again, if you were working with a more complex dataset, you might need to extract that information with code.

In [None]:
data["species"].nunique()

Think about any other statistics you might like to calculate on a particular column and have a go at looking up the functions you'd need for those!

Last week, we saw that we could use filtering to focus on specific parts of a dataset which meet certain criteria. We can combine this logic with the summary statistics we've learned above. For example, maybe we only want to know about flipper length for Gentoo penguins.

In [None]:
data[data["species"] == "Gentoo"]["flipper_length_mm"].mean()

Have a play with some more complex filtering operations and calculate summary statistics on the resulting output.

## Grouping data

Grouping is a super common operation in data analysis. In short, grouping by a variable means collapsing a dataframe down to one row per unique value of that variable. For example, if we group by species in the penguins dataset, we would be left with 3 rows. We can also group by multiple variables, which will collapse the dataframe down to one row per unique combination of all values of all variables. For example, if we grouped by species and sex, we would be left with 6 rows (3 species x 2 sexes). Once we have our groups, we can calculate statistics with them in the same way as we did above. For example:

In [None]:
data.groupby("species").mean(numeric_only=True)

In [None]:
data.groupby(["species", "sex"]).mean(numeric_only=True)

We can also calculate statistics for just one column or a subset of columns, as follows:

In [None]:
# Calculate the mean of a single column
data.groupby("species")["bill_length_mm"].mean()

In [None]:
# Calculate the mean of a subset of columns
data.groupby("species")[["bill_length_mm", "flipper_length_mm"]].mean()

We can also use the ``agg()`` function to apply multiple aggregate functions to the grouped data, as follows:

In [None]:
data.groupby("species").agg({
    "bill_length_mm": "mean",
    "flipper_length_mm": "max",
    "body_mass_g":"std"
})

The ``transform()`` function lets us apply a function to each group and return a transformed version of the data. For example, we might want to measure body mass for each individual in terms of deviation from the mean for that penguin's species -- this is also known as "centering". It puts penguins with an average body mass for their species at zero, those who are heavier than average above zero, and those who are lighter than average below zero.

In [None]:
data["normalised_body_mass"] = data.groupby("species")["body_mass_g"].transform(lambda x: (x - x.mean()) / x.std())
data.head()

Try grouping by some different variables and calculating some other summary statistics. Also, as an exercise, think about what would happen if you tried grouping by a numerical column and why this might not be a great idea!

## Cross-tabs and pivot tables

Cross-tabulations and pivot tables allow us to summarise the data in a matrix format, providing insights into relationships between variables.

We can use the the ``crosstab()`` function to create a cross-tabulation of the species and island columns. The resulting table shows the frequency of each species on each island. The `margins=True` parameter in the `crosstab()` function adds row and column totals to the cross-tabulation. This helps in understanding the overall distribution and totals for each category -- see what it looks like if you remove it!

In [None]:
pd.crosstab(data['species'], data['island'], margins=True)

We can also normalise cross-tabs to show proportions instead of raw counts. ``normalize="index"`` scales the counts to proportions within each row. You can also use ``normalize="columns"`` or ``normalize="all"``.

In [None]:
pd.crosstab(data['species'], data['island'], margins=True, normalize="index")

The `pivot_table()` function is used to create a basic pivot table with species as the index and island as the columns. The values in the following pivot table are the average body mass (body_mass_g) for each combination of species and sex, but you can try passing a different statistic to ``aggfunc``.

In [None]:
data.pivot_table(values='body_mass_g', index='species', columns='sex', aggfunc='mean')

You can also pass a list of aggregation functions to ``aggfunc``, as follows:

In [None]:
data.pivot_table(values='body_mass_g', index='species', columns='sex', aggfunc=['mean', 'median'])

Finally, pivot tables can also have multiple index columns, as follows:

In [None]:
data.pivot_table(values='body_mass_g', index=['species', 'island'], columns='sex', aggfunc=['mean', 'median'])

Try making some other cross-tabs and pivot tables for different variables and statistics.

## Data visualisation

Ok, we've looked at a lot of tables and other text-based methods of representing our data -- now let's see how to summarise our data more visually! For this, we're going to use the ``seaborn`` library which we loaded at the top of this notebook. The documentation for this library has a really nice [tutorial section](https://seaborn.pydata.org/tutorial.html) which is worth checking out.

**A confession and a disclaimer:** I don't really get on with Python for data visualisation. I find ggplot in R much more intuitive, simple and easily customisable, so I almost never make visualisations that will actually go into my talks/papers in Python -- I generally just use it for quick and dirty plots to help me make sense of data that I'm working with on the fly. That being said, I learned base ``matplotlib`` when I first started doing data analysis in Python, and the ``seaborn`` package we're going to use here is definitely a lot more user friendly than that, so you might quite like it!

First, I'm going to add a column to the dataframe to assign a colour to each species because I like the colour palette used in the [Palmer Penguins R package](https://allisonhorst.github.io/palmerpenguins/articles/intro.html). We can then use this palette to colour the data in our plots. If you don't like this palette, you can specify a different one! There's a loooong [list of named colours](https://matplotlib.org/stable/gallery/color/named_colors.html) available in matplotlib.

In [None]:
# Create the colour mapping
colour_palette = {
    "Adelie": "darkorange",
    "Chinstrap": "purple",
    "Gentoo": "cadetblue"
}

First, let's just make a histogram to see how many penguins we have of each colour. There's a bunch of arguments to the ``sns.histplot()`` function below -- the first few should be obvious, but try playing around with the others to see what they do!

In [None]:
sns.histplot(data=data, x="species", hue="species", palette=colour_palette, shrink=0.8, alpha=1, edgecolor="none", legend=False)

We might also want to see how flipper length and bill length are related: presumably penguins with bigger flippers are just bigger overall, and therefore also have bigger bills, so there should be a positive correlation?

In [None]:
sns.scatterplot(data=data, x="flipper_length_mm", y="bill_length_mm", color="black")

Now let's add some more nuance to this, by colouring the points according to penguin species. We can also use some basic matplotlib functions to customise the axis and legend labels and add a title. You can use any number of these commands as long as they're in the same cell as the line of code that creates the plot they apply to.

In [None]:
sns.scatterplot(data=data, x="flipper_length_mm", y="bill_length_mm", hue="species", palette=colour_palette)
plt.xlabel("Flipper length (mm)")
plt.ylabel("Bill length (mm)")
plt.title("Flipper length vs. bill length for three penguin species")
plt.legend(title="Species")

One thing to notice is that when I wanted all the points to be the same colour, I did this with ``color="black"``, but when I wanted to colour the points according to a variable in the dataframe, the argument was called ``hue`` instead. Mixing these two things up is a common cause of error messages when working with colour!

You can also do scatter plots with line(s) of best fit. To fit a single line of best fit to all the data, use ``sns.regplot``. To fit a different line to different groups, use ``sns.lmplot``.

In [None]:
sns.regplot(data=data, x="flipper_length_mm", y="bill_length_mm", color="black")

In [None]:
sns.lmplot(data=data, x="flipper_length_mm", y="bill_length_mm", hue="species", palette=colour_palette, legend=False)

Both of these functions will fit a linear regression line by default, but you can specify different kinds of fit (like quadratic) with the ``order`` parameter, or you can specify ``logistic=True`` to get a sigmoid curve. We won't go through these more advanced options today -- in this case, linear looks good! 

Now let's have a look at a different way of visualising flipper length, as a distribution. In the next cell, the ``multiple=layer`` argument tells Seaborn what to do with overlapping data. The other options for this are "stack", "fill" or "dodge" -- try them all and see which you like best! You can also leave this option out to see what the default behaviour for overlapping data is.

In [None]:
sns.histplot(data=data, x="flipper_length_mm", hue="species", palette=colour_palette, multiple="layer")

We can also use facets to prevent the data from overlapping altogether by specifying ``col="species"``.

In [None]:
sns.displot(data=data, x="flipper_length_mm", hue="species", palette=colour_palette, col="species", legend=False)

Another way to visualise distributions is kernel density estimation. Run the code in the next cell, and then try changing the x-axis label and adding a title.

In [None]:
sns.kdeplot(data=data, x="flipper_length_mm", hue="species", palette=colour_palette, fill=True, alpha=0.5)

We can also combine multiple plots at the same time, by first laying out a grid with ``plt.subplots`` and then plotting different data in the different panels. In the following code:
- ``fig`` is the overall plotting area which combines all the different subplots
- Each element of ``axes`` is a subplot
- The first two arguments to ``plt.subplots`` specify how many rows and columns we want -- you could swap these numbers around if you want the subplots to stack vertically instead of horizontally
- ``figsize`` specifies the overall plotting area which will contain the subplots
- ``gridspec_kw`` specifies the size of each subplot relative to the others -- you can remove this argument to make them all the same size
- ``sns.scatterplot`` will be the first subplot because it's assigned to ``axes[0]``; ``sns.histplot`` will be the second subplot because it's assigned to ``axes[1]`` (remember Python starts counting at zero!)
- ``fig.tight_layout()`` puts a bit of space between the subplots

It looks like a lot of elements, but just play around with them and see what happens! 

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8, 4), gridspec_kw=dict(width_ratios=[4, 3]))
sns.scatterplot(data=data, x="flipper_length_mm", y="bill_length_mm", hue="species", palette=colour_palette, ax=axes[0])
sns.histplot(data=data, x="species", hue="species", palette=colour_palette, shrink=.8, alpha=.8, legend=False, ax=axes[1])
fig.tight_layout()

Seaborn also makes it easy to create some pretty complex and aesthetically pleasing plots. For example, we can show the joint distribution of two variables along with marginal axes that show the univariate distribution of each one separately, as follows:

In [None]:
sns.jointplot(data=data, x="flipper_length_mm", y="bill_length_mm", hue="species", palette=colour_palette)

To save any of your plots, you can add ``plt.savefig("plot_name_here.png")`` (or whatever format you want instead of .png) in the same cell as the code that generates the plot. There are various things you can specify when you save, like size, transparency and DPI. It's good practice to call ``plt.close()`` after saving to free up memory.

Play around with the dataset and see what other graphs you can come up with. For example, you could see whether various aspects of a penguin's body size seem to vary more according to their species or their sex. Feel free to try some of these plotting functions with your own datasets too -- there's plenty of other kinds of plots that might be useful in your own work but wouldn't make sense for the penguins dataset e.g. time series data. Remember to refer to the [seaborn documentation](https://seaborn.pydata.org/tutorial.html) to get new ideas or help implementing something you're trying!