In [1]:
import sys, os 
sys.path.append(os.path.dirname(os.getcwd()))

In [2]:
from MontePython import monte_python

import re
from datetime import datetime
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib
from glob import glob
from os.path import join, basename
from skimage.measure import regionprops, regionprops_table
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
plt.rcParams["animation.html"] = "jshtml"

In [3]:
from WoF_post.wofs.plotting.wofs_plotter import WoFSPlotter

In [4]:
### NWS Reflectivity Colors (courtesy MetPy library):
c5 =  (0.0,                 0.9254901960784314, 0.9254901960784314)
c10 = (0.00392156862745098, 0.6274509803921569, 0.9647058823529412)
c15 = (0.0,                 0.0,                0.9647058823529412)
c20 = (0.0,                 1.0,                0.0)
c25 = (0.0,                 0.7843137254901961, 0.0)
c30 = (0.0,                 0.5647058823529412, 0.0)
c35 = (1.0,                 1.0,                0.0)
c40 = (0.9058823529411765,  0.7529411764705882, 0.0)
c45 = (1.0,                 0.5647058823529412, 0.0)
c50 = (1.0,                 0.0,                0.0)
c55 = (0.8392156862745098,  0.0,                0.0)
c60 = (0.7529411764705882,  0.0,                0.0)
c65 = (1.0,                 0.0,                1.0)
c70 = (0.6,                 0.3333333333333333, 0.788235294117647)
c75 = (0.0,                 0.0,                0.0) 

nws_dz_cmap = matplotlib.colors.ListedColormap([c20, c25, c30, c35, c40, c45, 
                 c50, c55, c60, c65, c70])
dz_levels_nws = np.arange(20.0,80.,5.)

## Testing the concept of tracking a storm in MRMS and then finding its counterpart in WoFS and continuing the track. 
1. Tracking the MRMS storms up until the forecast initialization.
2. Matching MRMS storms to WoFS storms using ObjectMatcher and then re-label the WoFS storms at t=0 if they do match an MRMS storm
3. Tracking the WoFS storms forward in time. 


In [5]:
def data_loader(files, var_name):
    """Load MRMS/WoFS Comp. reflectivity data."""
    qcer = monte_python.QualityControler()
    all_labels = np.zeros((len(files), 300, 300), dtype=np.int16)
    z_set = np.zeros((len(files), 300, 300))
    dbz_qc_params = [
            ("min_area", 15),
            ("merge_thresh", 3),
            ("max_thresh", (44.93, 100)),
        ]
    
    for i, f in enumerate(files):
        ds = xr.open_dataset(f, decode_times=False)
        dbz = ds[var_name].values.squeeze()
        if np.ndim(dbz) == 3:
            dbz = dbz[0,:,:]
        ds.close()
        storm_labels, object_props = monte_python.label(  input_data = dbz, 
                       method ='single_threshold', 
                       return_object_properties=True, 
                       params = {'bdry_thresh' : 43.0},  
                       )
        # QCing the labels
        storm_labels,_ = qcer.quality_control(dbz, storm_labels, 
                                          object_props, dbz_qc_params,)
    
        all_labels[i,:,:] = storm_labels 
        z_set[i,:,:] = dbz 
    
    return all_labels, z_set

In [6]:
# Get the MRMS reflectivity in the hour before the init time.
date, init_time = '20210526', '0000'
mrmsPath = f'/work/brian.matilla/WOFS_2021/MRMS/RAD_AZS_MSH_AGG/{date}/{init_time}/'
mrms_files = glob(join(mrmsPath, 'wofs_RAD_*'))
mrms_files.sort()
mrms_files = mrms_files[:12+1]

# Get the WoFS files for an hour after init time. 
wofsPath = f'/work/mflora/SummaryFiles/{date}/{init_time}/'
files = glob(join(wofsPath, 'wofs_ENS_*'))
files.sort()
files = files[:12+1]

# Load the MRMS data.
mrms_labels, mrms_dbz = data_loader(mrms_files, 'dz_cress')

# Load the WoFS data.
wofs_labels, wofs_dbz = data_loader(files, 'comp_dz')

In [7]:
# Track the MRMS storms. 
tracker = monte_python.ObjectTracker(percent_overlap=0.0, mend_tracks=True)
mrms_tracks = tracker.track(mrms_labels)

In [8]:
# Track the WoFS storms. 
tracker = monte_python.ObjectTracker(percent_overlap=0.0, mend_tracks=True)
wofs_tracks = tracker.track(wofs_labels)

# To ensure unique labelling between the MRMS and WoFS, 
# re-label using the highest MRMS label value. 
wofs_tracks[wofs_tracks>0] += np.max(mrms_tracks)

In [9]:
# Match the objects from the last MRMS time step to the first WoFS time step
# and re-label the WoFS storms. 
matcher = monte_python.ObjectMatcher(cent_dist_max=10, min_dist_max=10, time_max=0, one_to_one=True)
mrms_labels, wofs_labels, _ = matcher.match_objects(mrms_tracks[-1,:,:], wofs_tracks[0,:,:])

## Re-label the WoFS storms to their matching MRMS storm label.
wofs_track_relabel = np.copy(wofs_tracks)
for wofs_label, mrms_label in zip(wofs_labels, mrms_labels):
    wofs_track_relabel[wofs_tracks==wofs_label] = mrms_label

