This is where the RGB channels are extracted, overlayed into an RGB.png image, alongside with the corresponding climate data, DEM maps, and landcover maps.

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from pathlib import Path
import gc
import dask.array as da

# List of directories to process
input_dirs = [
    r'C:\Workdir\Develop\greenearthnet\val_chopped',
    r'C:\Workdir\Develop\greenearthnet\ood-st_chopped',
    r'C:\Workdir\Develop\greenearthnet\ood-t_chopped',
    r'C:\Workdir\Develop\greenearthnet\ood-s_chopped',
    r'C:\Workdir\Develop\greenearthnet\val_chopped',
    r'C:\Workdir\Develop\greenearthnet\train',

]
output_base_dir = r'C:\TjallingData\greenearthnet_additional'

# all landcover classes and corresponding colors
land_cover_classes = {
    10: 'Tree cover', 20: 'Shrubland', 30: 'Grassland', 40: 'Cropland',
    50: 'Built-up', 60: 'Bare / sparse vegetation', 70: 'Snow and Ice',
    80: 'Permanent water bodies', 90: 'Herbaceous wetland', 95: 'Mangroves',
    100: 'Moss and lichen',
}

colors = ['#006400', '#8db360', '#bae4b3', '#ffffcc', '#ffeda0', '#f2f0f7', '#f7fcf0', '#081d58', '#c7e9b4', '#7fcdbb', '#00441b']
cmap = mcolors.ListedColormap(colors)
bounds = list(land_cover_classes.keys()) + [max(land_cover_classes.keys()) + 1]
norm = mcolors.BoundaryNorm(bounds, cmap.N)

# E-OBS variables to plot for the landcover map
variables = ['eobs_fg', 'eobs_hu', 'eobs_rr', 'eobs_pp', 'eobs_qq', 'eobs_tg', 'eobs_tn', 'eobs_tx']
variable_names = ['Wind Speed', 'Relative Humidity', 'Rainfall', 'Sea-level Pressure', 'Shortwave Downwelling Radiation', 'Temperature Avg', 'Temperature Min', 'Temperature Max']

