# Analyzing Effects of Light and Nutrient Availability, Seasonal Variations, and Sea Ice Concentration on Ocean Phytoplankton in a Changing Climate: Modeling Bloom Dynamics, Carbon Flux, and Sea-Ice Interactions

**IF ATTEMPTING TO USE NOTEBOOK TO RECREATE FIGURES, PLEASE SEE README FILE FOR PROPER INSTRUCTIONS TO CREATE A WORKING DIRECTORY.**

This Jupyter Notebook contains the code necessary to recreate all of my figures and work for my study on how variables such as light availability, nutrient availability, This project is currently mentored by Dr. Andrew Thompson, the director of Caltech's Linde and Maxine Center for Global Environmental Science, and Sarah Zhang, a current Caltech graduate student in the Linde and Maxine Center for Global Environmental Science. Though I made final presentation to conclude this project for the summer, this is an ongoing project, and thus there may be updates.

**NOTE:** This project was originally completed on Google Colab and was adapted into a Jupyter Notebook. Please see README for instructions to create a proper working directory. If you need any other resources (i.e. would like to recreate the project within Google Colab, additional data, etc.) feel free to email me at audreyjunma@gmail.com.

See images of figures from this notebook here: https://drive.google.com/drive/folders/1amnPzMdTfGwJJ76Sbm3Ov9KRWdxvALKo?usp=sharing

See the final presentation here: https://www.canva.com/design/DAG6QTY85S8/68w1TFFip0YZWgzkiJEIug/edit?utm_content=DAG6QTY85S8&utm_campaign=designshare&utm_medium=link2&utm_source=sharebutton

## Relevant Information

### Background Information

The Southern Ocean's sea ice concentration reached a record high in 2014 and has decreasing ever since, decreasing noticeably starting in 2016, until it reached a record low in February of 2016. The Southern Ocean accounts for around 40% of the world's oceans' global uptake of anthropogenic carbon dioxide, and phytoplankton activity (photosynthetic processes) in the Southern Ocean account for 50-75% of that 40% contribution (i.e. 20-30% of the world's oceans' carbon uptake). In the Southern Ocean, photosynthetic activity of phytoplankton depends on two main factors: light availability and nutrient availability, specifically the availability of iron. Since the coverage of sea-ice can inhibit the passage of light for the phytoplankton to receive, this project aims to analyze how seasonal variations and sea-ice interactions that influence light availability influence phytoplankton populations living underneath sea ice. 

### Relevant Papers

*Boyd et al. 2024* **"The role of biota in the Southern Ocean carbon cycle"**

Link: https://www.nature.com/articles/s43017-024-00531-3#Abs1

*Deppeler and Davidson 2017* **"Southern Ocean Phytoplankton in a Changing Climate"**

Link: https://www.frontiersin.org/journals/marine-science/articles/10.3389/fmars.2017.00040/full

*Purich and Doddridge 2023* **"Record low Antarctic sea ice coverage indicates a new sea ice state"**

Link: https://www.nature.com/articles/s43247-023-00961-9

*Meiners et al. 2012* **"Introduction: SIPEX-2: A study of sea-ice physical, biogeochemical and ecosystem processes off East Antarctica during spring 2012"**

Link: https://www.sciencedirect.com/science/article/abs/pii/S0967064516301710?via%3Dihub


## Code (work, figures, etc.)

### Imports

In [None]:
import numpy as np
import scipy
from scipy.stats import linregress
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import xarray as xr
import pandas as pd
import cartopy.crs as ccrs
from pyproj import Transformer
from collections import defaultdict
from calendar import month_name
from datetime import datetime
import matplotlib.gridspec as gridspec
import pickle

### Loading + Compiling Data

In [None]:
fp = 'DIRECTORY NAME' # need to change this to your working directory name

list_profiles = xr.open_dataset(fp+'/Data/MEOP/global_list_profiles.nc')

list_profiles

We can make a few simple heatmaps to see where our data is being collected (by the seals, so essentially, where they are swimming).

In [None]:
# define grid size
grid_size = 1.0  # adjust for finer/coarser grids

# create 2D histogram of seal paths
hist, xedges, yedges = np.histogram2d(
    list_profiles['LONGITUDE'],
    list_profiles['LATITUDE'],
    bins=[np.arange(-180, 180 + grid_size, grid_size),
          np.arange(-90, -40 + grid_size, grid_size)]
)

# create figure + axis
fig = plt.figure(figsize=(10, 10))
ax = plt.subplot(1, 1, 1, projection=ccrs.SouthPolarStereo())

# plot histogram as a pcolormesh
c = ax.pcolormesh(xedges, yedges, hist.T, shading='auto', transform=ccrs.PlateCarree(), cmap='Reds', alpha=0.7)

# add coastlines + gridlines
ax.coastlines()
gl = ax.gridlines(draw_labels=True, linewidth=1, color='gray', alpha=0.5, linestyle='--')

# add colorbar
cbar = plt.colorbar(c, ax=ax, orientation='vertical', pad=0.01)
cbar.set_label('Number of Seals (N)')

# set plot extent
ax.set_extent([-180, 180, -90, -40], ccrs.PlateCarree())

# display plot
plt.show()

In [None]:
# same plot but with more visible defined coloring

# define grid size
grid_size = 1.0  # adjust for finer/coarser grids

# create 2D histogram of seal paths
hist, xedges, yedges = np.histogram2d(
    list_profiles['LONGITUDE'],
    list_profiles['LATITUDE'],
    bins=[np.arange(-180, 180 + grid_size, grid_size),
          np.arange(-90, -40 + grid_size, grid_size)]
)

# create figure + axis
fig = plt.figure(figsize=(10, 10))
ax = plt.subplot(1, 1, 1, projection=ccrs.SouthPolarStereo())

# set colorbar cap
colorbar_cap = 100  # set this to max display value

# plot histogram as a pcolormesh with capped colorbar
c = ax.pcolormesh(xedges, yedges, hist.T, shading='auto', transform=ccrs.PlateCarree(),
                  cmap='Reds', alpha=0.7, vmax=colorbar_cap)

# add coastlines + gridlines
ax.coastlines()
gl = ax.gridlines(draw_labels=True, linewidth=1, color='gray', alpha=0.5, linestyle='--')

# add colorbar
cbar = plt.colorbar(c, ax=ax, orientation='vertical', pad=0.01)
cbar.set_label('Number of Seals (N)')

# set plot extent
ax.set_extent([-180, 180, -90, -40], ccrs.PlateCarree())

# display plot
plt.show()

In [None]:
# same plot but with a coarser, smaller grid

grid_size = 3.0


# create 2D histogram of seal paths
hist, xedges, yedges = np.histogram2d(
    with_CHLA['LONGITUDE'],
    with_CHLA['LATITUDE'],
    bins=[np.arange(-180, 180 + grid_size, grid_size),
          np.arange(-90, -40 + grid_size, grid_size)]
)

# create figure + axis
fig = plt.figure(figsize=(10, 10))
ax = plt.subplot(1, 1, 1, projection=ccrs.SouthPolarStereo())

# set colorbar cap
colorbar_cap = 100  # set this to max display value

# plot histogram as a pcolormesh with capped colorbar
c = ax.pcolormesh(xedges, yedges, hist.T, shading='auto', transform=ccrs.PlateCarree(),
                  cmap='Reds', alpha=0.7, vmax=colorbar_cap)

# add coastlines + gridlines
ax.coastlines()
gl = ax.gridlines(draw_labels=True, linewidth=1, color='gray', alpha=0.5, linestyle='--')

# add colorbar
cbar = plt.colorbar(c, ax=ax, orientation='vertical', pad=0.01)
cbar.set_label('Number of Seals (N)')

# set plot extent
ax.set_extent([-180, 180, -90, -40], ccrs.PlateCarree())

# display plot
plt.show()

### Data Analysis + Visualization (plots)

#### Single Seal Plots

In [None]:
# chlorophyll vs time, colored determined by SIC data (monthly SIC)

plt.figure(figsize=(12, 6))
sc = plt.scatter(time.values, chla, c=sic_monthly, cmap='viridis', s=30, edgecolor='none')
plt.colorbar(sc, label='Sea Ice Concentration (Monthly)')
plt.xlabel('Date')
plt.ylabel('Chlorophyll-a (mg/m³)')
plt.title(f'Chlorophyll-a over Time (Seal: {seal_id})\nColored by Monthly Sea Ice Concentration')
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
# chlorophyll vs time, colored determined by SIC data (daily SIC)

plt.figure(figsize=(12, 6))
sc = plt.scatter(time.values, chla, c=sic_daily, cmap='viridis', s=30, edgecolor='none')
plt.colorbar(sc, label='Sea Ice Concentration (Daily)')
plt.xlabel('Date')
plt.ylabel('Chlorophyll-a (mg/m³)')
plt.title(f'Chlorophyll-a over Time (Seal: {seal_id})\nColored by Daily Sea Ice Concentration')
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
# on daily chlorophyll vs time, highlighting chlorophyll when SIC > 0.70
threshold = 0.70
high_sic_mask = (sic_monthly > threshold)

plt.figure(figsize=(12, 6))
plt.scatter(time.values[~high_sic_mask], chla[~high_sic_mask], c='gray', s=20, label='SIC ≤ 0.70')
plt.scatter(time.values[high_sic_mask], chla[high_sic_mask], c='red', s=35, label='SIC > 0.70')
plt.xlabel('Date')
plt.ylabel('Chlorophyll-a (mg/m³)')
plt.title(f'Chlorophyll-a with High Sea Ice Concentration (> {threshold}) Highlighted')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
# create figure
fig, ax = plt.subplots(figsize=(10, 8), subplot_kw={'projection': ccrs.SouthPolarStereo()})

# plot SIC with lat lon grid for SIC
sic_plot = ax.pcolormesh(lon, lat, sic_daily_reshaped, shading='auto', cmap='Blues', alpha=0.5)

# SIC colorbar
cbar = plt.colorbar(sic_plot, ax=ax, orientation='vertical', pad=0.02)
cbar.set_label('Sea Ice Concentration (%)')

# seal path colored by CHLA
sc = ax.scatter(seal_longitude, seal_latitude, c=seal_chla, cmap='viridis', marker='o', edgecolor='k')

# CHLA colorbar
cbar_chla = plt.colorbar(sc, ax=ax, orientation='vertical', pad=0.02)
cbar_chla.set_label('Chlorophyll (CHLA) Concentration')

# add map features
ax.add_feature(cartopy.feature.LAND, zorder=1, edgecolor='black')
ax.add_feature(cartopy.feature.COASTLINE, zorder=2)
ax.add_feature(cartopy.feature.BORDERS, linestyle=':')

# set title
ax.set_title(f'Seal Path and Sea Ice Concentration for Seal ID: {seal_id}')

# display plot
plt.show()

#### Compiling Data for Multiple Seals (filtering by latitude)

