In [3]:
####################################
# Contours Visualization Pipeline
####################################

In [1]:
!pip install opencv-python-headless



In [2]:
from typing import List
import itertools
import os
import shutil
import uuid
from collections import Counter
from datetime import datetime, timedelta
from pathlib import Path
import subprocess
import tempfile
import time
import warnings
import numpy as np
import pandas as pd
import xarray as xr
import zarr
import fsspec

import cv2

from matplotlib import pyplot as plt
import matplotlib.dates as mdates
from matplotlib.patches import Rectangle
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
plt.rcParams['figure.figsize'] = 12,8

import getpass
import azure.storage.blob
from azure.storage.blob import BlobClient, BlobServiceClient
from azure.core.exceptions import ResourceExistsError, HttpResponseError

In [11]:
SAS_TOKEN = getpass.getpass() # of the whole "cmip6" folder in Azure.
URL_PREFIX = 'https://nasanex30analysis.blob.core.windows.net/cmip6'

 ·············································································································································


In [5]:
####################################
# CONSTANTS
####################################

# constants for openCV countour finding
SMOOTH_RATIO = 0
MIN_AREA = 10
CONVEX = False

# constants for the rolling-window aggregation
ROLLING = 4


In [6]:
####################################
# Utils
####################################

class AzureSource():

    def __init__(self, model:str, year:int):
        fn = f"Ext_max_t__Rgn_1__{year}__Abv_Avg_5_K_for_3_days__CMIP6_{model}_Avg_yrs_1950_79.nc"
        self.filename = fn
        abspath = f"extremes_max/{model}/Region_1/Avg_yrs_1950_79/Abv_Avg_5_K_for_3_days/{fn}" 
        self.abspath = abspath
            
    def download(self):
        
        if not os.path.isfile(self.filename):
            
            sas_url = f"{URL_PREFIX}/{self.abspath}?{SAS_TOKEN}"
            blob_client = BlobClient.from_blob_url(sas_url)

            with tempfile.TemporaryFile() as f:
                fp = f"{f.name}.tmp"
                with open(fp, "wb") as my_blob:
                    download_stream = blob_client.download_blob()
                    my_blob.write(download_stream.readall())

                    os.rename(fp, self.filename)
                    while os.path.getsize(self.filename)/10**6 < 10: # MB
                        time.sleep(2) 


In [19]:
####################################
# Define Contour obj
####################################

"""
Bounding-contours algorithm to find the extend of the heat events and
produce visualizations. It uses the the heat events y/n dataset 
which was (supposed to be pre-) produced by the "Heatwave Analysis" algorithm. 
"""

class Contour(object):
    """A single contour obj. All unit operations are managed here."""
    
    def __init__(self, cnt:np.array, lons, lats):
        self.contour = cnt
        self.lons = lons
        self.lats = lats
        self.name = uuid.uuid4().hex[:6]
        self._area = 0.0
        self._smoothened = np.array([], dtype=np.int32)
        self._projected = np.array([], dtype=np.float64)
        self._center = ()
    
    def __repr__(self):
        return self.name
    
    @property
    def area(self):
        return cv2.contourArea(self.contour)
    
    @property
    def smoothened(self):
        cnt = self.contour
        arc = SMOOTH_RATIO*cv2.arcLength(cnt,True)
        return cv2.approxPolyDP(cnt,arc,True)
    
    @property
    def projected(self):
        squeezed = self.smoothened.squeeze()
        proj = [(float(self.lons[x]), float(self.lats[y])) for (x,y) in squeezed]
        return np.array(proj).reshape((-1,1,2))
    
    @property
    def center(self):
        M = cv2.moments(self.contour)
        cX = int(M["m10"] / M["m00"])
        cY = int(M["m01"] / M["m00"])
        return (float(self.lons[cX]), float(self.lats[cY])) 

    def position_to(self, c2:object)->str:
        """Find the relative position of a Contour obj to another.
        Return if c1 is inside or outside c2, or they intersect."""
        
        f = cv2.pointPolygonTest
        c1 = self.contour.squeeze().astype(float)
        tf = np.array([int((f(c2.contour, x, False))) for x in c1])
        if all(tf==-1):
            return "outside"  
        elif all(tf==1):
            return "inside" 
        else:
            return "intersect" 
    
    def __add__(self, obj2:object):
        """Fuse two countor objects ('bubbles'). Better do this if they 
        intersect or one is enclosed inside the other."""
        
        c1, c2 = self.contour, obj2.contour
        fused = cv2.convexHull(np.vstack([c1, c2]))
        new_obj = self.__class__(fused, self.lons, self.lats)
        return new_obj


