In [None]:
import os
import sys
from datetime import datetime

import geopandas as gpd
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import netCDF4 as nc
import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
LOCATIONS_DIR = "../../locs/"
DATA_DIR = "../../datasets/modis/sea_surface_temp"

In [None]:
def geojson_context_figure(files: list[str]):
    ## plot the geojson regions over a world map for checking
    world = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))  # type: ignore

    for file in files:
        gdf = gpd.read_file(file)  # load the region shape

        fig, ax = plt.subplots(1, 1, figsize=(10, 6))
        world.plot(ax=ax, color="lightgrey")
        gdf.plot(ax=ax, edgecolor="red", facecolor="none")
        ax.set_axis_off()
        output_path = file.replace(".geojson", ".png")
        plt.savefig(output_path, bbox_inches="tight")
        plt.close()

In [None]:
def plot_data(region_name, data, latitude, longitude, start_date, end_date):
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.set_aspect("equal")
    # Set vmax to 0.75 for colorbar maximum
    plt.pcolormesh(longitude, latitude, data, shading="auto", vmin=0, vmax=0.75)
    plt.colorbar(label="Chlorophyll-a concentration", extend="max")
    plt.title(f"Region: {region_name}\n{start_date} to {end_date}")
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")
    plt.savefig(f"./plots/{region_name}/{start_date}_{end_date}.png")
    plt.close()

In [None]:
def plot_data_mean(region_name: str, mean_values: dict[str, float], ylabel: str = "default ylabel", title: str = "default title"):
    # Create a new figure for the time series plot
    fig, ax = plt.subplots(figsize=(8, 6))

    # Extract dates and mean values
    dates = [
        datetime.strptime(date_str[:8], "%Y%m%d") for date_str in mean_values.keys()
    ]
    values = list(mean_values.values())

    # Plot the time series
    ax.plot_date(dates, values, fmt="-")

    # Set axis labels and title
    ax.set_xlabel("Time")
    ax.set_ylabel(ylabel)
    ax.set_title(title)

    # Rotate x-axis labels for better visibility
    plt.xticks(rotation=45)

    # Adjust x-axis tick locator to show dates
    ax.xaxis.set_major_locator(mdates.MonthLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))

    plt.savefig(f"./plots/{region_name}/{region_name}_mean.png")
    plt.close()

In [None]:
def plot_time_series(data_series, region_name):
    dates = [info["start_date"] for info in data_series]
    # Ensure that the data is copied from the original dataset to a writable array
    means = [
        np.nanmean(np.array(info["data"]))
        for info in data_series
        if info["data"].size > 0
    ]

    plt.figure(figsize=(10, 5))
    plt.plot(dates, means, marker="o", linestyle="-")
    plt.title(f"Time Series of Mean Chlorophyll-a Concentration for {region_name}")
    plt.xlabel("Date")
    plt.ylabel("Mean Chlorophyll-a Concentration")
    plt.grid(True)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(f"./plots/{region_name}/time_series.png")
    plt.close()

In [None]:
def write_to_csv(data_series, region_name):
    try:
        valid_data_series = [
            d
            for d in data_series
            if "start_date" in d and "end_date" in d and d["data"].size > 0
        ]

        if not valid_data_series:
            print(f"No valid data to process for {region_name}.")
            return

        # Define columns as just the months in a single year
        months = [
            "Jan",
            "Feb",
            "Mar",
            "Apr",
            "May",
            "Jun",
            "Jul",
            "Aug",
            "Sep",
            "Oct",
            "Nov",
            "Dec",
        ]

        # Create an empty DataFrame with years as index and months as columns
        years = sorted(set([d["start_date"].year for d in valid_data_series]))
        df = pd.DataFrame(index=years, columns=months)

        for entry in valid_data_series:
            data_copy = np.array(entry["data"]).copy()
            mean_value = (
                np.nanmean(data_copy) if np.any(~np.isnan(data_copy)) else np.nan
            )
            month = entry["start_date"].strftime("%b")
            year = entry["start_date"].year
            # Set the mean value at the correct year and month
            df.at[year, month] = mean_value

        # Remove any rows that are entirely NaN
        df.dropna(how="all", inplace=True)

        # Ensure directory exists
        os.makedirs(f"./data_csv/{region_name}", exist_ok=True)

        # Save to CSV
        csv_path = (
            f"./data_csv/{region_name}/chlorophyll_monthly_means_{region_name}.csv"
        )
        df.to_csv(csv_path, index_label="year")
        print(f"CSV file created for {region_name} at {csv_path}")
    except Exception as e:
        print(f"Failed to process data for {region_name}. Error: {e}")

In [None]:
def get_files(dir_path: str, file_type: str):
    # find all the files in a directory
    files = [
        os.path.join(dir_path, file)
        for file in os.listdir(dir_path)
        if file.endswith("." + file_type)
    ]
    return files

In [None]:
## import the GeoJSON files which contains the coordinates of the region(s) of interest
# find all the geojson files in a directory

locs_files = get_files(LOCATIONS_DIR, "geojson")
geojson_context_figure(locs_files)

# read the GeoJSON files
gdf_list = [gpd.read_file(file) for file in locs_files]