In [None]:
#  determine if all positions for a seal are south of 55°S
def is_south_of_55(latitudes):
    return np.any(latitudes <= -55)

# filter seals that stay south of 55°S and record their dive timeframes

southern_seals = []
missing_seals = []

# ensure no duplicates of seal IDs
available_seals = np.unique([s.decode('utf-8') if isinstance(s, bytes) else s
                             for s in list_profiles['SMRU_PLATFORM_CODE'].values])

for seal_id in available_seals:
    try:
        seal_ds = load_single_seal(seal_id, load=False)
        lats = seal_ds['LATITUDE']
        times = seal_ds['time']
        if is_south_of_55(lats):
            start_time = str(pd.to_datetime(times.min().values).date())
            end_time = str(pd.to_datetime(times.max().values).date())
            southern_seals.append((seal_id, start_time, end_time))
    except FileNotFoundError:
        missing_seals.append(seal_id)
    except Exception:
        continue

# print seal IDs and timeframe of dives
print(f"Found {len(southern_seals)} seals that stayed entirely south of 55°S:\n")
for seal_id, start, end in southern_seals:
    print(f"• Seal ID: {seal_id} — Dive timeframe: {start} to {end}")

In [None]:
# check from filtered seals to see which seals have a timeframe of < 1 yr
from datetime import datetime

short_duration_seals = []

for seal_id, start, end in southern_seals:
    start_dt = datetime.strptime(start, "%Y-%m-%d")
    end_dt = datetime.strptime(end, "%Y-%m-%d")
    duration_days = (end_dt - start_dt).days

    if duration_days < 365:
        short_duration_seals.append((seal_id, start, end, duration_days))

# print results
print(f"Seals with dive timeframes shorter than 1 year ({len(short_duration_seals)} found):\n")
for seal_id, start, end, duration in short_duration_seals:
    print(f"• {seal_id}: {start} to {end} — {duration} days")

In [None]:
# check from filtered seals to see which seals have a timeframe of < 1 yr
from datetime import datetime

short_duration_seals = []

for seal_id, start, end in southern_seals:
    start_dt = datetime.strptime(start, "%Y-%m-%d")
    end_dt = datetime.strptime(end, "%Y-%m-%d")
    duration_days = (end_dt - start_dt).days

    if duration_days < 365:
        short_duration_seals.append((seal_id, start, end, duration_days))

# print results
print(f"Seals with dive timeframes shorter than 1 year ({len(short_duration_seals)} found):\n")
for seal_id, start, end, duration in short_duration_seals:
    print(f"• {seal_id}: {start} to {end} — {duration} days")

In [None]:
# initialize dictionary with month names
monthly_bins = {
    month: {
        '55S_to_60S': set(),
        '60S_to_70S': set(),
        '70S_to_80S': set()
    } for month in list(month_name)[1:]
}

# loop through each seal
for seal_id, start, end, duration in chlorophyll_seals:
    try:
        seal_ds = load_single_seal(seal_id, load=False)
        lats = seal_ds['LATITUDE'].values
        times = pd.to_datetime(seal_ds['time'].values)

        # group latitudes for each month
        monthly_data = defaultdict(list)
        for lat, time in zip(lats, times):
            if np.isnan(lat):
                continue
            calendar_month = time.strftime('%B')
            monthly_data[calendar_month].append(lat)

        # compute latitude average for each monthly bin
        for calendar_month, lat_list in monthly_data.items():
            if not lat_list:
                continue
            avg_lat = np.nanmean(lat_list)

            if -60 <= avg_lat < -55:
                monthly_bins[calendar_month]['55S_to_60S'].add(seal_id)
            elif -70 <= avg_lat < -60:
                monthly_bins[calendar_month]['60S_to_70S'].add(seal_id)
            elif -80 <= avg_lat < -70:
                monthly_bins[calendar_month]['70S_to_80S'].add(seal_id)

    except Exception as e:
        print(f"Error with seal {seal_id}: {e}")
        continue

print("\nSeals grouped by calendar month and latitude band:\n")
for month in list(month_name)[1:]:  # jan to dec
    print(month.upper())  # header
    print()

    for band in ['55S_to_60S', '60S_to_70S', '70S_to_80S']:
        seals = sorted(monthly_bins[month][band])
        band_label = band.replace('_', ' ')
        print(f"{band_label}")
        print("__________")

        if seals:
            for seal_id in seals:
                print(seal_id)
        else:
            print("(none)")

        print()  

    print()  

In [None]:
# initialize dictionary grouped by latitude band
lat_band_bins = {
    '55S_to_60S': [],
    '60S_to_70S': [],
    '70S_to_80S': []
}

# process each filtered seal
for seal_id, start, end, duration in chlorophyll_seals:
    try:
        seal_ds = load_single_seal(seal_id, load=False)
        lats = seal_ds['LATITUDE'].values
        times = pd.to_datetime(seal_ds['time'].values)

        # pair latitudes with times + remove NaNs
        valid_coords = [(lat, time) for lat, time in zip(lats, times) if not np.isnan(lat)]

        if not valid_coords:
            continue

        # extract latitudes for averaging
        latitudes = [lat for lat, _ in valid_coords]
        avg_lat = np.nanmean(latitudes)

        # bin seals by average latitude
        if -60 <= avg_lat < -55:
            lat_band_bins['55S_to_60S'].append((seal_id, start, end))
        elif -70 <= avg_lat < -60:
            lat_band_bins['60S_to_70S'].append((seal_id, start, end))
        elif -80 <= avg_lat < -70:
            lat_band_bins['70S_to_80S'].append((seal_id, start, end))

    except Exception as e:
        print(f"Error with seal {seal_id}: {e}")
        continue

# display results
print("\nSeals grouped by latitude band (sorted by start date):\n")
for band in ['55S_to_60S', '60S_to_70S', '70S_to_80S']:
    print(band.replace('_', ' '))
    print("__________")

    seals = lat_band_bins[band]
    if seals:
        # Sort by start date
        seals_sorted = sorted(seals, key=lambda x: pd.to_datetime(x[1]))
        for seal_id, start, end in seals_sorted:
            print(f"{seal_id}: {start} → {end}")
    else:
        print("(none)")

    print()

#### Plots for Multiple Seals

In [None]:
print(len(chlorophyll_seals))
print(chlorophyll_seals)

seal_ids = [str(t[0]) for t in chlorophyll_seals]
print(seal_ids)

# sanity check
yearly_sic

In [None]:
for i in seal_ids:
  seal_id = i
  path=fp+'/Data/MEOP/CHLA_subset/'
  ds = xr.open_dataset(path + seal_id +'_all_prof.nc',decode_times=False)
  time_array = np.datetime64('1950-01-01', 'ns')+(ds.JULD.values).astype('timedelta64[D]')+((ds.JULD.values-np.floor(ds.JULD.values))*86400).astype('timedelta64[s]')
  year = np.unique(time_array.astype('datetime64[Y]'))
  sic_fp = fp + 'Data/SIC/'
  daily_sic_fp = fp + 'Data/SIC/Daily SIC/'
  with open(fp + f"Code/Processed Data/yearly_sic.pkl", 'rb') as file:
      yearly_sic = pickle.load(file)

  if str(year) + '_daily' not in list(yearly_sic.keys()):
    yearly_sic[str(year) + '_daily'] = get_and_project_sic_daily(year)

  with open(fp + f"Code/Processed Data/yearly_sic.pkl", 'wb') as file:
    pickle.dump(yearly_sic, file)

  with open(fp + f"Code/Processed Data/seal_sic_dict.pkl", 'rb') as file:
      seal_sic_dict = pickle.load(file)

  seal_id_daily = seal_id + '_daily'
  if seal_id_daily not in list(seal_sic_dict.keys()):
    seal_sic_dict[seal_id_daily] = get_seal_sic_daily(seal_id)

  with open(fp + f"Code/Processed Data/seal_sic_dict.pkl", 'wb') as file:
    pickle.dump(seal_sic_dict, file)

In [None]:
collected = []

# choosing a seal
for i in seal_ids:
  seal_ds = load_single_seal(i)
  with open(fp + f"Code/Processed Data/seal_sic_dict.pkl", 'rb') as file:
      seal_sic_dict = pickle.load(file)
  chla = seal_ds['CHLA'].mean(dim='N_LEVELS', skipna=True)  # Mean chl-a per profile
  time = seal_ds['time']

  collected.append({
          "seal_id": i,
          "chlorophyll": chla,
          "time": time
      })

In [None]:
for seal_data in collected:
    seal_id = seal_data['seal_id']
    time_array = seal_data['time']
    chla_array = seal_data['chlorophyll']

    sic_daily = seal_sic_dict[seal_id]

    plt.figure(figsize=(12, 6))
    sc = plt.scatter(time_array, chla_array, c=sic_daily, cmap='viridis', s=30, edgecolor='none')
    plt.colorbar(sc, label='Sea Ice Concentration (Daily)')
    plt.xlabel('Date')
    plt.ylabel('Chlorophyll-a (mg/m³)')
    plt.title(f'Chlorophyll-a over Time for Seal ID {seal_id} Colored by Daily Sea Ice Concentration')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [None]:
# reload existing SIC dictionary
with open(fp + f"Code/Processed Data/seal_sic_dict.pkl", 'rb') as file:
    seal_sic_dict = pickle.load(file)

# add missing seals' SIC data
for seal_info in chlorophyll_seals:
    seal_id = seal_info[0]
    seal_id_daily = seal_id + '_daily'

    # compute + store monthly SIC if missing
    if seal_id not in seal_sic_dict:
        try:
            seal_sic_dict[seal_id] = get_seal_sic(seal_id)
        except Exception as e:
            print(f"Failed monthly SIC for {seal_id}: {e}")

    # compute + store daily SIC if missing
    if seal_id_daily not in seal_sic_dict:
        try:
            seal_sic_dict[seal_id_daily] = get_seal_sic_daily(seal_id)
        except Exception as e:
            print(f"Failed daily SIC for {seal_id}: {e}")

# save updated dictionary
with open(fp + f"Code/Processed Data/seal_sic_dict.pkl", 'wb') as file:
    pickle.dump(seal_sic_dict, file)

In [None]:
results = []

with open(fp + f"Code/Processed Data/seal_sic_dict.pkl", 'rb') as file:
    seal_sic_dict = pickle.load(file)

for i in chlorophyll_seals:
    seal_id = i[0]
    seal_ds = load_single_seal(seal_id)

    # pull SIC data
    sic_monthly = seal_sic_dict.get(seal_id)
    sic_daily = seal_sic_dict.get(seal_id + '_daily')

    # compute mean CHLA across N_LEVELS
    chla = seal_ds['CHLA'].mean(dim='N_LEVELS', skipna=True)
    time = seal_ds['time']

    # append to results list
    results.append({
        'seal_id': seal_id,
        'chla': chla,
        'time': time,
        'sic_monthly': sic_monthly,
        'sic_daily': sic_daily,
        'latitude': seal_ds['LATITUDE'].values,
        'longitude': seal_ds['LONGITUDE'].values
    })