class ContourCollection(list):
    """Essentially just a list, except overloads behavior for "in" operator."""
    def __init__(self, items:List[Contour]):
        self.items = items
        super(ContourCollection, self).__init__(items)
        
    def __contains__(self, x):
        result = False
        for c in self.items:
            if x.name==c.name and x.area==c.area:
                result = True
        return result
    
    
####################################
# Find the independent contours for a given day 
####################################

def find_daily_contours(ds:xr.Dataset)->List[ContourCollection]:
    """Give a dataset and it will loop through days and
    find all contours per day, if any. This function does ~
    df['contours'].rolling(window=4).sum() """

    def find_contours(arr2d: np.array, 
                      convex:bool=False, 
                      min_area:int=150) -> List[np.array]:
        """Encapsulate islands of 1s and return contours, [(i,j),(..),].
        input:  day-slice of a dataset tasmax dataarray
        output: list of contours (np.arrays)"""

        H = arr2d.astype(np.uint8)
        ret, thresh = cv2.threshold(H, 0, 1, 0, cv2.THRESH_BINARY)

        kernel = np.ones((10,10), np.uint8)
        thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
        contours, hier = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        if convex:
            contours = [cv2.convexHull(c) for c in contours]

        contours = [c for c in contours if c.shape[0]>1] # filter single points
        for c in contours:
            if c.ndim!=3:
                print(c.shape)
                
        lons = ds.coords['lon']
        lats = ds.coords['lat']

        contours = [Contour(c, lons, lats) for c in contours]
        contours = [c for c in contours if c.area>min_area]

        return ContourCollection(contours)

    all_contours = []

    dr = pd.DatetimeIndex(ds['time'].dt.floor('D').values.astype('str'))
    
    days = []
    for d in dr:
        day = d.strftime("%Y-%m-%d")
        extreme = ds['extreme_yn'].sel(time=day)
        arr2d = extreme.values[0]

        all_contours += [find_contours(arr2d, convex=CONVEX, min_area=MIN_AREA)]
        days += [day]
        
    return all_contours, days


####################################
# Rolling-window contours summation on time axis
####################################

def collapse(contours:List[Contour]) -> List[Contour]:
    """Recursive func to fuse multiple contour objects, if overlapping."""
    
    if type(contours)==float and pd.isna(contours):
        return []
    
    conts = contours[:] # prevent mutation
    for cnt1, cnt2 in itertools.combinations(conts, 2):
        if cnt1.position_to(cnt2) in ("inside", "intersect"):
            cnt_new = cnt1+cnt2
            conts.remove(cnt1)
            conts.remove(cnt2)
            conts.append(cnt_new)
            return collapse(conts) # recursion
        
    return conts


def rolling_sum(all_contours:list, window:int=ROLLING)->pd.DataFrame:
    """Provide df with daily contours calculated, and it will df.rolling(w).sum()
    The only reason we can't use pandas is that its .rolling method refuses sum(lists)."""
    if window==1:
        warnings.warn("window=1 just returns contours as-is.")

    df = pd.DataFrame(dict(contours=all_contours))
    
    for i in range(1, window):
        df[f"shift{i}"] = df['contours'].shift(i)

    df['rolling_append'] = df.filter(regex=r'contours|shift*', axis=1).dropna().sum(axis=1)
    df['rolling_sum'] = df['rolling_append'].apply(collapse)

    # drop tmp columns:
    df = df[[c for c in df.columns if "shift" not in c]]
    df = df.drop("rolling_append", axis=1)
    
    assert len(ds['extreme_yn'])==len(df)

    return df

####################################
# Serialize metadata ready to json
####################################

def serialize(df:pd.DataFrame) -> pd.DataFrame:
    
    df1 = df.explode('contours')[['days','contours']].reset_index(drop=True)
    df1['type'] = 'daily'
    df1 = df1.rename({'contours':'contour'}, axis=1)

    df2 = df.explode('rolling_sum')[['days','rolling_sum']].reset_index(drop=True)
    df2['type'] = 'rolling_sum'
    df2 = df2.rename({'rolling_sum':'contour'}, axis=1)

    df3 = pd.concat([df1,df2], axis=0)\
                .sort_values(by=['days','type'], ascending=True)\
                .dropna()\
                .reset_index(drop=True)

    df3['name'] = [x.name for x in df3['contour']]
    df3['center'] = [x.center for x in df3['contour']]
    df3['area'] = [x.area for x in df3['contour']]
    df3['projected'] = [x.projected for x in df3['contour']]
    df3 = df3.drop('contour', axis=1)
    
    return df3


In [20]:
####################################
# Generate figures for each day with contours
####################################

def validate(df:pd.DataFrame):
    assert "contours" in df.columns
    assert "rolling_sum" in df.columns
    assert df.index.is_monotonic
    