# set the limits of the region(s) of interest with lists for each min and max
x_min_list = [gdf.total_bounds[0] for gdf in gdf_list]
y_min_list = [gdf.total_bounds[1] for gdf in gdf_list]
x_max_list = [gdf.total_bounds[2] for gdf in gdf_list]
y_max_list = [gdf.total_bounds[3] for gdf in gdf_list]

In [None]:
data_files = get_files(DATA_DIR, "nc")
data_files.sort()

# parse the start and end date information from the filename (AQUA_MODIS.20210101_20210131.L3m.MO.CHL.chlor_a.4km.nc) of each file and convert it to a datetime object
start_date = [
    datetime.strptime(file.split("/")[-1].split(".")[1].split("_")[0], "%Y%m%d")
    for file in data_files
]
end_date = [
    datetime.strptime(file.split("/")[-1].split(".")[1].split("_")[1], "%Y%m%d")
    for file in data_files
]

# create a list of dictionaries containing the filename, start date and end date for each file
data_files_info = [
    {"filename": file, "start_date": start, "end_date": end}
    for file, start, end in zip(data_files, start_date, end_date)
]

# specify the start and end dates for the desired date range
start_date_range = datetime(2000, 1, 1)
end_date_range = datetime.now()  # datetime(2022, 12, 31)

# print the the number of data files available
print(f"Total data files available: {len(data_files_info)}")

In [None]:
# filter the data_files_info list based on the date range
filtered_files_info = [
    file_info
    for file_info in data_files_info
    if start_date_range <= file_info["start_date"] <= end_date_range
    and start_date_range <= file_info["end_date"] <= end_date_range
]

print(f"Selected date range: {start_date_range} to {end_date_range}")
print(f"Total data files available within the date range: {len(filtered_files_info)}")

In [None]:
successful_loads = 0
unsuccessful_loads = 0
data_list = []

# read the data
for file_info in tqdm(filtered_files_info, desc="Loading global data"):
    try:
        data = nc.Dataset(file_info["filename"], "r")  # type: ignore
        successful_loads += 1
        data_list.append(
            {
                "data": data,
                "start_date": file_info["start_date"],
                "end_date": file_info["end_date"],
            }
        )
    except:
        unsuccessful_loads += 1

print(f"Successful loads: {successful_loads}")
print(f"Unsuccessful loads: {unsuccessful_loads}")
print(
    f"Size of data loaded: {sys.getsizeof(data_list)} bytes (~{(sys.getsizeof(data_list) / 1024**3):.2f} GB)"
)

In [None]:
time_series_data = {
    gdf_name: []
    for gdf_name in [
        os.path.basename(locs_file).replace(".geojson", "") for locs_file in locs_files
    ]
}

In [None]:
# loop through all the data, and crop it to the region of interest
for data_info in tqdm(data_list, desc="Cropping data to region of interest"):
    latitude = data_info["data"]["lat"][:]
    longitude = data_info["data"]["lon"][:]
    # chlor_a = data_info["data"]["chlor_a"][:]
    sst = data_info["data"]['sst'][:]
    # replace fill values with NaN for better plotting
    # fill_value = data_info["data"]["chlor_a"]._FillValue
    fill_value = data_info["data"]['sst']._FillValue
    sst[sst == fill_value] = np.nan
    sst[sst < 0] = np.nan
    # chlor_a[chlor_a == fill_value] = np.nan
    # chlor_a[chlor_a < 0] = np.nan  # set any data below 0 to NaN

    # loop through each region
    for i, gdf in enumerate(gdf_list):
        sst_crop = sst[
            (latitude >= y_min_list[i]) & (latitude <= y_max_list[i]), :
        ]
        sst_crop = sst_crop[
            :, (longitude >= x_min_list[i]) & (longitude <= x_max_list[i])
        ]
        longitude_crop = longitude[
            (longitude >= x_min_list[i]) & (longitude <= x_max_list[i])
        ]
        latitude_crop = latitude[
            (latitude >= y_min_list[i]) & (latitude <= y_max_list[i])
        ]
        region_name = os.path.basename(locs_files[i]).replace(".geojson", "")
        time_series_data[region_name].append(
            {
                "data": sst_crop,
                "latitude": latitude_crop,
                "longitude": longitude_crop,
                "start_date": data_info["start_date"],
                "end_date": data_info["end_date"],
                "region_name": region_name,
            }
        )

In [None]:
for region_name, data_series in time_series_data.items():

    os.makedirs(f"./plots/{region_name}", exist_ok=True)
    # plot_time_series(data_series, region_name)
    write_to_csv(data_series, region_name)

del data_list
print("Data cropped and time series plotted.")

In [None]:
for region in time_series_data.keys():
    mean_values = {}
    for data_info in time_series_data[region]:
        time_period = f"{data_info['start_date'].strftime('%Y%m%d')}_{data_info['end_date'].strftime('%Y%m%d')}"
        # print the number of nans in the data
        nans = np.isnan(data_info["data"]).sum()
        total = data_info["data"].size
        print(
            f"Region: {region}, time period: {time_period}, NaNs: {nans}, Total: {total}, % NaNs: {nans / total * 100:.2f}%"
        )

        # Calculate mean value
        mean_value = np.mean(data_info["data"])

        # Store mean value in dictionary

        mean_values[time_period] = mean_value

    os.makedirs(f"./plots/{region}", exist_ok=True)
    plot_data_mean(region_name=f"{region}", mean_values=mean_values, ylabel="temp", title="sea surface temp over time")