In [10]:
all_tracks = np.concatenate([mrms_tracks, wofs_track_relabel], axis=0)
all_dbz = np.concatenate([mrms_dbz, wofs_dbz], axis=0)

files = mrms_files + files
modes = ['MRMS']*len(mrms_files) + ['WOFS']*len(files)

In [11]:
class TrackAnimator:
    """
    Creates an animation of the tracks including a label over the centroid and
    a line showing the track. 
    """
    def __init__(self, tracked_objects, dataset, modes, fnames):
        
        self.tracked_objects = tracked_objects
        self.dataset=dataset
        self.modes = modes
        self.fnames = fnames
        
        self.object_props = [regionprops(tracks, data) for tracks, data in zip(tracked_objects, tracked_objects)]
        
        properties = ['label', 'centroid']
        objects_df = [pd.DataFrame(regionprops_table(tracks, data, properties)) 
              for tracks, data in zip(tracked_objects, tracked_objects)]
        
        self.x_cent, self.y_cent = self.get_track_path(objects_df, tracked_objects)
        
        self.x, self.y = np.meshgrid(np.arange(300), np.arange(300))
        
        
    
    def create_initial_frame(self):
        # Animate the Figure. 
        fig, ax = plt.subplots(dpi=300, figsize=(8,7))
        
        # Animate the Figure. 
        cont = ax.contourf(self.x, self.y, self.dataset[0],
                   cmap=nws_dz_cmap, levels=dz_levels_nws, )

        ax.contourf(self.x, self.y, np.ma.masked_where(self.tracked_objects[0]<1, self.tracked_objects[0]),
                    cmap = 'tab20b', alpha=0.6)
        self.set_label(ax, self.modes[0], self.fnames[0])
        
        self.plot_track(ax, self.y_cent, self.x_cent, 0)
        self.label_centroid(ax, self.object_props[0])
        fig.colorbar(cont, label='Reflectivity')
    
        return fig, ax 
    
    
    def animate(self, i):
        global cont
        z = self.dataset[i]
        tracks = self.tracked_objects[i,:,:]
        props = self.object_props[i]
        ax.clear()
        
        cont = ax.contourf(self.x,self.y, z, 
                       cmap=nws_dz_cmap, levels=dz_levels_nws,)
    
        ax.contourf(self.x, self.y, np.ma.masked_where(tracks<1, tracks), cmap = 'tab20b', alpha=0.6)
        self.label_centroid(ax, props)
        self.plot_track(ax, self.y_cent, self.x_cent, i)
        self.set_label(ax, self.modes[i], self.fnames[i])
        
        return cont

    
    def label_centroid(self, ax, object_props):
        """Place object label on object's centroid"""
        for region in object_props:
            x_cent,y_cent = region.centroid
            x_cent=int(x_cent)
            y_cent=int(y_cent)
            xx, yy = self.x[x_cent,y_cent], self.y[x_cent,y_cent]
            fontsize = 6.5 if region.label >= 10 else 8
            ax.text(xx,yy,
                    region.label,
                    fontsize=fontsize,
                    ha='center',
                    va='center',
                    color = 'k'
                    )
    
    def get_centroid(self, df, label):
        try:
            df=df.loc[df['label'] == label]
            x_cent, y_cent = df['centroid-0'], df['centroid-1']
            x_cent=int(x_cent)
            y_cent=int(y_cent)
        except:
            return np.nan, np.nan
    
        return x_cent, y_cent 
    
    def get_track_path(self, object_props, tracked_objects):
        """ Create track path. """
        unique_labels = np.unique(tracked_objects)[1:]
        centroid_x = {l : [] for l in unique_labels}
        centroid_y = {l : [] for l in unique_labels}
    
        for df in object_props:
            for label in unique_labels:
                x,y = self.get_centroid(df, label)
                centroid_x[label].append(x)
                centroid_y[label].append(y)
    
        return centroid_x, centroid_y

    def plot_track(self, ax, x_cent, y_cent, ind): 
        labels = x_cent.keys()
        for l in labels:
            xs,ys = x_cent[l], y_cent[l]
            ax.plot(xs[:ind], ys[:ind], lw='0.5')
            
    def get_label(self, file):
        fname = basename(file)
        pattern = 'wofs_(\S{3})_(\d{2})_(?P<date>\d{8})_(\d{4})_(?P<time>\d{4}).nc'
        groups = re.compile(pattern).match(fname).groupdict()
        d = [i for k,i in groups.items()]
        valid_dt = datetime.strptime(f'{d[0]}{d[1]}' , '%Y%m%d%H%M')          
        return valid_dt.strftime('Valid: %Y-%m-%d, %H%M UTC')

    def set_label(self, ax, mode, file):
        label = self.get_label(file)
        ax.annotate(mode, xy=(0.9, 1.05), xycoords='axes fraction')
        ax.annotate(label, xy=(1.0, 1.01), xycoords='axes fraction', ha='right')         
            

In [12]:
# Animate the Figure. 
animator = TrackAnimator(all_tracks, all_dbz, modes, files)
fig, ax = animator.create_initial_frame()

# call the animator. blit=True means only re-draw the parts that have changed.
plt.tight_layout()
anim = FuncAnimation(fig, animator.animate, frames=len(all_dbz), interval=40, repeat=True, blit=False)
HTML(anim.to_jshtml())
#anim.save(f'mrms_to_wofs_{date}_{init_time}.gif', writer='pillow', fps=5)