In [None]:
results
# load sic monthly + sic daily
# paper estimations of chlorophyll concentrations under sea ice (artic + antartic) **antartic sea ice is best but possibly limited data (date, where in the area, ice thickness)
# blooms (sudden increase) vs. presence at all
# connection chlorophyll vs. tons of carbon (measuring phytoplankton)

In [None]:
# convert results to DF
results_df = pd.DataFrame(results)

# choose a seal
selected_seal = results_df.iloc[0]  # CHANGE INDEX HERE
seal_id = selected_seal['seal_id']
seal_chla = selected_seal['chla']
seal_time = selected_seal['time']
seal_latitude = selected_seal['latitude']
seal_longitude = selected_seal['longitude']
sic_monthly = selected_seal['sic_monthly']

# create figure
fig, ax = plt.subplots(figsize=(10, 8), subplot_kw={'projection': ccrs.SouthPolarStereo()})

# plot SIC with lat lon grid for SIC
lon = np.linspace(-180, 180, sic_monthly.shape[1])
lat = np.linspace(-90, 0, sic_monthly.shape[0])
sic_plot = ax.pcolormesh(lon, lat, sic_monthly, shading='auto', cmap='Blues', alpha=0.5)

# sic colorbar
cbar = plt.colorbar(sic_plot, ax=ax, orientation='vertical', pad=0.02)
cbar.set_label('Sea Ice Concentration')

# seal path colored by CHLA
sc = ax.scatter(seal_longitude, seal_latitude, c=seal_chla, cmap='viridis', marker='o', edgecolor='k')

# CHLA colorbar
cbar_chla = plt.colorbar(sc, ax=ax, orientation='vertical', pad=0.02)
cbar_chla.set_label('Chlorophyll (CHLA) Concentration')

# add map featurees
ax.coastlines()
ax.set_title(f'Seal Path and Sea Ice Concentration for Seal ID: {seal_id}')

plt.show()

In [None]:
# seal path + depth/water column plots (temperature, salinity, CHLA) + SIC plots

def basic_seal_plots_sic(seal_id, constrain = 0, seaice_option = 'simple'):
  seal = load_single_seal(seal_id)
  seal_sic_monthly = seal_sic_dict[seal_id]
  seal_sic_daily = seal_sic_dict[seal_id + '_daily']

  ##  figure showing the path of this seal
  fig, ax = plt.subplots(1,1,figsize=(10, 10),subplot_kw={'projection': ccrs.SouthPolarStereo()})
  ax.coastlines()
  gl = ax.gridlines(draw_labels=True, linewidth=1, color='gray', alpha=0.5, linestyle='--')
  ax.set_extent([-180, 180, -90, -40], ccrs.PlateCarree())
  c = ax.scatter(seal['LONGITUDE'],seal['LATITUDE'], s=5, c = seal['time.month'],cmap = 'hsv',transform=ccrs.PlateCarree())
  cbar = plt.colorbar(c, shrink = 0.5)
  plt.show()
  #
  pressure = seal.PRES_ADJUSTED
  temp     = seal.TEMP_ADJUSTED
  salinity = seal.PSAL_ADJUSTED
  chla     = seal.CHLA_ADJUSTED
  time     = seal.time
  seal_sic_monthly = xr.DataArray(seal_sic_monthly, dims = ['time'], coords = [time]).fillna(0)
  seal_sic_daily = xr.DataArray(seal_sic_daily, dims = ['time'], coords = [time]).fillna(0)
  _,time_grid = np.meshgrid(seal.N_LEVELS,time) # for plotting purposes
  # t,s,chla plots
  fig, axs = plt.subplots(4,1,figsize=(25,14), gridspec_kw={'hspace':0.1})
  s = 8
  vmin = temp.min(); vmax = temp.max() # articificially constraining the colormap so it's not too dominated by warm waters at the start
  c = axs[0].scatter(time_grid,pressure, c = temp, s=s, vmin=vmin, vmax=vmax, cmap = 'plasma')
  cax = fig.add_axes([0.07, 0.7, 0.01, 0.18]); cbar=plt.colorbar(c,cax = cax,orientation='vertical', extend='max'); cbar.set_label('[$^\circ$C]', fontsize = 12)
  axs[0].invert_yaxis()
  axs[0].set_xlim(time[0], time[-1])
  vmin = salinity.min(); vmax = salinity.max()
  c = axs[1].scatter(time_grid,pressure, c = salinity, s=s, vmin=vmin, vmax=vmax, cmap = 'viridis')
  cax = fig.add_axes([0.07, 0.5, 0.01, 0.18]); cbar=plt.colorbar(c,cax = cax,orientation='vertical', extend='max'); cbar.set_label('[psu]', fontsize = 12)
  axs[1].invert_yaxis()
  axs[1].set_xlim(time[0], time[-1])
  # Looking at other papers, I noticed chlorophyll fluoresence is usually shown logarithmically
  c = axs[2].scatter(time_grid,pressure, c = chla, s=s, norm=colors.LogNorm(vmin=0.1, vmax=2), cmap = 'YlGn')
  cax = fig.add_axes([0.07, 0.3, 0.01, 0.18]); cbar=plt.colorbar(c,cax = cax,orientation='vertical', extend='max'); cbar.set_label('[mg m$^{-3}$]', fontsize = 12)
  axs[2].invert_yaxis()
  axs[2].set_xlim(time[0], time[-1])
  if constrain > 0:
    # constrain y axis so they all match, only have chla for upper 175 m
    axs[0].set_ylim([constrain,0])
    axs[1].set_ylim([constrain,0])
    axs[2].set_ylim([constrain,0])
  if seaice_option == 'simple': # simple version, one color
    axs[3].plot(time,seal_sic_monthly, 'bo', markersize=3, label='monthly SIC', alpha=0.25)
    # axs[3].plot(time,seal_sic_daily, 'ro', markersize=3, label='daily SIC')
  elif seaice_option == 'mean_chla':
    axs[3].scatter(time,seal_sic_daily, c = chla.mean('N_LEVELS'), s = s,norm=colors.LogNorm(vmin=0.1, vmax=1), cmap = 'YlGn', label='daily SIC, colored by mean Chla') # colored according mean Chla
  axs[3].set_ylabel('sea-ice concentration')
  axs[3].set_ylim([0,1])
  axs[3].set_xlim(time[0], time[-1])
  axs[3].legend()


  plt.show()

In [None]:
def get_chla_timeseries(seal_id):
  seal = load_single_seal(seal_id)
  pressure = seal.PRES_ADJUSTED
  chla     = seal.CHLA_ADJUSTED
  time     = seal.time

  chla_ts = np.nansum(chla, axis=1)
  return chla_ts, time
    
# pick 'under sea ice' threshold (generally 0.7, but you can change to preference)
sic_threshold = 0.7

# set SICs to booleans (True/False) based on if the SIC > sic_threshold e.g. (np.nan >= sic_threshold) returns False, which is what we want
seal_sic_bools = (seal_sic >= sic_threshold)

In [None]:
# plot total CHLA, summed over entire water column

seal_chla_ts, seal_time = get_chla_timeseries(seal_id)

colors = {True: "lightskyblue", False: "orange"}

fig, ax = plt.subplots(figsize=(8, 5))
ax.scatter(seal_time, seal_chla_ts, c=[colors[b] for b in seal_sic_bools], s=5)
ax.legend(handles=[
    mpatches.Patch(color=colors[True], label=f"SIC >= {sic_threshold}"),
    mpatches.Patch(color=colors[False], label=f"SIC < {sic_threshold}")
])
ax.set_ylabel("Total Chlorophyll-a [mg m$^{-3}$]")
ax.set_title(seal_id + ', Daily SIC')
fig.show()

save_filename = fp + f"JupyterNotebooks/Figures/total_chla_{seal_id}_dailysic.png"
fig.savefig(save_filename, bbox_inches='tight')

#### Regression Plots with Line of Best Fit

In [None]:
# regression plots with line of best fit

# ensure lengths + NaNs are aligned properly
valid_mask = ~np.isnan(seal_chla_ts) & ~np.isnan(seal_sic)
sic_values = seal_sic[valid_mask]
chla_values = seal_chla_ts[valid_mask]

# scatter plot
fig, ax = plt.subplots(figsize=(8, 5))
ax.scatter(sic_values, chla_values, s=10, alpha=0.6, label='Data Points')

# line of best fit
slope, intercept, r_value, p_value, std_err = linregress(sic_values, chla_values)
x_vals = np.linspace(0, 1, 100)
y_vals = slope * x_vals + intercept
ax.plot(x_vals, y_vals, color='red', label=f'Best Fit Line\n$R^2$={r_value**2:.2f}, p={p_value:.2g}')

# labels + title
ax.set_xlabel("Sea-Ice Concentration (SIC)")
ax.set_ylabel("Total Chlorophyll-a [mg m$^{-3}$]")
ax.set_title(f'{seal_id}: CHLA vs SIC')
ax.legend()

# display plot
plt.show()

In [None]:
all_chla = []
all_sic = []

seal_ids = [k for k in seal_sic_dict.keys() if not k.endswith('_daily')]

for seal_id in seal_ids:
    try:
        # load seal data
        seal = load_single_seal(seal_id)
        chla = seal.CHLA_ADJUSTED  # shape: (n_profiles, n_levels)
        time = pd.to_datetime(seal.time.values)

        # load SIC 
        sic_daily = seal_sic_dict.get(seal_id + '_daily')
        if sic_daily is None:
            continue

        sic_series = pd.Series(sic_daily, index=time)

        # loop over each profile (time step)
        for i in range(chla.shape[0]):
            chla_profile = chla[i, :]  # CHLA at all depths for this time
            sic_value = sic_series.iloc[i]

            if not np.isfinite(sic_value):
                continue  # skip if SIC is missing

            # mask CHLA profile to keep only finite values
            valid_mask = np.isfinite(chla_profile)
            valid_chla = chla_profile[valid_mask]

            # append SIC value once for each valid CHLA value
            all_sic.extend([sic_value] * len(valid_chla))
            all_chla.extend(valid_chla)

    except Exception as e:
        print(f"Skipping {seal_id} due to error: {e}")
        continue

# convert to numpy arrays
all_sic = np.array(all_sic)
all_chla = np.array(all_chla)

# plot
fig, ax = plt.subplots(figsize=(8, 5))
ax.scatter(all_sic, all_chla, s=3, alpha=0.3, label='All CHLA points')

# line of best fit
if len(all_sic) > 2:
    slope, intercept, r_value, p_value, std_err = linregress(all_sic, all_chla)
    x_vals = np.linspace(0, 1, 100)
    y_vals = slope * x_vals + intercept
    ax.plot(x_vals, y_vals, color='red', label=f'Fit: $R^2$={r_value**2:.2f}, p={p_value:.2g}')