def process_file(nc_file, output_dir, variables, variable_names, subfolder_prefix):
    try:
        with xr.open_dataset(nc_file) as minicube:
            minicube_with_data = minicube.isel(time=slice(4, None, 5))

            # Check if 's2_mask' exists
            if 's2_dlmask' in minicube_with_data:
                minicube_clear = minicube_with_data.where(minicube_with_data.s2_dlmask == 0, drop=True)
            else:
                print(f"'s2_dlmask' not found in {nc_file}. Skipping mask filtering.")
                minicube_clear = minicube_with_data

            # Check for required bands
            bands_to_check = ['s2_B02', 's2_B03', 's2_B04', 's2_B8A']
            for band in bands_to_check:
                if band not in minicube_clear:
                    print(f"Band {band} not found in {nc_file}. Skipping file.")
                    return None
                minicube_clear = minicube_clear.dropna(dim='time', subset=[band], how='any')

            num_time_steps = minicube_clear.sizes['time']

            if num_time_steps == 0:
                print(f"No valid data found in {nc_file}")
                return None

            time_dates = pd.to_datetime(minicube_clear.time.values)

            base_name = os.path.splitext(os.path.basename(nc_file))[0]

            df_list = []

            for i in range(num_time_steps):
                date_str = time_dates[i].strftime('%Y-%m-%d')
                filename_prefix = f"{subfolder_prefix}_{base_name}_{date_str}"

                # RGB Image: a composite of the three R G B bands
                red = minicube_clear['s2_B04'].isel(time=i).values
                green = minicube_clear['s2_B03'].isel(time=i).values
                blue = minicube_clear['s2_B02'].isel(time=i).values

                # Normalize RGB values
                red_normalized = (red - red.min()) / (red.max() - red.min())
                green_normalized = (green - green.min()) / (green.max() - green.min())
                blue_normalized = (blue - blue.min()) / (blue.max() - blue.min())

                rgb_normalized = np.dstack((red_normalized, green_normalized, blue_normalized))

                fig, ax = plt.subplots(figsize=(5, 4))
                ax.imshow(rgb_normalized)
                ax.axis('off')
                fig.savefig(os.path.join(output_dir, 'RGB', f'{filename_prefix}.png'), bbox_inches='tight', pad_inches=0, dpi=300)
                plt.close(fig)


                # Landcover Map
                fig, ax = plt.subplots(figsize=(5, 4))
                data = minicube_clear['esawc_lc'].isel(time=i).values
                ax.imshow(data, cmap=cmap, norm=norm)
                ax.axis('off')
                fig.savefig(os.path.join(output_dir, 'LandcoverMaps', f'{filename_prefix}.png'), bbox_inches='tight', pad_inches=0, dpi=300)
                plt.close(fig)

                # DEM Maps
                elevation_datasets = {
                    'nasa_dem': minicube_clear['nasa_dem'].isel(time=i).values,
                    'cop_dem': minicube_clear['cop_dem'].isel(time=i).values,
                    'alos_dem': minicube_clear['alos_dem'].isel(time=i).values
                }

                for elev_name, elev_data in elevation_datasets.items():
                    fig, ax = plt.subplots(figsize=(5, 4))
                    ax.imshow(elev_data, cmap='terrain')
                    ax.axis('off')
                    fig.savefig(os.path.join(output_dir, 'DEM', f'{filename_prefix}_{elev_name}.png'), bbox_inches='tight', pad_inches=0, dpi=300)
                    plt.close(fig)

                # Metadata
                df_temp = pd.DataFrame(index=[time_dates[i]])
                for var, name in zip(variables, variable_names):
                    data_avg = float(minicube_clear[var].isel(time=i).mean().values)
                    df_temp[name] = data_avg

                df_temp.reset_index(inplace=True)
                df_temp.rename(columns={'index': 'file_name'}, inplace=True)
                df_temp['file_name'] = f'{filename_prefix}.png'
                df_list.append(df_temp)

                print(f'Processed {filename_prefix}')

            if df_list:
                return pd.concat(df_list, ignore_index=True)
            else:
                print(f"No valid data frames created for {nc_file}")
                return None
    except Exception as e:
        print(f"An error occurred while processing {nc_file}: {e}")
        return None

# Main processing loop
for input_dir in input_dirs:
    map_name = os.path.basename(input_dir)
    output_dir = os.path.join(output_base_dir, map_name)

    os.makedirs(os.path.join(output_dir, 'RGB'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'LandcoverMaps'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'DEM'), exist_ok=True)

    # Get all subfolders
    subfolders = [f.path for f in os.scandir(input_dir) if f.is_dir()]

    for subfolder in subfolders:
        subfolder_name = os.path.basename(subfolder)
        nc_files = glob.glob(os.path.join(subfolder, '*.nc'))

        # Process files in batches
        batch_size = 10
        for i in range(0, len(nc_files), batch_size):
            batch_files = nc_files[i:i+batch_size]
            df_list = []

            for nc_file in batch_files:
                df_temp = process_file(nc_file, output_dir, variables, variable_names, subfolder_name)
                if df_temp is not None:
                    df_list.append(df_temp)

                # Clear memory
                gc.collect()

            # Concatenate and save batch results
            if df_list:
                df_batch = pd.concat(df_list, ignore_index=True)
                df_batch.to_csv(os.path.join(output_dir, f'metadata_batch_{subfolder_name}_{i//batch_size}.csv'), index=False)

            del df_list
            gc.collect()

        print(f'Completed processing for {map_name}/{subfolder_name}')

    # Combine all batch results for the current input_dir
    all_batches = glob.glob(os.path.join(output_dir, 'metadata_batch_*.csv'))
    if all_batches:
        df = pd.concat([pd.read_csv(f) for f in all_batches], ignore_index=True)
        df.to_csv(os.path.join(output_dir, 'metadata.csv'), index=False)

        # Clean up batch files
        for f in all_batches:
            os.remove(f)
    else:
        print(f"No valid metadata batches found for {map_name}")

print('All processing complete.')