In [1]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.gridspec as gridspec
import pandas as pd

from os.path import join
import os
from functools import partial
import pathlib

from joblib import Parallel, delayed
import joblib

import gc


In [2]:
def plot_forecast_and_diff(ds, true_ds, init_ds, forecast_dir, forecast_hour=-1, diff_max=20):
    true_ds["t"] = ds.t
    
    for channel in ds.channel:
        fig = plt.figure(figsize=(15, 5))
        gs = gridspec.GridSpec(
            1, 6,
            width_ratios=[1, 1, 1, 0.1, 1, 0.1],
            wspace=0.1,
            figure=fig,
        )

        i=0
        # Create axes
        ax0 = fig.add_subplot(gs[i, 0], projection=ccrs.PlateCarree())
        ax1 = fig.add_subplot(gs[i, 1], projection=ccrs.PlateCarree())
        ax2 = fig.add_subplot(gs[i, 2], projection=ccrs.PlateCarree())
        cax = fig.add_subplot(gs[i, 3])  # colorbar axis (between 2nd & 3rd)
        ax3 = fig.add_subplot(gs[i, 4], projection=ccrs.PlateCarree())
        cax2 = fig.add_subplot(gs[i, 5])
        
        pred =  ds.sel(channel=[channel]).BT_or_R.isel(latitude=slice(11,-10), longitude=slice(19,-18))
        true = true_ds.sel(channel=[channel]).BT_or_R
        init = init_ds.sel(channel=[channel]).BT_or_R
        
        if channel > 4:
            init.plot(ax=ax0, transform=ccrs.PlateCarree(), cmap="Spectral_r", vmax=300, vmin=200,  add_colorbar=False)
            pred.plot(ax=ax1, transform=ccrs.PlateCarree(), cmap="Spectral_r", vmax=300, vmin=200,  add_colorbar=False)
            im2 = true.plot(ax=ax2, transform=ccrs.PlateCarree(), cmap="Spectral_r", vmax=300, vmin=200, add_colorbar=False)
            
            diff = pred - true
            diffbar = diff.plot(ax=ax3, transform=ccrs.PlateCarree(), cmap="seismic", vmax=diff_max, vmin=-diff_max, add_colorbar=False)
            ax3.set_title("Prediction - Observations")
        else: #channel 4
            vmin, vmax = 0.0, 0.2
            init.plot(ax=ax0, transform=ccrs.PlateCarree(), cmap="Spectral_r", vmax=vmax, vmin=vmin, add_colorbar=False)
            pred.plot(ax=ax1, transform=ccrs.PlateCarree(), cmap="Spectral_r", vmax=vmax, vmin=vmin, add_colorbar=False)
            im2 = true.plot(ax=ax2, transform=ccrs.PlateCarree(), cmap="Spectral_r", vmax=vmax, vmin=vmin, add_colorbar=False)
    
            diff = pred - true
            diffbar = diff.plot(ax=ax3, transform=ccrs.PlateCarree(), cmap="seismic", vmax=0.1, vmin=-0.1, add_colorbar=False)
            
        ax0.set_title(f"Initial Conditions {pd.Timestamp(init.t.values).strftime("%Y-%m-%dT%H:%M:%S")}")
        ax1.set_title("Prediction")
        ax2.set_title("Observations")
        ax3.set_title("Prediction - Observations")
        
        ax0.set_yticks(list(range(-50,51,25)))

        for ax in [ax0, ax1, ax2, ax3]:
            ax.add_feature(cfeature.COASTLINE)
            ax.set_xticks(list(range(-120,-29,30)))
    
        for ax in [cax, cax2]:
            ax.axis('off')
            ax.get_xaxis().set_ticks([])
            ax.get_yaxis().set_ticks([])
            
        cbar = fig.colorbar(im2, ax=cax, orientation='vertical', fraction = 1, shrink=0.75)
        cbar = fig.colorbar(diffbar, ax=cax2, orientation='vertical', fraction = 1, shrink=0.75)
    
        fig.suptitle(f"Channel {channel.values}, FH {forecast_hour:02}\n{pd.Timestamp(pred.t.values[0]).strftime("%Y-%m-%dT%H:%M:%S")}")

        figname = f"C{channel.values:02}_FH{forecast_hour:02}.png"

        save_dir = join(forecast_dir, f"gifs/C{pred.channel.values[0]:02}")
        
        os.makedirs(save_dir, exist_ok=True)
        plt.savefig(join(save_dir, figname), format="png")

        print(f"saved {figname}")

        plt.close(fig)