# labels
ax.set_xlabel("Sea-Ice Concentration (SIC)")
ax.set_ylabel("Chlorophyll-a [mg m$^{-3}$]")
ax.set_title("All CHLA Points vs SIC (All Seals, All Depths)")
ax.legend()

# display plot
plt.show()

In [None]:
all_chla = []
all_sic = []

seal_ids = [k for k in seal_sic_dict.keys() if not k.endswith('_daily')]

for seal_id in seal_ids:
    try:
        # load seal data
        seal = load_single_seal(seal_id)
        chla = seal.CHLA_ADJUSTED  # shape: (n_profiles, n_levels)
        time = pd.to_datetime(seal.time.values)

        # 
        sic_monthly = seal_sic_dict.get(seal_id) 
        if sic_monthly is None:
            continue

        # create pandas series for monthly SIC with the same time index as seal data
        sic_series = pd.Series(sic_monthly, index=time)

        # loop over each profile (time step)
        for i in range(chla.shape[0]):
            chla_profile = chla[i, :]  # CHLA at all depths for this time
            sic_value = sic_series.iloc[i]  # get SIC value for this time step

            if not np.isfinite(sic_value):
                continue  # skip if SIC is missing

            # mask CHLA profile to keep only finite values
            valid_mask = np.isfinite(chla_profile)
            valid_chla = chla_profile[valid_mask]

            # append SIC value once for each valid CHLA value
            all_sic.extend([sic_value] * len(valid_chla))
            all_chla.extend(valid_chla)

    except Exception as e:
        print(f"Skipping {seal_id} due to error: {e}")
        continue

# convert to numpy arrays
all_sic = np.array(all_sic)
all_chla = np.array(all_chla)

# plot
fig, ax = plt.subplots(figsize=(8, 5))
ax.scatter(all_sic, all_chla, s=3, alpha=0.3, label='All CHLA points')

# line of best fit
if len(all_sic) > 2:
    slope, intercept, r_value, p_value, std_err = linregress(all_sic, all_chla)
    x_vals = np.linspace(0, 1, 100)
    y_vals = slope * x_vals + intercept
    ax.plot(x_vals, y_vals, color='red', label=f'Fit: $R^2$={r_value**2:.2f}, p={p_value:.2g}')

# labels
ax.set_xlabel("Sea-Ice Concentration (SIC)")
ax.set_ylabel("Chlorophyll-a [mg m$^{-3}$]")
ax.set_title("All CHLA Points vs SIC (All Seals, All Depths)")
ax.legend()

# display plot
plt.show()

In [None]:
# clear previous values
all_chla_ts = []
all_sic_vals = []

# threshold for SIC coloring
sic_threshold = 0.7
colors = {True: "lightskyblue", False: "orange"}

# go through all seals (non-daily keys)
seal_ids = [k for k in seal_sic_dict.keys() if not k.endswith('_daily')]

for seal_id in seal_ids:
    try:
        # load CHLA time series + SIC for the seal
        seal_chla_ts, seal_time = get_chla_timeseries(seal_id)
        seal_sic = seal_sic_dict.get(seal_id + '_daily')
        if seal_sic is None:
            continue

        # convert to arrays
        chla = np.array(seal_chla_ts)
        sic = np.array(seal_sic)

        # remove NaNs
        valid_mask = np.isfinite(chla) & np.isfinite(sic)
        all_chla_ts.extend(chla[valid_mask])
        all_sic_vals.extend(sic[valid_mask])

    except Exception as e:
        print(f"Skipping {seal_id} due to error: {e}")
        continue

# convert to numpy arrays
all_chla_ts = np.array(all_chla_ts)
all_sic_vals = np.array(all_sic_vals)
sic_bools = all_sic_vals >= sic_threshold
point_colors = [colors[b] for b in sic_bools]

# line of best fit
slope, intercept, r_value, p_value, std_err = linregress(all_sic_vals, all_chla_ts)
x_vals = np.linspace(0, 1, 100)
y_vals = slope * x_vals + intercept

# plot: CHLA vs SIC for all seals
fig, ax = plt.subplots(figsize=(8, 5))
ax.scatter(all_sic_vals, all_chla_ts, c=point_colors, s=5, alpha=0.6)

# best fit line
ax.plot(x_vals, y_vals, color='red', label=f'Best Fit Line\n$R^2$={r_value**2:.2f}, p={p_value:.2g}')

# legend
ax.legend(handles=[
    mpatches.Patch(color=colors[True], label=f"SIC ≥ {sic_threshold}"),
    mpatches.Patch(color=colors[False], label=f"SIC < {sic_threshold}"),
    ax.lines[-1]  # add the fit line to the legend
])

# labels + title
ax.set_xlabel("Sea-Ice Concentration (SIC)")
ax.set_ylabel("Total Chlorophyll-a [mg m$^{-3}$]")
ax.set_title("All Seals: CHLA vs Daily SIC (with Best Fit Line)")
ax.grid(True)

# display plot
plt.show()

In [None]:
# same plot, but separated by region

# sector definitions (longitudes in degrees)
regions = {
    'ewedd': [-40, 15],
    'wwedd': [-62, -40],
    'ab': [-140, -62],
    'ross': [160, -140], 
    'pac': [90, 160],
    'indian': [15, 90]
}
titles = {
    'ewedd':'East Weddell', 'wwedd': 'West Weddell', 'ab': 'Amundsen & Bellingshausen',
    'ross': 'Ross', 'pac': 'Pacific', 'indian': 'Indian'
}

def get_region_mask_for_profiles(lons, regions):
    region_profile_masks = {}
    for region, (lon_min, lon_max) in regions.items():
        if region != 'ross':
            mask = (lons > lon_min) & (lons <= lon_max)
        else:
            # ross wraps around 180/-180 longitude
            mask = (lons <= lon_min) | (lons > lon_max)
        region_profile_masks[region] = mask
    return region_profile_masks

all_region_data = {region: {'sic': [], 'chla': []} for region in regions}

seal_ids = [k for k in seal_sic_dict.keys() if not k.endswith('_daily')]

for seal_id in seal_ids:
    try:
        seal = load_single_seal(seal_id)
        chla = seal.CHLA_ADJUSTED  # shape (n_profiles, n_levels)
        time = pd.to_datetime(seal.time.values)
        lons = seal['LONGITUDE'].values  # shape (n_profiles,)

        sic_daily = seal_sic_dict.get(seal_id + '_daily')
        if sic_daily is None:
            continue
        # align sic_daily with profiles by time (assumed same length + order)
        sic_series = pd.Series(sic_daily, index=time)

        # get masks for each sector
        region_masks = get_region_mask_for_profiles(lons, regions)

        for i in range(chla.shape[0]):
            sic_value = sic_series.iloc[i]
            if not np.isfinite(sic_value):
                continue
            chla_profile = chla[i, :]
            valid_mask = np.isfinite(chla_profile)
            valid_chla = chla_profile[valid_mask]

            for region, mask in region_masks.items():
                if mask[i]:
                    # Append SIC once per each valid CHLA point at this profile in this region
                    all_region_data[region]['sic'].extend([sic_value] * len(valid_chla))
                    all_region_data[region]['chla'].extend(valid_chla)

    except Exception as e:
        print(f"Skipping {seal_id} due to error: {e}")
        continue

# plotting
fig, axs = plt.subplots(len(regions), 1, figsize=(8, 5*len(regions)), constrained_layout=True)

for ax, (region, data) in zip(axs, all_region_data.items()):
    sic = np.array(data['sic'])
    chla = np.array(data['chla'])

    ax.scatter(sic, chla, s=5, alpha=0.3, label='Data points')
    if len(sic) > 2:
        slope, intercept, r_value, p_value, std_err = linregress(sic, chla)
        x_vals = np.linspace(0, 1, 100)
        y_vals = slope * x_vals + intercept
        ax.plot(x_vals, y_vals, color='red', label=f'Fit $R^2$={r_value**2:.2f}, p={p_value:.2g}')

    ax.set_xlabel('Daily Sea-Ice Concentration (SIC)')
    ax.set_ylabel('Chlorophyll-a [mg m$^{-3}$]')
    ax.set_title(titles.get(region, region))
    ax.legend()

# display plot
plt.show()

#### Time Trend Plots

In the Southern Ocean, sea ice has been declining since it reached a record high in 2014, specifically starting to noticeably decline after 2016 until it reached a record low in February of 2023. To show this difference as a trend of CHLA vs SIC, I created two plots showing CHLA vs SIC: one with all data points collected until the beginning of 2016, and another with all data points collected in 2016 and onwards.

In [None]:
all_chla_pre2016 = []
all_sic_pre2016 = []
all_chla_post2016 = []
all_sic_post2016 = []

seal_ids = [k for k in seal_sic_dict.keys() if not k.endswith('_daily')]

for seal_id in seal_ids:
    try:
        # load seal data
        seal = load_single_seal(seal_id)
        chla = seal.CHLA_ADJUSTED  # shape: (n_profiles, n_levels)
        time = pd.to_datetime(seal.time.values)

        # load SIC for seal (monthly SIC instead of daily)
        sic_monthly = seal_sic_dict.get(seal_id)  
        if sic_monthly is None:
            continue

        # create pandas series for monthly SIC with the same time index as seal data
        sic_series = pd.Series(sic_monthly, index=time)

        # loop over each profile (time step)
        for i in range(chla.shape[0]):
            chla_profile = chla[i, :]  # CHLA at all depths for this time
            sic_value = sic_series.iloc[i]  # get SIC value for this time step

            if not np.isfinite(sic_value):
                continue  # skip if SIC is missing

            # mask CHLA profile to keep only finite values
            valid_mask = np.isfinite(chla_profile)
            valid_chla = chla_profile[valid_mask]

            # separate by year
            if time[i].year < 2016:
                # Pre-2016 data
                all_sic_pre2016.extend([sic_value] * len(valid_chla))
                all_chla_pre2016.extend(valid_chla)
            else:
                # 2016 and after data
                all_sic_post2016.extend([sic_value] * len(valid_chla))
                all_chla_post2016.extend(valid_chla)

    except Exception as e:
        print(f"Skipping {seal_id} due to error: {e}")
        continue

# convert to numpy arrays
all_sic_pre2016 = np.array(all_sic_pre2016)
all_chla_pre2016 = np.array(all_chla_pre2016)
all_sic_post2016 = np.array(all_sic_post2016)
all_chla_post2016 = np.array(all_chla_post2016)

# create figure with two subplots
fig, axs = plt.subplots(1, 2, figsize=(15, 6))

# plot for Pre-2016
axs[0].scatter(all_sic_pre2016, all_chla_pre2016, s=3, alpha=0.3, color='blue', label='Pre-2016 CHLA points')
if len(all_sic_pre2016) > 2:
    slope, intercept, r_value, p_value, std_err = linregress(all_sic_pre2016, all_chla_pre2016)
    x_vals = np.linspace(0, 1, 100)
    y_vals = slope * x_vals + intercept
    axs[0].plot(x_vals, y_vals, color='blue', label=f'Fit: $R^2$={r_value**2:.2f}, p={p_value:.2g}')
