In [None]:
from eib import *

MIN_OVERLAP = 3
MAX_DIST = 400

In [None]:
for b in BASINS:
    if not os.path.isfile("extended-ibtracs/extended-ibtracs_"+b+".nc"):
        print('-------', b, '-------')
        # Load IBTrACS for the given basin
        print("... Loading ...")
        ib = huracanpy.load("ibtracs/ibtracs_"+b+".csv")
    
        # load tracks data
        TRACK_flist = glob("input/TRACK-*.pkl")
        SyCLoPS_flist = glob("input/SyCLoPS-*.pkl")
        tracks = {}
        for f in TRACK_flist + SyCLoPS_flist:
            name = f.split("/")[-1].split('.')[0]
            with open(f, "rb") as file:
                tracks[name] = pkl.load(file)
    
        # =============   MATCHING  ============= #
        print("... Matching ...")
        # Step 1: For each source, perform individual matching and treat duplicates
        for ds in tqdm(tracks):
            ## Step 1a: Match the source with ibtracs
            M = huracanpy.assess.match([ib, tracks[ds]], ["ib", ds], min_overlap = MIN_OVERLAP, max_dist = MAX_DIST, 
                                      tracks1_is_ref = True,   # Treat the duplicates where one reanalysis tracks has several corresponding IBTrACS:
                                                               # Keep the couple with longest overlap
                                      )    
            ## Step 1b: Treat the duplicates where several RA tracks correspond to 1 obs.: Merge them together
            ### Merge with M to subset matched trackss and assign matching distance
            tracks[ds] = tracks[ds].to_dataframe().merge(M.rename(columns = {"id_"+ds:"track_id"})[["track_id", "dist"]])
            ### Update merged ids
            new_ids = M.groupby('id_ib')["id_"+ds].apply(lambda s: '+'.join(str(s)))
            replace = M.join(new_ids, on = "id_ib", lsuffix = "_old", rsuffix = "_new")[["id_"+ds+"_old", "id_"+ds+"_new"]]
            tracks[ds]["track_id"] = tracks[ds]["track_id"].replace(replace.set_index("id_"+ds+"_old")["id_"+ds+"_new"].to_dict())
            ### When several times are for one track_id: keep the one belonging to the closest tracks to IBTrACS
            tracks[ds] = tracks[ds].sort_values("dist").groupby(["track_id", "time"]).first().reset_index()
            ### Retransform into xarray
            tracks[ds] = tracks[ds].to_xarray().rename({"index":"record"})
    
        # Step 2: Redo the matching now that the tracks have been merged
        M = huracanpy.assess.match([ib, ib, *tracks.values()], ["ib", "ib2", *tracks.keys()], 
                                   min_overlap = MIN_OVERLAP, max_dist = MAX_DIST, 
                                   tracks1_is_ref = True,
                                  ).drop(columns = "id_ib2").drop_duplicates()
        # Note : The double ib is a trick to make sure all tracks from IBTrACS are included, since they all match with themselves
        # Remove matches where IBTrACS is not involved: IBTrACS is considered ground truth for whether a track was tropical
        M  = M[~(M.id_ib.isna())]
    
        # Step 3: Merge tracks together
        tracks_matched = []
        for id_ib in tqdm(np.unique(ib.track_id)): # Loop over rows of M
            track = ib[["track_id", "lon", "lat", "time",]].hrcn.sel_id(id_ib).set_coords("time").swap_dims({"record":"time"}
                                                            ).assign_coords(source = "IBTrACS").expand_dims(dim="source")
            if len(track.time) > 1:
                # Matches
                m = M[M.id_ib == id_ib]
                
                # Storm attributes
                name = ib.hrcn.sel_id(id_ib).name[0].values
            
                # Merge all data about the track in one source
                if len(m)>0:
                    for ds in tracks:
                        ds_id = getattr(m, "id_"+ds).values[0] # ID in the given source
                        vars2keep = [v for v in ["track_id", "lon", "lat", "pres", "wind10", "short_label",] if v in list(tracks[ds].variables.keys())]
                        tid = tracks[ds].hrcn.sel_id(ds_id).set_coords("time").swap_dims({"record":"time"}
                                                                                        ).reset_coords()[vars2keep
                                ].assign(source = ds).set_coords("source")
                        track = xr.concat([track, tid], dim = "source")
                
                # Save into tracks
                tracks_matched.append(track.swap_dims({"time":"record"}).rename({"track_id":"track_id_source"}).assign(track_id = id_ib, name = name))
                
        # Concatenate all tracks in a new source
        eib = xr.concat(tracks_matched, dim = "record")
        # Remove track_id_source
        eib = eib.drop_vars("track_id_source")
        # Filter 6-hourly
        eib = eib.where(eib.time.dt.hour % 6 == 0, drop = True)
    
        # =============   RE-ADD IBTRACS ATTRIBUTES  ============= #
        print("... IBTrACS Attributes ...")
        ## Transform into dataframes
        ib_df = ib.to_dataframe()
        teib_df = eib.sel(source = "IBTrACS").squeeze().reset_coords().to_dataframe()
        # Merge data
        M = teib_df[["time","track_id"]].reset_index().merge(ib_df, how = "left", on = ["time", "track_id"])
        assert len(M) == len(teib_df) # Check that no duplicate were created
        # Transform to xarray and select attributes
        M_xr = M.set_index("record").to_xarray()
        # Merge back into the extended ibtracs
        eib = eib.reset_coords().merge(M_xr.drop_vars(["lon", "lat",]))

        # =============   COMPUTE TRANSLATION SPEED AND AZIMUTH  ============= #
        print("... Computing Translation speed and azimuth ...")
        S = []
        A = []
        for ds in tqdm(eib.source):
            eib_ds = eib.sel(source = ds) # Subset source
            eib_ds = eib_ds.where(~np.isnan(eib_ds.lon), drop = True) # Remove points where given source does not have data
            speed = huracanpy.calc.translation_speed(eib_ds.lon, eib_ds.lat, eib_ds.time, eib_ds.track_id)
            S.append(speed)
            azimuth = huracanpy.calc.azimuth(eib_ds.lon, eib_ds.lat, eib_ds.track_id)
            A.append(azimuth)
        eib["translation_speed"] = xr.concat(S, dim = "source")
        eib["azimuth"] = xr.concat(A, dim = "source")
    
        # =============   ADD CPS PARAMETERS  ============= #
        print("... Adding CPS parameters ...")
        with open("input/tracks_with_CPS_data.pkl", "rb") as handle:
            tracks_with_CPS = pkl.load(handle)
            
        L = []
        for ds in tqdm(tracks_with_CPS):
            # Transform the sources into dataframes
            t_CPS_df = tracks_with_CPS[ds].to_dataframe()
            teib_df = eib.sel(source = ds).squeeze().reset_coords().to_dataframe()
            for var in ["lon", "lat"]: # Round up the coordinates for the matching to work
                N = 0
                t_CPS_df[var] = np.round(t_CPS_df[var], N)
                teib_df[var] = np.round(teib_df[var], N)
            # Merge extended ib and CPS source on lon, lat and time
            t_CPS_df = t_CPS_df[~t_CPS_df[["lon", "lat", "time",]].duplicated()] # Remove duplicates
            M = teib_df[["lon", "lat", "time",]].reset_index().merge(t_CPS_df[["lon", "lat", "time", "vtl", "vtu", "b"]], how = "left", on = ["lon", "lat", "time"])
            assert len(M) == len(teib_df) # Check that no duplicate were created
            M_xr = M.set_index("record").to_xarray().assign(source = ds).set_coords("source")[["vtu", "vtl", "b"]] # convert back to source and format
            L.append(M_xr)
        CPS = xr.concat(L, dim = "source") # Merge the CPS variables over sources
        eib = eib.merge(CPS) # Merge back into the full extended ibtracs source

        # Compute 1-day rolling means of CPS parameters
        eib_df = eib.to_dataframe()
        CPS_roll = eib_df.reset_index().set_index("record").sort_values("time").groupby(["source", "track_id"]).rolling(5, center = True)[["vtu", "vtl", "b"]].mean()
        eib[["vtu_roll", "vtl_roll", "b_roll"]] = CPS_roll.reset_index().set_index(["source", "record"]).to_xarray()[["vtu", "vtl", "b"]]

        # =============   ADD WCSI FLAGS  ============= #
        print("... Adding WCSI flags ...")
        with open("input/tracks_WCSI.pkl", "rb") as handle:
            tracks_WCSI = pkl.load(handle)
            
        L = []
        for ds in tqdm(tracks_WCSI):
            # Transform the sources into dataframes
            t_WCSI_df = tracks_WCSI[ds].to_dataframe()
            teib_df = eib.sel(source = ds).squeeze().reset_coords().to_dataframe()
            for var in ["lon", "lat"]: # Round up the coordinates for the matching to work
                N = 0
                t_WCSI_df[var] = np.round(t_WCSI_df[var], N)
                teib_df[var] = np.round(teib_df[var], N)
            # Merge extended ib and WCSI source on lon, lat and time
            t_WCSI_df = t_WCSI_df[~t_WCSI_df[["lon", "lat", "time",]].duplicated()] # Remove duplicates
            M = teib_df[["lon", "lat", "time",]].reset_index().merge(t_WCSI_df[["lon", "lat", "time", "is_tc"]], how = "left", on = ["lon", "lat", "time"])
            assert len(M) == len(teib_df) # Check that no duplicate were created
            M = M.rename(columns = {"is_tc":"WCSI"})
            M_xr = M.set_index("record").to_xarray().assign(source = ds).set_coords("source")[["WCSI"]] # convert back to source and format
            L.append(M_xr)
        WCSI = xr.concat(L, dim = "source") # Merge the WCSI variables over sources
        eib = eib.merge(WCSI) # Merge back into the full extended ibtracs source
        eib = eib.assign(WCSI = eib.WCSI.fillna(False).astype(bool)) # Convert to boolean

        # =============   ADD IS_TC & IS_ETC FLAGS  ============= #
        print("... Adding is_tc & is_etc flags ...")

        # Is TC if WCSI (TRACK with CPS) or TC (SyCLoPS) or TS nature (IBTrACS)
        eib["is_tc"] = (eib.WCSI == True) | (eib.short_label == "TC")
        eib.is_tc.loc[dict(source="IBTrACS")] = (eib.nature == "TS")

        # Is ETC if Cold Core Asymetric (CCA, TRACK with CPS) or EX (SyCLoPS) or ET nature (IBTrACS)
        eib["CCA"] = (np.abs(eib.b_roll) > 15) & (eib.vtl_roll < 0)
        eib["is_etc"] = (eib.CCA) | (eib.short_label == "EX")
        eib.is_etc.loc[dict(source="IBTrACS")] = (eib.nature == "ET")

        # =============   ADD ET FLAGS  ============= #
        print("... Adding ET flags ...")

        # Create ET variable
        eib["ET"]  = xr.DataArray(
            np.full(eib['lon'].shape, np.nan, dtype=object),
            dims=eib.dims,
            coords=eib.coords
        )

        # Convert to dataframe to be able to group by two variables
        eib_df = eib.to_dataframe().reset_index()

        # Identify last TS point
        TS = eib_df[eib_df.is_tc]
        last_TS = TS.sort_values("time").groupby(["source", "track_id"]).last()[["record", "time"]].reset_index()
        eib.ET.loc[dict(record = last_TS.record.to_xarray(), source = last_TS.source.to_xarray())] = +1

        # Subset points post ET onset
        tmp = eib_df.merge(last_TS, on = ["track_id", "source"], suffixes=['', '_onset'])
        tmp = tmp.assign(delta = tmp.time - tmp.time_onset)
        post_onset = tmp[tmp.delta >= np.timedelta64(0)]

        # Identify first ETC point
        EX = post_onset[post_onset.is_etc]
        first_EX = EX.sort_values("time").groupby(["source", "track_id"]).first()[["record", "time"]].reset_index()
        eib.ET.loc[dict(record = first_EX.record.to_xarray(), source = first_EX.source.to_xarray())] = -1

        # Convert ET to int8
        eib = eib.assign(ET = eib.ET.fillna(0).astype(np.int8))
        
        # =============   SAVE  ============= #
        print("... Saving ...")
        ## Specify time encoding
        eib.time.encoding['units'] = 'seconds since 1900-01-01'
        eib.time.encoding['calendar'] = 'standard'
        ## Changes types for lighter files
        ### Convert to float32
        varlist = list(eib.variables.keys())
        floatvars = list(set(varlist) & (set(["b_roll", "vtu_roll", "vtl_roll", "azimuth", "translation_speed", "lon", "lat"])))
        for var in floatvars:
            eib[var] = eib[var].astype(np.float32)
        ## Actual saving
        eib.to_netcdf("extended-ibtracs/extended-ibtracs_"+b+".nc")
        
        print("Done!")