In [3]:
def plot_forecast_and_diff_combined(ds, true_ds, init_ds, forecast_dir, forecast_hour=-1, diff_max=20):
                             
    true_ds["t"] = ds.t
    num_channels = len(ds.channel)
    mainfig = plt.figure(figsize=(18, 5 * num_channels), 
                        layout="constrained",
                        )
    
    gs = gridspec.GridSpec(
        num_channels, 6,
        width_ratios=[1, 1, 1, 0.1, 1, 0.1],
        wspace=0.1,
        #hspace=0.5,
        figure=mainfig,
    )
    
    for i, channel in enumerate(ds.channel):
        fig = mainfig.add_subfigure(gs[i,:])
        
        # Create axes
        ax0 = mainfig.add_subplot(gs[i, 0], projection=ccrs.PlateCarree())
        ax1 = mainfig.add_subplot(gs[i, 1], projection=ccrs.PlateCarree())
        ax2 = mainfig.add_subplot(gs[i, 2], projection=ccrs.PlateCarree())
        cax = mainfig.add_subplot(gs[i, 3])  # colorbar axis (between 2nd & 3rd)
        ax3 = mainfig.add_subplot(gs[i, 4], projection=ccrs.PlateCarree())
        cax2 = mainfig.add_subplot(gs[i, 5])
        
        pred =  ds.sel(channel=[channel]).BT_or_R.isel(latitude=slice(11,-10), longitude=slice(19,-18))
        true = true_ds.sel(channel=[channel]).BT_or_R
        init = init_ds.sel(channel=[channel]).BT_or_R
        
        if channel > 4:
            init.plot(ax=ax0, transform=ccrs.PlateCarree(), cmap="Spectral_r", vmax=300, vmin=200,  add_colorbar=False)
            pred.plot(ax=ax1, transform=ccrs.PlateCarree(), cmap="Spectral_r", vmax=300, vmin=200,  add_colorbar=False)
            im2 = true.plot(ax=ax2, transform=ccrs.PlateCarree(), cmap="Spectral_r", vmax=300, vmin=200, add_colorbar=False)
            
            diff = pred - true
            diffbar = diff.plot(ax=ax3, transform=ccrs.PlateCarree(), cmap="seismic", vmax=diff_max, vmin=-diff_max, add_colorbar=False)
            ax3.set_title("Prediction - Observations")
        else: #channel 4
            vmin, vmax = 0.0, 0.2
            init.plot(ax=ax0, transform=ccrs.PlateCarree(), cmap="Spectral_r", vmax=vmax, vmin=vmin, add_colorbar=False)
            pred.plot(ax=ax1, transform=ccrs.PlateCarree(), cmap="Spectral_r", vmax=vmax, vmin=vmin, add_colorbar=False)
            im2 = true.plot(ax=ax2, transform=ccrs.PlateCarree(), cmap="Spectral_r", vmax=vmax, vmin=vmin, add_colorbar=False)
    
            diff = pred - true
            diffbar = diff.plot(ax=ax3, transform=ccrs.PlateCarree(), cmap="seismic", vmax=0.1, vmin=-0.1, add_colorbar=False)
            
        ax0.set_title(f"Initial Conditions {pd.Timestamp(init.t.values).strftime("%Y-%m-%dT%H:%M:%S")}")
        ax1.set_title("Prediction")
        ax2.set_title("Observations")
        ax3.set_title("Prediction - Observations")
        
        ax0.set_yticks(list(range(-50,51,25)))

        for ax in [ax0, ax1, ax2, ax3]:
            ax.add_feature(cfeature.COASTLINE)
            ax.set_xticks(list(range(-120,-29,30)))
    
        for ax in [cax, cax2]:
            ax.axis('off')
            ax.get_xaxis().set_ticks([])
            ax.get_yaxis().set_ticks([])
            
        cbar = fig.colorbar(im2, ax=cax, orientation='vertical', fraction = 1, shrink=0.75)
        cbar = fig.colorbar(diffbar, ax=cax2, orientation='vertical', fraction = 1, shrink=0.75)
    
        fig.suptitle(f"Channel {channel.values}, FH {forecast_hour:02}\n{pd.Timestamp(pred.t.values[0]).strftime("%Y-%m-%dT%H:%M:%S")}")
    

    figname = f"combined_FH{forecast_hour:02}.png"

    save_dir = join(forecast_dir, "gifs/combined")
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(join(save_dir, figname), dpi=mainfig.dpi)
    print(f"saved {figname}")
    plt.close(mainfig)

In [4]:
#parameters
forecast_dir = "/glade/derecho/scratch/dkimpara/goes_10km_train/simple_wxformer_lognormal/forecasts/2023-12-31T23:55:05"
num_cpus = 3
test = 0

In [5]:
files = sorted([f for f in os.listdir(forecast_dir) if os.path.isfile(join(forecast_dir, f))])