axs[0].set_xlabel("Sea-Ice Concentration (SIC)")
axs[0].set_ylabel("Chlorophyll-a [mg m$^{-3}$]")
axs[0].set_title("Pre-2016: CHLA vs SIC")
axs[0].legend()

# plot for 2016 and After
axs[1].scatter(all_sic_post2016, all_chla_post2016, s=3, alpha=0.3, color='green', label='2016 and after CHLA points')
if len(all_sic_post2016) > 2:
    slope, intercept, r_value, p_value, std_err = linregress(all_sic_post2016, all_chla_post2016)
    x_vals = np.linspace(0, 1, 100)
    y_vals = slope * x_vals + intercept
    axs[1].plot(x_vals, y_vals, color='green', label=f'Fit: $R^2$={r_value**2:.2f}, p={p_value:.2g}')
axs[1].set_xlabel("Sea-Ice Concentration (SIC)")
axs[1].set_ylabel("Chlorophyll-a [mg m$^{-3}$]")
axs[1].set_title("2016 and After: CHLA vs SIC")
axs[1].legend()

# display plot
plt.tight_layout()
plt.show()

In [None]:
all_chla_pre2016 = []
all_sic_pre2016 = []
all_chla_post2016 = []
all_sic_post2016 = []

seal_ids = [k for k in seal_sic_dict.keys() if not k.endswith('_daily')]

for seal_id in seal_ids:
    try:
        # load seal data
        seal = load_single_seal(seal_id)
        chla = seal.CHLA_ADJUSTED  # shape: (n_profiles, n_levels)
        time = pd.to_datetime(seal.time.values)

        # load SIC for this seal (monthly SIC instead of daily)
        sic_monthly = seal_sic_dict.get(seal_id)
        if sic_monthly is None:
            continue

        # create pandas series for monthly SIC with the same time index as seal data
        sic_series = pd.Series(sic_monthly, index=time)

        # loop over each profile (time step)
        for i in range(chla.shape[0]):
            chla_profile = chla[i, :]
            sic_value = sic_series.iloc[i]

            if not np.isfinite(sic_value):
                continue

            valid_mask = np.isfinite(chla_profile)
            valid_chla = chla_profile[valid_mask]

            # separate by year
            if time[i].year < 2016:
                all_sic_pre2016.extend([sic_value] * len(valid_chla))
                all_chla_pre2016.extend(valid_chla)
            else:
                all_sic_post2016.extend([sic_value] * len(valid_chla))
                all_chla_post2016.extend(valid_chla)

    except Exception as e:
        print(f"Skipping {seal_id} due to error: {e}")
        continue

# convert to numpy arrays
all_sic_pre2016 = np.array(all_sic_pre2016)
all_chla_pre2016 = np.array(all_chla_pre2016)
all_sic_post2016 = np.array(all_sic_post2016)
all_chla_post2016 = np.array(all_chla_post2016)

# create figure with two subplots
fig, axs = plt.subplots(1, 2, figsize=(15, 6))

# plot for Pre-2016
axs[0].scatter(all_sic_pre2016, all_chla_pre2016, s=3, alpha=0.3, color='blue', label='Pre-2016 CHLA points')
if len(all_sic_pre2016) > 2:
    slope, intercept, r_value, p_value, std_err = linregress(all_sic_pre2016, all_chla_pre2016)
    x_vals = np.linspace(0, 2.5, 100)
    y_vals = slope * x_vals + intercept
    axs[0].plot(x_vals, y_vals, color='blue', label=f'Fit: $R^2$={r_value**2:.2f}, p={p_value:.2g}')
axs[0].set_xlabel("Sea-Ice Concentration (SIC)")
axs[0].set_ylabel("Chlorophyll-a [mg m$^{-3}$]")
axs[0].set_title("Pre-2016: CHLA vs SIC")
axs[0].legend()

# plot for 2016 and After
axs[1].scatter(all_sic_post2016, all_chla_post2016, s=3, alpha=0.3, color='green', label='2016 and after CHLA points')
if len(all_sic_post2016) > 2:
    slope, intercept, r_value, p_value, std_err = linregress(all_sic_post2016, all_chla_post2016)
    x_vals = np.linspace(0, 2.5, 100)
    y_vals = slope * x_vals + intercept
    axs[1].plot(x_vals, y_vals, color='green', label=f'Fit: $R^2$={r_value**2:.2f}, p={p_value:.2g}')
axs[1].set_xlabel("Sea-Ice Concentration (SIC)")
axs[1].set_ylabel("Chlorophyll-a [mg m$^{-3}$]")
axs[1].set_title("2016 and After: CHLA vs SIC")
axs[1].legend()

# set both axes to have the same scale (0 to 2.5)
for ax in axs:
    ax.set_xlim(0, 1.0)
    ax.set_ylim(0, 5.0)

# display plot
plt.tight_layout()
plt.show()

In [None]:
all_chla_pre2016 = []
all_sic_pre2016 = []
all_chla_post2016 = []
all_sic_post2016 = []

seal_ids = [k for k in seal_sic_dict.keys() if not k.endswith('_daily')]

for seal_id in seal_ids:
    try:
        # load seal data
        seal = load_single_seal(seal_id)
        chla = seal.CHLA_ADJUSTED  # shape: (n_profiles, n_levels)
        time = pd.to_datetime(seal.time.values)

        # load SIC for this seal (monthly SIC instead of daily)
        sic_monthly = seal_sic_dict.get(seal_id) 
        if sic_monthly is None:
            continue

        # Create a pandas series for monthly SIC with the same time index as seal data
        sic_series = pd.Series(sic_monthly, index=time)

        # loop over each profile (time step)
        for i in range(chla.shape[0]):
            chla_profile = chla[i, :]  # CHLA at all depths for this time
            sic_value = sic_series.iloc[i]  # get SIC value for this time step

            if not np.isfinite(sic_value):
                continue  # skip if SIC is missing

            # mask CHLA profile to keep only finite values
            valid_mask = np.isfinite(chla_profile)
            valid_chla = chla_profile[valid_mask]

            # separate by year
            if time[i].year < 2016:
                # Pre-2016 data
                all_sic_pre2016.extend([sic_value] * len(valid_chla))
                all_chla_pre2016.extend(valid_chla)
            else:
                # 2016 and after data
                all_sic_post2016.extend([sic_value] * len(valid_chla))
                all_chla_post2016.extend(valid_chla)

    except Exception as e:
        print(f"Skipping {seal_id} due to error: {e}")
        continue

# convert to numpy arrays
all_sic_pre2016 = np.array(all_sic_pre2016)
all_chla_pre2016 = np.array(all_chla_pre2016)
all_sic_post2016 = np.array(all_sic_post2016)
all_chla_post2016 = np.array(all_chla_post2016)

# get y-axis limits (same for both plots)
y_min = min(np.min(all_chla_pre2016), np.min(all_chla_post2016))
y_max = max(np.max(all_chla_pre2016), np.max(all_chla_post2016))

# create figure with two subplots
fig, axs = plt.subplots(1, 2, figsize=(15, 6))

# plot for Pre-2016
axs[0].scatter(all_sic_pre2016, all_chla_pre2016, s=3, alpha=0.3, color='blue', label='Pre-2016 CHLA points')
if len(all_sic_pre2016) > 2:
    slope, intercept, r_value, p_value, std_err = linregress(all_sic_pre2016, all_chla_pre2016)
    x_vals = np.linspace(0, 1, 100)
    y_vals = slope * x_vals + intercept
    axs[0].plot(x_vals, y_vals, color='blue', label=f'Fit: $R^2$={r_value**2:.2f}, p={p_value:.2g}')
axs[0].set_xlabel("Sea-Ice Concentration (SIC)")
axs[0].set_ylabel("Chlorophyll-a [mg m$^{-3}$]")
axs[0].set_title("Pre-2016: CHLA vs SIC")
axs[0].set_ylim([y_min, y_max])  # Set the same y-axis limits for both plots
axs[0].legend()

# plot for 2016 and After
axs[1].scatter(all_sic_post2016, all_chla_post2016, s=3, alpha=0.3, color='green', label='2016 and after CHLA points')
if len(all_sic_post2016) > 2:
    slope, intercept, r_value, p_value, std_err = linregress(all_sic_post2016, all_chla_post2016)
    x_vals = np.linspace(0, 1, 100)
    y_vals = slope * x_vals + intercept
    axs[1].plot(x_vals, y_vals, color='green', label=f'Fit: $R^2$={r_value**2:.2f}, p={p_value:.2g}')
axs[1].set_xlabel("Sea-Ice Concentration (SIC)")
axs[1].set_ylabel("Chlorophyll-a [mg m$^{-3}$]")
axs[1].set_title("2016 and After: CHLA vs SIC")
axs[1].set_ylim([y_min, y_max])  # Set the same y-axis limits for both plots
axs[1].legend()

# display plot
plt.tight_layout()
plt.show()

In [None]:
sic_threshold = 0.7
colors_map = {True: "lightskyblue", False: "orange"}

fig, axs = plt.subplots(1, 2, figsize=(18, 6), sharey=True)
axs[0].set_title("Pre 2016")
axs[1].set_title("Post 2016")
axs[0].set_xlabel("Time")
axs[1].set_xlabel("Time")
axs[0].set_ylabel("Chlorophyll-a [mg m$^{-3}$]")

for seal_id in seal_sic_dict.keys():
    if seal_id.endswith('_daily'):
        continue

    seal = load_single_seal(seal_id)
    seal_sic_monthly = seal_sic_dict[seal_id]
    time = seal.time
    sic = xr.DataArray(seal_sic_monthly, dims=['time'], coords=[time]).fillna(0)
    chla = seal.CHLA_ADJUSTED  # shape: (time, depth)

    pre_2016_mask = time < np.datetime64('2016-01-01')
    post_2016_mask = time >= np.datetime64('2016-01-01')

    # flatten chlorophyll for plotting individual points with time repeated per depth
    for mask, ax in zip([pre_2016_mask, post_2016_mask], axs):
        times_plot = np.repeat(time[mask].values, chla.shape[1])
        chla_plot = chla[mask, :].values.flatten()
        sic_bool_plot = np.repeat((sic.values >= sic_threshold)[mask], chla.shape[1])

        # remove NaNs for plotting
        valid = ~np.isnan(chla_plot)
        times_plot = times_plot[valid]
        chla_plot = chla_plot[valid]
        sic_bool_plot = sic_bool_plot[valid]

        ax.scatter(times_plot, chla_plot,
                   c=[colors_map[b] for b in sic_bool_plot],
                   s=5, alpha=0.5)

handles = [mpatches.Patch(color=colors_map[True], label=f"SIC >= {sic_threshold}"),
           mpatches.Patch(color=colors_map[False], label=f"SIC < {sic_threshold}")]
axs[1].legend(handles=handles, loc='upper right', title='Sea Ice Concentration')