def create_figures(df:pd.DataFrame, window:int, save=False, folder:str=None):  
    
    validate(df)
    
    def add_patches(column:str, _idx:int, color:str, linewidths:int, alpha=1):
        contours = df[column][df.index==_idx].values[0]
        patches = [Polygon(c.projected.squeeze(), True) for c in contours]

        args = dict(edgecolors=(color,), linewidths=(linewidths,), facecolor="none", alpha=alpha)
        
        p = PatchCollection(patches, **args)
        ax1.add_collection(p)
        [ax1.scatter(x=c.center[0], y=c.center[1], c=color, s=3) for c in contours]
        
        p = PatchCollection(patches, **args)
        ax2.add_collection(p)
        [ax2.scatter(x=c.center[0], y=c.center[1], c=color, s=3) for c in contours]
    
    for i, idx in enumerate(df.index):

        dr = pd.DatetimeIndex(ds['time'].dt.floor('D').values.astype('str'))
        day = dr[idx].strftime("%Y-%m-%d")
        tasmax = ds['tasmax'].sel(time=day)
        tdiff = ds['above_threshold'].sel(time=day)
        extreme = ds['extreme_yn'].sel(time=day)

        fig, (ax1, ax2) = plt.subplots(1,2, figsize=(24,8))

        im1 = extreme.squeeze().plot.imshow(ax=ax1, cmap='cividis')
        im2 = tdiff.squeeze().plot.imshow(ax=ax2, cmap='coolwarm', vmin=4, vmax=-4, alpha=0.8)

        colors = 'r b c w m g y'.split()*100
        for x in range(i+1):
            add_patches('contours', idx-x, colors[i-x], 1.5)
            if x==window:
                add_patches('rolling_sum', idx, 'g', 4, alpha=0.8) 
                break

        fig.tight_layout()
        
        if save:
            # save image locally
            if not os.path.exists(folder):
                os.mkdir(folder)
            fig.savefig(f"{folder}/{day}.jpg")

            plt.close(fig)

####################################
# Compile a video from images
####################################

def create_video(files:List[str], fn_out:str)->None:
    
    h,w,_ = cv2.imread(files[0]).shape

    with tempfile.TemporaryFile() as f:

        fp = f"{f.name}.avi"
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        video = cv2.VideoWriter(fp, fourcc, 10, (w,h))

        for fn in files:
            img = cv2.imread(fn)
            video.write(img)

        video.release()
        os.rename(fp, 'out.avi')
        time.sleep(2)

    fn_in = 'out.avi'
    cmd = f"ffmpeg -i '{fn_in}' -ac 2 -b:v 2000k -c:a aac -c:v libx264 -b:a 160k -vprofile high -bf 0 -strict experimental -f mp4 '{fn_out}'"
    subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)  
    os.remove('out.avi')

In [None]:
####################################
# Run the Pipeline
####################################

models =  ["GISS_E2_1_G_ssp585"]  # "GFDL_ESM4_ssp245", "GFDL_ESM4_ssp585", "GISS_E2_1_G_ssp245"
years = list(range(2020,2030)) 

for model in models:
    
    for year in years:
        
        t1 = time.time()
        
        # import dataset
        ################################
        at = AzureSource(model, year)
        at.download()
        ds = xr.open_mfdataset(at.filename)
        days_above = ds.attrs['Number of continuous days to be considered extreme']
        kelv_above = ds.attrs['threshold']
        upload_folder = f"NEWcontours_{days_above}days_{kelv_above}K/{model}"

        # find contours
        ################################
        dc, days = find_daily_contours(ds)
        df_daily = rolling_sum(dc)
        df_daily['days'] = days

#         # create metadata
#         ################################
#         path_meta = f"{model}_{year}.json"
#         df_meta = serialize(df_daily)
#         df_meta.to_json(path_meta)
        
        # create images
        ################################
        img_folder = f"{model}_{year}"
        create_figures(df_daily, window=ROLLING, save=True, folder=img_folder)
        figs = sorted([str(p) for p in Path(img_folder).rglob("*.jpg")])
        
        # create video
        ################################
        path_video = f"{model}_{year}.mp4"
        create_video(figs, path_video)
        
        # export all to Azure
        ################################
#         AzureTarget(path_meta).upload(upload_folder)
        [AzureTarget(fn).upload(upload_folder) for fn in figs]
        AzureTarget(path_video).upload(upload_folder)
        
        # delete local files
        ################################
#         os.remove(path_meta)
        shutil.rmtree(img_folder) 
        os.remove(path_video)
        os.remove(at.filename)
        
        print(f"{model}\t{year}\t{round((time.time()-t1)/60,2)} min")


GISS_E2_1_G_ssp585	2020	12.79 min
GISS_E2_1_G_ssp585	2021	12.58 min
GISS_E2_1_G_ssp585	2022	12.35 min
GISS_E2_1_G_ssp585	2023	13.28 min
GISS_E2_1_G_ssp585	2024	12.47 min