if test:
    files = files[:2]

def make_gif(forecast_dir, file_and_index):
    k, file = file_and_index
    
    fh = k + 1

    file = join(forecast_dir, file)
    print(f"processing {file}")
    ds = xr.open_dataset(file, engine="netcdf4")
    
    zarr_ds = xr.open_dataset("/glade/derecho/scratch/dkimpara/goes-cloud-dataset/goes_10km.zarr", consolidated=False).drop_duplicates(
        dim="t"
    ).sortby(
        "t"
    ).transpose(
        "t", "channel", "latitude", "longitude")
    
    true_ds = zarr_ds.sel(t=ds.t, method="nearest")

    # get init data
    path1 = pathlib.Path(forecast_dir)
    time_str = path1.parent.name if not path1.is_dir() else path1.name
    init_time = pd.Timestamp(time_str)

    init_ds = zarr_ds.sel(t=init_time, method="nearest")

    # plot_forecast_and_diff_combined(ds, true_ds, init_ds, forecast_dir, fh)
    plot_forecast_and_diff(ds, true_ds, init_ds, forecast_dir, fh)
    plot_forecast_and_diff_combined(ds, true_ds, init_ds, forecast_dir, fh)
    
    gc.collect()

f = partial(make_gif, forecast_dir)

result = Parallel(n_jobs = num_cpus - 1)(delayed(f)(file_and_index)
                            for file_and_index in enumerate(files))

    

processing /glade/derecho/scratch/dkimpara/goes_10km_train/simple_wxformer_lognormal/forecasts/2023-12-31T23:55:05/2024-01-01T01:55:05.nc
saved C04_FH02.png
saved C07_FH02.png
saved C08_FH02.png
saved C09_FH02.png
saved C10_FH02.png
saved C13_FH02.png
saved combined_FH02.png
processing /glade/derecho/scratch/dkimpara/goes_10km_train/simple_wxformer_lognormal/forecasts/2023-12-31T23:55:05/2024-01-01T03:55:05.nc
saved C04_FH04.png
saved C07_FH04.png
saved C08_FH04.png
saved C09_FH04.png
saved C10_FH04.png
saved C13_FH04.png
saved combined_FH04.png




processing /glade/derecho/scratch/dkimpara/goes_10km_train/simple_wxformer_lognormal/forecasts/2023-12-31T23:55:05/2024-01-01T05:55:06.nc
saved C04_FH06.png
saved C07_FH06.png
saved C08_FH06.png
saved C09_FH06.png
saved C10_FH06.png
saved C13_FH06.png
saved combined_FH06.png
processing /glade/derecho/scratch/dkimpara/goes_10km_train/simple_wxformer_lognormal/forecasts/2023-12-31T23:55:05/2024-01-01T07:55:06.nc
saved C04_FH08.png
saved C07_FH08.png
saved C08_FH08.png
saved C09_FH08.png
saved C10_FH08.png
saved C13_FH08.png
saved combined_FH08.png
processing /glade/derecho/scratch/dkimpara/goes_10km_train/simple_wxformer_lognormal/forecasts/2023-12-31T23:55:05/2024-01-01T09:55:06.nc
saved C04_FH10.png
saved C07_FH10.png
saved C08_FH10.png
saved C09_FH10.png
saved C10_FH10.png
saved C13_FH10.png
saved combined_FH10.png
processing /glade/derecho/scratch/dkimpara/goes_10km_train/simple_wxformer_lognormal/forecasts/2023-12-31T23:55:05/2024-01-01T00:55:05.nc
saved C04_FH01.png
saved C07_FH01.

In [6]:
!magick -delay 20 -loop 1 {join(forecast_dir, f'gifs/C04/*.png')} {join(forecast_dir, f'gifs/C04.gif')}
!magick -delay 20 -loop 1 {join(forecast_dir, f'gifs/C07/*.png')} {join(forecast_dir, f'gifs/C07.gif')}
!magick -delay 20 -loop 1 {join(forecast_dir, f'gifs/C08/*.png')} {join(forecast_dir, f'gifs/C08.gif')}
!magick -delay 20 -loop 1 {join(forecast_dir, f'gifs/C09/*.png')} {join(forecast_dir, f'gifs/C09.gif')}
!magick -delay 20 -loop 1 {join(forecast_dir, f'gifs/C10/*.png')} {join(forecast_dir, f'gifs/C10.gif')}
!magick -delay 20 -loop 1 {join(forecast_dir, f'gifs/C13/*.png')} {join(forecast_dir, f'gifs/C13.gif')}

In [7]:
!magick -delay 30 -loop 1 {join(forecast_dir, f'gifs/combined/*.png')} {join(forecast_dir, f'gifs/combined.gif')}