# display plot
plt.suptitle("Chlorophyll-a Points for All Seals Colored by Monthly SIC Threshold")
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

In [None]:
sic_threshold = 0.7
colors_map = {True: "lightskyblue", False: "orange"}

fig, axs = plt.subplots(1, 2, figsize=(18, 6), sharey=True)
axs[0].set_title("Pre 2016")
axs[1].set_title("Post 2016")
axs[0].set_xlabel("Monthly Sea Ice Concentration")
axs[1].set_xlabel("Monthly Sea Ice Concentration")
axs[0].set_ylabel("Chlorophyll-a [mg m$^{-3}$]")

for seal_id in seal_sic_dict.keys():
    if seal_id.endswith('_daily'):
        continue

    seal = load_single_seal(seal_id)
    seal_sic_monthly = seal_sic_dict[seal_id]
    time = seal.time
    sic = xr.DataArray(seal_sic_monthly, dims=['time'], coords=[time]).fillna(0)
    chla = seal.CHLA_ADJUSTED  # shape: (time, depth)

    pre_2016_mask = time < np.datetime64('2016-01-01')
    post_2016_mask = time >= np.datetime64('2016-01-01')

    for mask, ax in zip([pre_2016_mask, post_2016_mask], axs):
        # repeat SIC + chlorophyll for each depth level
        sic_plot = np.repeat(sic.values[mask], chla.shape[1])
        chla_plot = chla[mask, :].values.flatten()
        sic_bool_plot = np.repeat((sic.values >= sic_threshold)[mask], chla.shape[1])

        # remove NaNs
        valid = ~np.isnan(chla_plot)
        sic_plot = sic_plot[valid]
        chla_plot = chla_plot[valid]
        sic_bool_plot = sic_bool_plot[valid]

        ax.scatter(sic_plot, chla_plot,
                   c=[colors_map[b] for b in sic_bool_plot],
                   s=5, alpha=0.5)

handles = [mpatches.Patch(color=colors_map[True], label=f"SIC >= {sic_threshold}"),
           mpatches.Patch(color=colors_map[False], label=f"SIC < {sic_threshold}")]
axs[1].legend(handles=handles, loc='upper right', title='Sea Ice Concentration')

# display plot
plt.suptitle("Chlorophyll-a Points for All Seals\nColored by Monthly SIC Threshold")
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

#### Heatmaps (CHLA vs SIC) for All Seals

Heatmaps showing number of measurements for each SIC concentration and CHLA value range (binned) as well as each SIC concentration and depth value range (binned), shows mean and median to account for outliers that may influence mean.

In [None]:
def plot_all_seals_together(seal_sic_dict, num_bins=10):
    # create figure _ axis
    fig, ax = plt.subplots(figsize=(12, 8))

    # prepare lists to store all time, SIC, + CHLA values
    all_time = []
    all_sic = []
    all_chla = []

    # iterate over all seals + collect their data
    for seal_id in seal_sic_dict.keys():
        if '_daily' in seal_id:
            seal_id_monthly = seal_id.replace('_daily', '')
            if seal_id_monthly in seal_sic_dict:
                seal = load_single_seal(seal_id_monthly)
                seal_sic_monthly = seal_sic_dict[seal_id_monthly]
                seal_sic_monthly = xr.DataArray(seal_sic_monthly, dims=['time'], coords=[seal.time]).fillna(0)
                chla = seal.CHLA_ADJUSTED
                time = seal.time

                # convert datetime objects to numerical format (e.g. number of days since  start)
                time_numeric = (time - time[0]).dt.days

                # collect data for plotting
                all_time.extend(time_numeric.values)
                all_sic.extend(seal_sic_monthly.values)
                all_chla.extend(chla.values.flatten() if chla.ndim > 1 else chla.values)

    # convert collected data to numpy arrays
    all_time = np.array(all_time)
    all_sic = np.array(all_sic)
    all_chla = np.array(all_chla)

    # bin time + SIC data
    time_bins = np.linspace(all_time.min(), all_time.max(), num_bins + 1)
    sic_bins = np.linspace(all_sic.min(), all_sic.max(), num_bins + 1)

    # initialize array to store average CHLA values for each SIC-time bin
    avg_chla = np.zeros((num_bins, num_bins))
    count_chla = np.zeros((num_bins, num_bins))

    # bin data based on SIC + time
    time_bin_indices = np.digitize(all_time, time_bins) - 1
    sic_bin_indices = np.digitize(all_sic, sic_bins) - 1

    # update average CHLA + count for each bin
    for t_bin, s_bin, chla_value in zip(time_bin_indices, sic_bin_indices, all_chla):
        if 0 <= t_bin < num_bins and 0 <= s_bin < num_bins:
            avg_chla[s_bin, t_bin] += chla_value
            count_chla[s_bin, t_bin] += 1

    # normalize average CHLA by number of points in each bin
    avg_chla[count_chla > 0] /= count_chla[count_chla > 0]

    # create heatmap using pcolormesh
    time_bin_centers = (time_bins[:-1] + time_bins[1:]) / 2
    sic_bin_centers = (sic_bins[:-1] + sic_bins[1:]) / 2

    c = ax.pcolormesh(time_bin_centers, sic_bin_centers, avg_chla, cmap='YlGn', shading='auto')

    # add colorbar for average CHLA with units
    cbar = plt.colorbar(c, ax=ax)
    cbar.set_label('Average CHLA (mg/m³)')  # Added units for CHLA

    # format x-axis to show actual dates
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=1))  # Set major ticks to every month
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))  # Format the date to Year-Month

    # add labels + title
    ax.set_xlabel('Time')
    ax.set_ylabel('Monthly SIC')
    ax.set_title('All Seals - Monthly SIC vs Time with Average CHLA')
    # Rotate the x-axis labels for better readability
    plt.xticks(rotation=45)
    plt.tight_layout()

    # display plot
    plt.show()

# call function to generate the plot for all seals together with smaller bins
plot_all_seals_together(seal_sic_dict, num_bins=40)

In [None]:
def plot_all_seals_together(seal_sic_dict, num_bins=10):
    # create figure + axis
    fig, ax = plt.subplots(figsize=(60, 8))

    # prepare lists to store all time, SIC, + CHLA values
    all_time = []
    all_sic = []
    all_chla = []

    # iterate over all seals + collect their data
    for seal_id in seal_sic_dict.keys():
        if '_daily' in seal_id:
            seal_id_monthly = seal_id.replace('_daily', '')
            if seal_id_monthly in seal_sic_dict:
                seal = load_single_seal(seal_id_monthly)
                seal_sic_monthly = seal_sic_dict[seal_id_monthly]
                seal_sic_monthly = xr.DataArray(seal_sic_monthly, dims=['time'], coords=[seal.time]).fillna(0)
                chla = seal.CHLA_ADJUSTED
                time = seal.time

                # collect data for plotting
                all_time.extend(time.values)  # Use original datetime values
                all_sic.extend(seal_sic_monthly.values)
                all_chla.extend(chla.values.flatten() if chla.ndim > 1 else chla.values)

    # convert collected data to numpy arrays
    all_time = np.array(all_time, dtype='datetime64[ns]')  # Ensure time is in datetime format
    all_sic = np.array(all_sic)
    all_chla = np.array(all_chla)

    # bin time + SIC data
    time_bins = np.linspace(all_time.min().astype('datetime64[D]').astype(int),
                            all_time.max().astype('datetime64[D]').astype(int),
                            num_bins + 1)
    sic_bins = np.linspace(all_sic.min(), all_sic.max(), num_bins + 1)

    # initialize array to store average CHLA values for each SIC-time bin
    avg_chla = np.zeros((num_bins, num_bins))
    count_chla = np.zeros((num_bins, num_bins))

    # bin data based on SIC + time
    time_bin_indices = np.digitize(all_time.astype('datetime64[D]').astype(int), time_bins) - 1
    sic_bin_indices = np.digitize(all_sic, sic_bins) - 1

    # update average CHLA + count for each bin
    for t_bin, s_bin, chla_value in zip(time_bin_indices, sic_bin_indices, all_chla):
        if 0 <= t_bin < num_bins and 0 <= s_bin < num_bins:
            avg_chla[s_bin, t_bin] += chla_value
            count_chla[s_bin, t_bin] += 1

    # normalize average CHLA by number of points in each bin
    avg_chla[count_chla > 0] /= count_chla[count_chla > 0]

    # create heatmap using pcolormesh
    time_bin_centers = (time_bins[:-1] + time_bins[1:]) / 2
    sic_bin_centers = (sic_bins[:-1] + sic_bins[1:]) / 2

    c = ax.pcolormesh(time_bin_centers, sic_bin_centers, avg_chla, cmap='YlGn', shading='auto')

    # add colorbar for average CHLA with units
    cbar = plt.colorbar(c, ax=ax)
    cbar.set_label('Average CHLA (mg/m³)')

    # format x-axis to show actual dates
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=1))  # Set major ticks to every month
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))  # Format the date to Year-Month

    # add labels + title
    ax.set_xlabel('Time')
    ax.set_ylabel('Monthly SIC')
    ax.set_title('All Seals - Monthly SIC vs Time with Average CHLA')

    # rotate x-axis labels for better readability
    plt.xticks(rotation=45)
    plt.tight_layout()

    # display plot
    plt.show()

# call function to generate plot for all seals together with smaller bins
plot_all_seals_together(seal_sic_dict, num_bins=100)

In [None]:
def plot_all_seals_together(seal_sic_dict, num_bins=10):
    # create figure + axis with increased width
    fig, ax = plt.subplots(figsize=(16, 8))  # increased width to 16

    # prepare lists to store all time, SIC, + CHLA values
    all_time = []
    all_sic = []
    all_chla = []

    # iterate over all seals + collect their data
    for seal_id in seal_sic_dict.keys():
        if '_daily' in seal_id:
            seal_id_monthly = seal_id.replace('_daily', '')
            if seal_id_monthly in seal_sic_dict:
                seal = load_single_seal(seal_id_monthly)
                seal_sic_monthly = seal_sic_dict[seal_id_monthly]
                seal_sic_monthly = xr.DataArray(seal_sic_monthly, dims=['time'], coords=[seal.time]).fillna(0)
                chla = seal.CHLA_ADJUSTED
                time = seal.time

                # collect data for plotting, excluding zero CHLA values
                for t, s, chla_value in zip(time.values, seal_sic_monthly.values, chla.values.flatten() if chla.ndim > 1 else chla.values):
                    if chla_value > 0:  # only include non-zero CHLA values
                        all_time.append(t)
                        all_sic.append(s)
                        all_chla.append(chla_value)

    # convert collected data to numpy arrays
    all_time = np.array(all_time, dtype='datetime64[ns]')  # ensure time is in datetime format
    all_sic = np.array(all_sic)
    all_chla = np.array(all_chla)

    # bin time + SIC data
    time_bins = np.linspace(all_time.min().astype('datetime64[D]').astype(int),
                            all_time.max().astype('datetime64[D]').astype(int),
                            num_bins + 1)
    sic_bins = np.linspace(all_sic.min(), all_sic.max(), num_bins + 1)

    # initialize array to store average CHLA values for each SIC-time bin
    avg_chla = np.zeros((num_bins, num_bins))
    count_chla = np.zeros((num_bins, num_bins))

    # bin data based on SIC + time
    time_bin_indices = np.digitize(all_time.astype('datetime64[D]').astype(int), time_bins) - 1
    sic_bin_indices = np.digitize(all_sic, sic_bins) - 1

    # update average CHLA + count for each bin
    for t_bin, s_bin, chla_value in zip(time_bin_indices, sic_bin_indices, all_chla):
        if 0 <= t_bin < num_bins and 0 <= s_bin < num_bins:
            avg_chla[s_bin, t_bin] += chla_value
            count_chla[s_bin, t_bin] += 1

    # normalize average CHLA by number of points in each bin
    avg_chla[count_chla > 0] /= count_chla[count_chla > 0]

    # create heatmap using pcolormesh
    time_bin_centers = (time_bins[:-1] + time_bins[1:]) / 2
    sic_bin_centers = (sic_bins[:-1] + sic_bins[1:]) / 2

    c = ax.pcolormesh(time_bin_centers, sic_bin_centers, avg_chla, cmap='YlGn', shading='auto')

    # add colorbar for average CHLA with units
    cbar = plt.colorbar(c, ax=ax)
    cbar.set_label('Average CHLA (mg/m³)')

    # format x-axis to show actual dates
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=6))  # Set major ticks to every month
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))  # Format the date to Year-Month

    # add labels + title
    ax.set_xlabel('Time')
    ax.set_ylabel('Monthly SIC')
    ax.set_title('All Seals - Monthly SIC vs Time with Average CHLA')

    # rotate x-axis labels for better readability
    plt.xticks(rotation=45)
    plt.tight_layout()

    # display plot
    plt.show()

# call function to generate plot for all seals together with smaller bins
plot_all_seals_together(seal_sic_dict, num_bins=100)

In [None]:
def plot_all_seals_together(seal_sic_dict, num_bins=10):
    fig, ax = plt.subplots(figsize=(16, 8))

    all_time = []
    all_sic = []
    all_chla = []
    all_pressure = []

    for seal_id in seal_sic_dict.keys():
        if '_daily' in seal_id:
            seal_id_monthly = seal_id.replace('_daily', '')
            if seal_id_monthly in seal_sic_dict:
                seal = load_single_seal(seal_id_monthly)
                seal_sic_monthly = seal_sic_dict[seal_id_monthly]
                seal_sic_monthly = xr.DataArray(seal_sic_monthly, dims=['time'], coords=[seal.time]).fillna(0)
                chla = seal.CHLA_ADJUSTED
                pressure = seal.PRES_ADJUSTED
                time = seal.time

                # flatten chla + pressure arrays
                chla_flat = chla.values.flatten()
                pressure_flat = pressure.values.flatten()

                # repeat time + SIC to match flattened arrays length
                time_repeated = np.repeat(time.values, chla.shape[1])
                sic_repeated = np.repeat(seal_sic_monthly.values, chla.shape[1])

                for t, s, chla_value, p in zip(time_repeated, sic_repeated, chla_flat, pressure_flat):
                    if chla_value > 0:
                        all_time.append(t)
                        all_sic.append(s)
                        all_chla.append(chla_value)
                        all_pressure.append(p)

    all_time = np.array(all_time, dtype='datetime64[ns]')
    all_sic = np.array(all_sic)
    all_chla = np.array(all_chla)
    all_pressure = np.array(all_pressure)

    depth = all_pressure  # 1 dbar ~ 1 meter

    # bin SIC + depth instead of time + SIC 
    sic_bins = np.linspace(all_sic.min(), all_sic.max(), num_bins + 1)
    depth_bins = np.linspace(depth.min(), depth.max(), num_bins + 1)

    avg_chla = np.zeros((num_bins, num_bins))
    count_chla = np.zeros((num_bins, num_bins))

    sic_bin_indices = np.digitize(all_sic, sic_bins) - 1
    depth_bin_indices = np.digitize(depth, depth_bins) - 1

    for s_bin, d_bin, chla_value in zip(sic_bin_indices, depth_bin_indices, all_chla):
        if 0 <= s_bin < num_bins and 0 <= d_bin < num_bins:
            avg_chla[d_bin, s_bin] += chla_value
            count_chla[d_bin, s_bin] += 1

    avg_chla[count_chla > 0] /= count_chla[count_chla > 0]

    sic_bin_centers = (sic_bins[:-1] + sic_bins[1:]) / 2
    depth_bin_centers = (depth_bins[:-1] + depth_bins[1:]) / 2

    c = ax.pcolormesh(sic_bin_centers, depth_bin_centers, avg_chla, cmap='YlGn', shading='auto')

    cbar = plt.colorbar(c, ax=ax)
    cbar.set_label('Average CHLA (mg/m³)')

    ax.set_xlabel('Sea Ice Concentration (SIC)')
    ax.set_ylabel('Depth (m)')
    ax.set_title('All Seals - SIC vs Depth with Average CHLA')
    ax.invert_yaxis()  # depth increases downward

    plt.tight_layout()
    plt.show()


# call function to generate plot for all seals together with smaller bins
plot_all_seals_together(seal_sic_dict, num_bins=100)

In [None]:
def plot_all_seals_together(seal_sic_dict, num_bins=10):
    fig, ax = plt.subplots(figsize=(16, 8))

    all_time = []
    all_sic = []
    all_chla = []
    all_pressure = []

    for seal_id in seal_sic_dict.keys():
        if '_daily' in seal_id:
            seal_id_monthly = seal_id.replace('_daily', '')
            if seal_id_monthly in seal_sic_dict:
                seal = load_single_seal(seal_id_monthly)
                seal_sic_monthly = seal_sic_dict[seal_id_monthly]
                seal_sic_monthly = xr.DataArray(seal_sic_monthly, dims=['time'], coords=[seal.time]).fillna(0)
                chla = seal.CHLA_ADJUSTED
                pressure = seal.PRES_ADJUSTED
                time = seal.time

                # flatten chla + pressure arrays
                chla_flat = chla.values.flatten()
                pressure_flat = pressure.values.flatten()

                # repeat time + SIC to match flattened arrays length
                time_repeated = np.repeat(time.values, chla.shape[1])
                sic_repeated = np.repeat(seal_sic_monthly.values, chla.shape[1])

                for t, s, chla_value, p in zip(time_repeated, sic_repeated, chla_flat, pressure_flat):
                    if chla_value > 0:
                        all_time.append(t)
                        all_sic.append(s)
                        all_chla.append(chla_value)
                        all_pressure.append(p)

    all_time = np.array(all_time, dtype='datetime64[ns]')
    all_sic = np.array(all_sic)
    all_chla = np.array(all_chla)
    all_pressure = np.array(all_pressure)

    depth = all_pressure  # 1 dbar ~ 1 meter

    # bin SIC + depth instead of time + SIC 
    sic_bins = np.linspace(all_sic.min(), all_sic.max(), num_bins + 1)
    depth_bins = np.linspace(depth.min(), depth.max(), num_bins + 1)

    avg_chla = np.zeros((num_bins, num_bins))
    count_chla = np.zeros((num_bins, num_bins))

    sic_bin_indices = np.digitize(all_sic, sic_bins) - 1
    depth_bin_indices = np.digitize(depth, depth_bins) - 1

    for s_bin, d_bin, chla_value in zip(sic_bin_indices, depth_bin_indices, all_chla):
        if 0 <= s_bin < num_bins and 0 <= d_bin < num_bins:
            avg_chla[d_bin, s_bin] += chla_value
            count_chla[d_bin, s_bin] += 1

    avg_chla[count_chla > 0] /= count_chla[count_chla > 0]

    sic_bin_centers = (sic_bins[:-1] + sic_bins[1:]) / 2
    depth_bin_centers = (depth_bins[:-1] + depth_bins[1:]) / 2

    c = ax.pcolormesh(sic_bin_centers, depth_bin_centers, avg_chla, cmap='YlGn', shading='auto', vmax = 5)

    cbar = plt.colorbar(c, ax=ax)
    cbar.set_label('Average CHLA (mg/m³)')

    ax.set_xlabel('Sea Ice Concentration (SIC)')
    ax.set_ylabel('Depth (m)')
    ax.set_title('All Seals - SIC vs Depth with Average CHLA')
    ax.invert_yaxis()  # depth increases downward

    plt.tight_layout()
    plt.show()


# call function to generate plot for all seals together with smaller bins
plot_all_seals_together(seal_sic_dict, num_bins=100)

#### Seasonal Plots

In [None]:
def plot_seasonal_seals(seal_sic_dict, num_bins=10):
    # define seasons in Southern Ocean
    seasons = {
        'Spring (Sept - Nov)': (8, 10),   # Sept = 8, Nov = 10 (0-indexed)
        'Summer (Dec - Feb)': (11, 1),    # Dec = 11, Feb = 1
        'Autumn (Mar - May)': (2, 4),     # Mar = 2, May = 4
        'Winter (Jun - Aug)': (5, 7),     # Jun = 5, Aug = 7
    }

    fig, axes = plt.subplots(2, 2, figsize=(16, 16))  # 2x2 grid of subplots
    axes = axes.flatten()  # flatten to access easily in a loop

    for season, (start_month, end_month) in seasons.items():
        ax = axes[list(seasons.keys()).index(season)]
        all_time = []
        all_sic = []
        all_chla = []
        all_pressure = []

        for seal_id in seal_sic_dict.keys():
            if '_daily' in seal_id:
                seal_id_monthly = seal_id.replace('_daily', '')
                if seal_id_monthly in seal_sic_dict:
                    seal = load_single_seal(seal_id_monthly)
                    seal_sic_monthly = seal_sic_dict[seal_id_monthly]
                    seal_sic_monthly = xr.DataArray(seal_sic_monthly, dims=['time'], coords=[seal.time]).fillna(0)
                    chla = seal.CHLA_ADJUSTED
                    pressure = seal.PRES_ADJUSTED
                    time = seal.time

                    # filter by season based on time of year
                    for t, s, chla_value, p in zip(time.values, seal_sic_monthly.values, chla.values.flatten(), pressure.values.flatten()):
                        if chla_value > 0:
                            # convert numpy.datetime64 to pandas datetime object
                            t = pd.to_datetime(t)  # convert to pandas datetime
                            month = t.month  # get month (1-indexed)

                            # check if month falls within defined season's range
                            if (start_month <= month <= end_month) or (start_month <= month <= 12 and month <= end_month):
                                all_time.append(t)
                                all_sic.append(s)
                                all_chla.append(chla_value)
                                all_pressure.append(p)

        if not all_time:  
            continue

        all_time = np.array(all_time, dtype='datetime64[ns]')
        all_sic = np.array(all_sic)
        all_chla = np.array(all_chla)
        all_pressure = np.array(all_pressure)

        depth = all_pressure  # 1 dbar ~ 1 meter

        # bin SIC + depth
        sic_bins = np.linspace(all_sic.min(), all_sic.max(), num_bins + 1)
        depth_bins = np.linspace(depth.min(), depth.max(), num_bins + 1)

        avg_chla = np.zeros((num_bins, num_bins))
        count_chla = np.zeros((num_bins, num_bins))

        sic_bin_indices = np.digitize(all_sic, sic_bins) - 1
        depth_bin_indices = np.digitize(depth, depth_bins) - 1

        for s_bin, d_bin, chla_value in zip(sic_bin_indices, depth_bin_indices, all_chla):
            if 0 <= s_bin < num_bins and 0 <= d_bin < num_bins:
                avg_chla[d_bin, s_bin] += chla_value
                count_chla[d_bin, s_bin] += 1

        avg_chla[count_chla > 0] /= count_chla[count_chla > 0]

        sic_bin_centers = (sic_bins[:-1] + sic_bins[1:]) / 2
        depth_bin_centers = (depth_bins[:-1] + depth_bins[1:]) / 2

        c = ax.pcolormesh(sic_bin_centers, depth_bin_centers, avg_chla, cmap='YlGn', shading='auto', vmax = 5)
        cbar = plt.colorbar(c, ax=ax)
        cbar.set_label('Average CHLA (mg/m³)')

        ax.set_xlabel('Sea Ice Concentration (SIC)')
        ax.set_ylabel('Depth (m)')
        ax.set_title(f'{season} - SIC vs Depth with Average CHLA')
        ax.invert_yaxis()  # depth increases downward

    plt.tight_layout()
    plt.show()


# call function to generate plot for all seals, separated by season
plot_seasonal_seals(seal_sic_dict, num_bins=100)

In [None]:
def plot_seasonal_seals(seal_sic_dict, num_bins=10):
    # define seasons in Southern Ocean
    seasons = {
        'Spring (Sept - Nov)': (8, 10),   # Sept = 8, Nov = 10 (0-indexed)
        'Summer (Dec - Feb)': (11, 1),    # Dec = 11, Feb = 1
        'Autumn (Mar - May)': (2, 4),     # Mar = 2, May = 4
        'Winter (Jun - Aug)': (5, 7),     # Jun = 5, Aug = 7
    }

    fig, axes = plt.subplots(2, 2, figsize=(16, 16))  # 2x2 grid of subplots
    axes = axes.flatten()  # flatten to access easily in a loop

    for season, (start_month, end_month) in seasons.items():
        ax = axes[list(seasons.keys()).index(season)]
        all_time = []
        all_sic = []
        all_chla = []
        all_pressure = []

        for seal_id in seal_sic_dict.keys():
            if '_daily' in seal_id:
                seal_id_monthly = seal_id.replace('_daily', '')
                if seal_id_monthly in seal_sic_dict:
                    seal = load_single_seal(seal_id_monthly)
                    seal_sic_monthly = seal_sic_dict[seal_id_monthly]
                    seal_sic_monthly = xr.DataArray(seal_sic_monthly, dims=['time'], coords=[seal.time]).fillna(0)
                    chla = seal.CHLA_ADJUSTED
                    pressure = seal.PRES_ADJUSTED
                    time = seal.time

                    # filter by season based on time of year
                    for t, s, chla_value, p in zip(time.values, seal_sic_monthly.values, chla.values.flatten(), pressure.values.flatten()):
                        if chla_value > 0:
                            # convert numpy.datetime64 to pandas datetime object
                            t = pd.to_datetime(t)  # convert to pandas datetime
                            month = t.month  # get month (1-indexed)

                            # handle summer season (dec-feb) wrap-around by checking if month is in range of dec (11) to feb (1)
                            if season == 'Summer (Dec - Feb)':
                                if month == 12 or month <= 2:  # dec to feb
                                    all_time.append(t)
                                    all_sic.append(s)
                                    all_chla.append(chla_value)
                                    all_pressure.append(p)
                            else:
                                # handle all other seasons as usual
                                if (start_month <= month <= end_month) or (start_month <= month <= 12 and month <= end_month):
                                    all_time.append(t)
                                    all_sic.append(s)
                                    all_chla.append(chla_value)
                                    all_pressure.append(p)

        if not all_time:  
            continue

        all_time = np.array(all_time, dtype='datetime64[ns]')
        all_sic = np.array(all_sic)
        all_chla = np.array(all_chla)
        all_pressure = np.array(all_pressure)

        depth = all_pressure  # 1 dbar ~ 1 meter

        # bin SIC + depth
        sic_bins = np.linspace(all_sic.min(), all_sic.max(), num_bins + 1)
        depth_bins = np.linspace(depth.min(), depth.max(), num_bins + 1)

        avg_chla = np.zeros((num_bins, num_bins))
        count_chla = np.zeros((num_bins, num_bins))

        sic_bin_indices = np.digitize(all_sic, sic_bins) - 1
        depth_bin_indices = np.digitize(depth, depth_bins) - 1

        for s_bin, d_bin, chla_value in zip(sic_bin_indices, depth_bin_indices, all_chla):
            if 0 <= s_bin < num_bins and 0 <= d_bin < num_bins:
                avg_chla[d_bin, s_bin] += chla_value
                count_chla[d_bin, s_bin] += 1

        avg_chla[count_chla > 0] /= count_chla[count_chla > 0]

        sic_bin_centers = (sic_bins[:-1] + sic_bins[1:]) / 2
        depth_bin_centers = (depth_bins[:-1] + depth_bins[1:]) / 2

        c = ax.pcolormesh(sic_bin_centers, depth_bin_centers, avg_chla, cmap='YlGn', shading='auto', vmax = 5)
        cbar = plt.colorbar(c, ax=ax)
        cbar.set_label('Average CHLA (mg/m³)')

        ax.set_xlabel('Sea Ice Concentration (SIC)')
        ax.set_ylabel('Depth (m)')
        ax.set_title(f'{season} - SIC vs Depth with Average CHLA')
        ax.invert_yaxis()  # depth increases downward

    plt.tight_layout()
    plt.show()


# call function to generate plot for all seals, separated by season
plot_seasonal_seals(seal_sic_dict, num_bins=100)

In [None]:
def plot_all_seals_by_season(seal_sic_dict, num_bins=10):
    """
    Plots CHLA vs. SIC vs. Depth for all seals, separated by season
    (Southern Ocean: DJF=summer, MAM=autumn, JJA=winter, SON=spring)
    """

    # define Southern Hemisphere seasons
    season_months = {
        "Summer (Dec–Feb)": [12, 1, 2],
        "Autumn (Mar–May)": [3, 4, 5],
        "Winter (Jun–Aug)": [6, 7, 8],
        "Spring (Sep–Nov)": [9, 10, 11],
    }

    # collect all data
    all_time, all_sic, all_chla, all_pressure = [], [], [], []

    for seal_id in seal_sic_dict.keys():
        if '_daily' in seal_id:
            seal_id_monthly = seal_id.replace('_daily', '')
            if seal_id_monthly in seal_sic_dict:
                seal = load_single_seal(seal_id_monthly)
                seal_sic_monthly = seal_sic_dict[seal_id_monthly]
                seal_sic_monthly = xr.DataArray(seal_sic_monthly, dims=['time'], coords=[seal.time]).fillna(0)

                chla = seal.CHLA_ADJUSTED
                pressure = seal.PRES_ADJUSTED
                time = seal.time

                # flatten CHLA + pressure arrays
                chla_flat = chla.values.flatten()
                pressure_flat = pressure.values.flatten()

                # repeat time + SIC to match flattened arrays
                time_repeated = np.repeat(time.values, chla.shape[1])
                sic_repeated = np.repeat(seal_sic_monthly.values, chla.shape[1])

                # append valid CHLA points
                for t, s, chla_value, p in zip(time_repeated, sic_repeated, chla_flat, pressure_flat):
                    if chla_value > 0:
                        all_time.append(t)
                        all_sic.append(s)
                        all_chla.append(chla_value)
                        all_pressure.append(p)

    # convert to arrays
    all_time = np.array(all_time, dtype='datetime64[ns]')
    all_sic = np.array(all_sic)
    all_chla = np.array(all_chla)
    all_pressure = np.array(all_pressure)
    all_depth = all_pressure  # 1 dbar ≈ 1 m

    # prepare figure
    fig, axs = plt.subplots(2, 2, figsize=(18, 10))
    axs = axs.flatten()

    # loop over seasons
    for ax, (season_name, months) in zip(axs, season_months.items()):
        # mask data for current season
        months_array = np.array([np.datetime64(t, 'M').astype(object).month for t in all_time])
        season_mask = np.isin(months_array, months)

        if not np.any(season_mask):
            ax.set_title(f"{season_name} (no data)")
            ax.axis('off')
            continue

        time = all_time[season_mask]
        sic = all_sic[season_mask]
        chla = all_chla[season_mask]
        depth = all_depth[season_mask]

        # bin data
        sic_bins = np.linspace(sic.min(), sic.max(), num_bins + 1)
        depth_bins = np.linspace(depth.min(), depth.max(), num_bins + 1)
        chla_values_bin = {(i, j): [] for i in range(num_bins) for j in range(num_bins)}

        sic_bin_indices = np.digitize(sic, sic_bins) - 1
        depth_bin_indices = np.digitize(depth, depth_bins) - 1

        for s_bin, d_bin, chla_value in zip(sic_bin_indices, depth_bin_indices, chla):
            if 0 <= s_bin < num_bins and 0 <= d_bin < num_bins:
                chla_values_bin[(d_bin, s_bin)].append(chla_value)

        # compute median CHLA for each bin
        median_chla = np.full((num_bins, num_bins), np.nan)
        for (d_bin, s_bin), chla_list in chla_values_bin.items():
            if chla_list:
                median_chla[d_bin, s_bin] = np.median(chla_list)

        # plot
        sic_bin_centers = (sic_bins[:-1] + sic_bins[1:]) / 2
        depth_bin_centers = (depth_bins[:-1] + depth_bins[1:]) / 2

        c = ax.pcolormesh(sic_bin_centers, depth_bin_centers, median_chla,
                          cmap='YlGn', shading='auto', vmax=5)
        ax.set_title(season_name)
        ax.set_xlabel('Sea Ice Concentration (SIC)')
        ax.set_ylabel('Depth (m)')
        ax.invert_yaxis()

    # shared colorbar
    cbar = fig.colorbar(c, ax=axs, orientation='vertical', fraction=0.025, pad=0.02)
    cbar.set_label('Median CHLA (mg/m³)')

    plt.suptitle('All Seals - SIC vs Depth with Median CHLA by Season', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

# call function
plot_all_seals_by_season(seal_sic_dict, num_bins=100)