Wellington Case study steps:

- Load all datasets
- Regrid DEM
- Regrid all datasets to be on the same grid as the DEM
- Save each dataset, and push to the mesh
-  

# File Setup

In [14]:
from joblib import Parallel, delayed
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xarray as xr
import cartopy
import datetime
import pyproj
import rioxarray
from scipy.interpolate import griddata
import scipy as sp
import folium
import os
import folium
from ipywidgets import interact, Dropdown
from IPython.display import display, clear_output
from itertools import combinations
import matplotlib
import dask.array as da
from concurrent.futures import ThreadPoolExecutor
from scipy.sparse import csr_matrix
from oceanum.datamesh import Connector
import netCDF4 as nc
datamesh=Connector(token='3052e2bdd10904ae353ac54ed205df32bfcc20e2')
import dask

In [2]:
# Define the current and target CRS using EPSG codes
current_crs = pyproj.CRS("EPSG:2193")  
target_crs = pyproj.CRS("EPSG:4326")

transformer = pyproj.Transformer.from_crs(current_crs, target_crs, always_xy=True,allow_ballpark=False)

# Load and preprocess Datasets

This is for getting everything onto datamesh, not needed once it's on there

### DEM

In [3]:
xr_dem = datamesh.query({'datasource':'linz-wellington_2013-2014-dem_1m-2193',"geofilter":{"type":"bbox","geom":[174.536322,-41.442487,175.124777,-41.049083]}})
xr_dem = xr_dem.band_data



In [4]:
dem_values = xr_dem.values
fill_value = np.min(dem_values)

dem_values[dem_values==fill_value] = np.nan
xr_dem.values = dem_values

In [5]:
# Create a regular grid of coordinates
res = 4000# low res to make the initial calcs fast
x_regular = np.linspace(xr_dem['x'].min(), xr_dem['x'].max(), res)
y_regular = np.linspace(xr_dem['y'].min(), xr_dem['y'].max(), res)

# Interpolate data onto regular grid
data_interpolated = xr_dem[0,:,:].interp(y=y_regular, x=x_regular)

# Now you can use the interpolated data for transformation
xr_new_DEM = xr.DataArray(data_interpolated,coords={"y":y_regular,"x":x_regular},dims=["y","x"])

# Change to regular lat-lon coords
new_x,new_y = transformer.transform(x_regular,y_regular)
xr_new_DEM = xr_new_DEM.assign_coords(x=new_x,y=new_y)

# Chunk to save memory
xr_new_DEM = xr_new_DEM.chunk(100)

In [6]:
xr_new_DEM

Unnamed: 0,Array,Chunk
Bytes,122.07 MiB,78.12 kiB
Shape,"(4000, 4000)","(100, 100)"
Dask graph,1600 chunks in 1 graph layer,1600 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 122.07 MiB 78.12 kiB Shape (4000, 4000) (100, 100) Dask graph 1600 chunks in 1 graph layer Data type float64 numpy.ndarray",4000  4000,

Unnamed: 0,Array,Chunk
Bytes,122.07 MiB,78.12 kiB
Shape,"(4000, 4000)","(100, 100)"
Dask graph,1600 chunks in 1 graph layer,1600 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [7]:
# datamesh.write_datasource(datasource_id='regridded_dem', 
#                           name="Regridded DEM",
#                           description="1m DEM of Wellington which has been regridded for casestudy", 
#                           data=xr_new_DEM, 
#                           coordinates={"x":"x","y":"y"},
#                           geom={'type':'Polygon','coordinates':[[[xr_new_DEM.x.min(),xr_new_DEM.y.min()],
#                                                                  [xr_new_DEM.x.max(),xr_new_DEM.y.min()],
#                                                                  [xr_new_DEM.x.max(),xr_new_DEM.y.max()],
#                                                                  [xr_new_DEM.x.min(),xr_new_DEM.y.max()]]]},
                                                                 
#                          )

In [8]:
# Create a grid from the DEM to use for the other variables
x = xr_new_DEM.x.astype(np.float32)
y = xr_new_DEM.y.astype(np.float32)
X, Y = np.meshgrid(x, y)

In [10]:
# del xr_new_DEM
del data_interpolated
del dem_values
del xr_dem

### Seismic Data

In [None]:
# Long term Jack should put results straight on datamesh, but this is still needed
array_2pc = np.load('for_DataMesh/for_DataMesh/2perc_disps_sites_c_MDEz_uniform.npy')
array_10pc = np.load('for_DataMesh/for_DataMesh/10perc_disps_sites_c_MDEz_uniform.npy')

df_2pc = pd.DataFrame(array_2pc)
df_2pc.columns = ['ID','Lon','Lat','Uplift','Subsidence','absVLM']
lon_lats = [transformer.transform(x,y) for x,y in zip(df_2pc.Lon,df_2pc.Lat)]
df_2pc['Lon'] = [x[0] for x in lon_lats]
df_2pc['Lat'] = [x[1] for x in lon_lats]

df_10pc = pd.DataFrame(array_10pc)
df_10pc.columns = ['ID','Lon','Lat','Uplift','Subsidence','absVLM']
lon_lats = [transformer.transform(x,y) for x,y in zip(df_10pc.Lon,df_10pc.Lat)]
df_10pc['Lon'] = [x[0] for x in lon_lats]
df_10pc['Lat'] = [x[1] for x in lon_lats]

array_2pc = griddata((df_2pc['Lon'], df_2pc['Lat']), df_2pc.absVLM, (X, Y), method='linear')
array_10pc = griddata((df_10pc['Lon'], df_10pc['Lat']), df_10pc.absVLM, (X, Y), method='linear')
array_none = array_10pc.copy()
array_none[:,:] = 0

# Change data type
array_2pc = array_2pc.astype(np.float32)
array_10pc = array_10pc.astype(np.float32)
array_none = array_none.astype(np.float32)

# Put all into larger array
seismic_array = np.empty((array_2pc.shape[0],array_2pc.shape[1],3))
seismic_array[:,:,0] = array_none
seismic_array[:,:,1] = array_2pc
seismic_array[:,:,2] = array_10pc


In [None]:
# Put all into an xarray
xr_seismic = xr.DataArray(seismic_array,coords=[x,y,[0,2,10]],dims=['x','y','exc_prob'])

In [None]:
xr_seismic = xr_seismic.chunk(100)

In [None]:
# datamesh.write_datasource(datasource_id='coseismic_vlm_displacement', 
#                           name="Coseismic VLM displacement",
#                           description="Coseismic VLM displacement for none, 2% and 10% exceedance probabilities", 
#                           data=xr_seismic, 
#                           coordinates={"x":"Lon","y":"Lat"},
#                           geom={'type':'Polygon','coordinates':[[[xr_seismic.Lon.min(),xr_seismic.Lat.min()],
#                                                                  [xr_seismic.Lon.max(),xr_seismic.Lat.min()],
#                                                                  [xr_seismic.Lon.max(),xr_seismic.Lat.max()],
#                                                                  [xr_seismic.Lon.min(),xr_seismic.Lat.max()]]]},
                                                                 
#                          )

In [None]:
# del xr_seismic
del seismic_array
del array_2pc
del array_10pc
del array_none

### Storm Surge

DataSet: https://data.4tu.nl/articles/_/13392314/1

In [None]:
xr_storm_surge = xr.open_dataset('COAST-RP.nc')
xr_100rp = xr_storm_surge['storm_tide_rp_0100']

In [None]:
well_y = -41.274678
well_x = 174.854143

In [None]:
df_storm_surge = xr_storm_surge.to_dataframe()

In [None]:
df_storm_surge.loc[:,'euclidean'] = (df_storm_surge.station_x_coordinate-well_x)**2+(df_storm_surge.station_y_coordinate-well_y)**2
df_storm_surge = df_storm_surge.sort_values('euclidean').reset_index(drop=True)
storm_tide_rps_dict = df_storm_surge.loc[0,[x for x in df_storm_surge.columns if 'storm_tide_rp' in x]].to_dict()

In [None]:
storm_tide_rps_dict = {k.split('_')[-1]:v for k,v in storm_tide_rps_dict.items()}

In [None]:
storm_tide_rps_dict

# Load VLMs

In [None]:
df_vlms = pd.read_csv('Welly_VLM_2018-2023_100m.txt',delimiter='\t')
df_vlms = df_vlms.astype(np.float32)

In [None]:
df_vlms = df_vlms.rename(columns={'  0.000000':'lon','  0.000000.1':'lat'})

In [None]:
df_vlms.set_index(['lat','lon'],inplace=True)
df_vlms.columns = df_vlms.columns.astype(float)
df_vlms = df_vlms.reset_index()

In [None]:
cols = [x for x in df_vlms.columns if x not in ['lon','lat']]

In [None]:
# Function to perform interpolation for a given year
def interpolate_year(year):
    Z_year = griddata((df_vlms['lon'].astype(np.float32), df_vlms['lat'].astype(np.float32)), df_vlms[year].astype(np.float32), (X, Y), method='linear')
    return np.nan_to_num(Z_year, nan=0).astype(np.float32)

# Function to process a chunk of years
def process_chunk(chunk):
    with ThreadPoolExecutor() as executor:
        return list(executor.map(interpolate_year, chunk))

# Determine chunk size and number of chunks
chunk_size = 10  # Adjust based on available memory
chunks = [cols[i:i + chunk_size] for i in range(0, len(cols), chunk_size)]

# Initialize list to collect results
all_chunks_results = []

# Process each chunk
for chunk in chunks:
    chunk_results = process_chunk(chunk)
    all_chunks_results.append(np.stack(chunk_results, axis=-1))

# Concatenate all chunks along the time dimension
xr_vlm_grid = np.concatenate(all_chunks_results, axis=-1)

# Create DataArray
xr_vlm_grid = xr.DataArray(xr_vlm_grid, coords={"x": np.array(x), "y": np.array(y), "years": cols}, dims=['x', 'y', 'years'])

# Interpolate NaN values
xr_vlm_grid = (xr_vlm_grid.chunk(dict(y=-1)).interpolate_na(dim='y', method='linear') + xr_vlm_grid.chunk(dict(x=-1)).interpolate_na(dim='x', method='linear')) / 2

# rechunk
xr_vlm_grid = xr_vlm_grid.chunk(100)

In [None]:
xr_vlm_grid

In [None]:
# datamesh.write_datasource(datasource_id='vlm_displacement', 
#                           name="VLM displacement",
#                           description="VLM displacement, iterpolated onto a 2D grid from Ian's first pass of Wellington", 
#                           data=xr_vlms, 
#                           coordinates={"x":"lon","y":"lat"},
#                           geom={'type':'Polygon','coordinates':[[[xr_vlms.lon.min(),xr_vlms.lat.min()],
#                                                                  [xr_vlms.lon.max(),xr_vlms.lat.min()],
#                                                                  [xr_vlms.lon.max(),xr_vlms.lat.max()],
#                                                                  [xr_vlms.lon.min(),xr_vlms.lat.max()]]]},
                                                                 
#                          )

In [None]:
del all_chunks_results
del df_vlms

# Load SLR Data

In [9]:
xr_slr = sp.io.loadmat('Ian_relabelled_sites/total_volc_noVLMssp585_medium_confidence_values.mat')
years = np.unique(xr_slr['years'])


In [10]:
num_years = len(years)

In [11]:
file_names = os.listdir("Ian_relabelled_sites/")
file_names = [x for x in file_names if ('total' in x)&('medium' in x)&('_noVLM' in x)]

quantiles = [0.17,0.5,0.83]


In [12]:
from concurrent.futures import ProcessPoolExecutor

In [15]:
# Initialize a Dask array to store the results
slr_array = da.zeros(
    (len(x), len(y), len(np.unique(sp.io.loadmat(f'Ian_relabelled_sites/{file_names[0]}')['years'])), len(quantiles), len(file_names)),
    dtype=np.float32, chunks=(len(x), len(y), 1, 1, 1)
)

def process_file(k, file):
    data = sp.io.loadmat(f'Ian_relabelled_sites/{file}')
    years = np.unique(data['years'])
    lats = {loc: lat for loc, lat in zip(data['locations'].squeeze(), data['lat'].squeeze())}
    lons = {loc: lon for loc, lon in zip(data['locations'].squeeze(), data['lon'].squeeze())}

    slr_da = xr.DataArray(
        data['sea_level_change'],
        coords={
            'locations': data['locations'].squeeze(),
            'years': data['years'].squeeze(),
            'quantiles': data['quantiles'].squeeze()
        },
        dims=['locations', 'years', 'quantiles']
    ).sel(quantiles=quantiles).to_dataframe('slr')

    slr_da['Lat'] = slr_da.index.get_level_values('locations').map(lats)
    slr_da['Lon'] = slr_da.index.get_level_values('locations').map(lons)
    
    # Filter the DataFrame for the Wellington region
    slr_da = slr_da[
        (slr_da.Lon > np.min(x.values)) &
        (slr_da.Lat > np.min(y.values)) &
        (slr_da.Lon < np.max(x.values)) &
        (slr_da.Lat < np.max(y.values))
    ].reset_index().drop(columns='locations').astype(np.float32)

    # Convert DataFrame columns to NumPy arrays
    years_arr = slr_da['years'].to_numpy()
    quantiles_arr = slr_da['quantiles'].to_numpy()
    lon_arr = slr_da['Lon'].to_numpy()
    lat_arr = slr_da['Lat'].to_numpy()
    slr_arr = slr_da['slr'].to_numpy()

    def process_year_quantile(year, quantile):
        mask = (years_arr == year) & (quantiles_arr == quantile)
        if np.any(mask):
            grid = griddata((lon_arr[mask], lat_arr[mask]), slr_arr[mask], (X, Y), method='linear').astype(np.float32)
            return grid
        return None

    grids = Parallel(n_jobs=-1)(
        delayed(process_year_quantile)(year, quantile)
        for year in years
        for quantile in quantiles
    )

    # Filter out None results and stack into a Dask array
    valid_grids = [grid for grid in grids if grid is not None]
    if valid_grids:
        stacked_grids = da.stack(valid_grids, axis=-1).rechunk((len(x), len(y), num_years, len(quantiles)))
        slr_array[:, :, :, :, k] = stacked_grids.reshape((len(x), len(y), num_years, len(quantiles)))

    return k

# Process each file
for k, file in enumerate(file_names):
    process_file(k, file)

# Rechunk the final Dask array
# slr_array = slr_array.rechunk((len(x), len(y), len(years), len(quantiles), len(file_names)))

# Create an xarray DataArray
xr_slr = xr.DataArray(
    slr_array,
    coords={"Lon": np.array(x), "Lat": np.array(y), "years": years, "quantiles": quantiles, "scenarios": file_names},
    dims=['Lon', 'Lat', 'years', 'quantiles', 'scenarios']
)

# Fill missing values
xr_slr_filled = xr_slr.ffill('Lon').bfill('Lon').ffill('Lat').bfill('Lat')

# rechunking
xr_slr_filled = xr_slr_filled.chunk(100)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array.reshape(shape)

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    >>> array.reshape(shape, limit='128 MiB')
  if await self.run_code(code, result, async_=asy):
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array.reshape(shape)

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    >>> array.reshape(shape, limit='128 MiB')
  if await self.run_code(code, result, async_=asy):
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array.reshape(shape)

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    >>> array.reshape(shape, limit='128 MiB')
  if await self.run_code(code, result, async

In [23]:
xr_slr_filled

Unnamed: 0,Array,Chunk
Bytes,12.52 GiB,8.01 MiB
Shape,"(4000, 4000, 14, 3, 5)","(100, 100, 14, 3, 5)"
Dask graph,1600 chunks in 22 graph layers,1600 chunks in 22 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 12.52 GiB 8.01 MiB Shape (4000, 4000, 14, 3, 5) (100, 100, 14, 3, 5) Dask graph 1600 chunks in 22 graph layers Data type float32 numpy.ndarray",4000  4000  5  3  14,

Unnamed: 0,Array,Chunk
Bytes,12.52 GiB,8.01 MiB
Shape,"(4000, 4000, 14, 3, 5)","(100, 100, 14, 3, 5)"
Dask graph,1600 chunks in 22 graph layers,1600 chunks in 22 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


# Adding SLR and VLM together Data

In [None]:
# Find a mean of the rates and project forward
xr_vlm_mean = duplicated_arr.mean(dim='years',skipna=True)
xr_vlm_mean = xr_vlm_mean.expand_dims({'years':xr_slr.years.values})

year_count = (xr_vlm_mean.years-np.min(xr_vlm_mean.years)) #should really start when the DEM was made for
xr_vlm_forecast = xr_vlm_mean*np.array(year_count)[:,None,None]
xr_vlm_forecast = xr_vlm_forecast.transpose('lon','lat','years')
xr_slr_vlm_adjusted = xr_slr_filled-np.array(xr_vlm_forecast)[:,:,:,None,None]

# Flooding the DEM

In [None]:
new_DEM_mask = new_DEM.copy()
new_DEM_mask = new_DEM_mask/new_DEM_mask
new_DEM_mask = new_DEM_mask.fillna(123456789)#where(new_DEM_mask!=float('NaN'),123456789)
new_DEM_mask = new_DEM_mask.where(new_DEM_mask!=1,np.nan)
new_DEM_mask = new_DEM_mask.where(new_DEM_mask!=123456789,1)

In [None]:
(np.empty((len(list(storm_tide_rps_dict.keys())),
         len(list(earthquake_scens.keys())),
         len(years),
         len(file_names),
         len(x),
         len(y))).astype(np.float32)).shape

In [None]:
masked_xarray = xr.DataArray(np.empty((len(list(storm_tide_rps_dict.keys())),
         len(list(earthquake_scens.keys())),
         len(years),
         len(file_names),
         len(x),
         len(y))).astype(np.float32),dims=['stormsurge','earthquake','years','file','lon','lat'],
                            coords=[list(storm_tide_rps_dict.keys()),list(earthquake_scens.keys()),years,file_names,x,y])

flooded_xarray = xr.DataArray(np.empty((len(list(storm_tide_rps_dict.keys())),
         len(list(earthquake_scens.keys())),
         len(years),
         len(file_names),
         len(x),
         len(y))).astype(np.float32),dims=['stormsurge','earthquake','years','file','lon','lat'],
                            coords=[list(storm_tide_rps_dict.keys()),list(earthquake_scens.keys()),years,file_names,x,y])

In [None]:
masked_array_dict = {}
flooded_dict = {}

for storm_rps,storm_value in storm_tide_rps_dict.items():
    for e_scen,e_array in earthquake_scens.items():
        for year in years:
            for scenario in file_names:
                xr_slr_vlm_adjusted_year = xr_slr_vlm_adjusted.sel(years=year,scenarios=scenario)
                
                new_DEM_flooded_low = new_DEM-np.array(xr_slr_vlm_adjusted_year.sel(quantiles=np.min(quantiles)))/1000-e_array-float(storm_value)
                new_DEM_flooded_low = new_DEM_flooded_low
                flooded_low = new_DEM_flooded_low.where(new_DEM_flooded_low<0,np.nan)*0+1
                
                new_DEM_flooded_mid = new_DEM-np.array(xr_slr_vlm_adjusted_year.sel(quantiles=0.5))/1000-e_array-float(storm_value)
                new_DEM_flooded_mid = new_DEM_flooded_mid
                flooded_mid = new_DEM_flooded_mid.where(new_DEM_flooded_mid<0,np.nan)*0+1
                
                new_DEM_flooded_high = new_DEM-np.array(xr_slr_vlm_adjusted_year.sel(quantiles=np.max(quantiles)))/1000-e_array-float(storm_value)
                new_DEM_flooded_high = new_DEM_flooded_high
                flooded_high = new_DEM_flooded_high.where(new_DEM_flooded_high<0,np.nan)*0+1
                
                flooded = flooded_low.fillna(0)+flooded_mid.fillna(0)+flooded_high.fillna(0)
                flooded = flooded.where(flooded>0,np.nan)
                
                masked_flooded = np.ma.masked_invalid(flooded)
                masked_flooded = (np.max(masked_flooded)-masked_flooded)/(np.max(masked_flooded)-np.min(masked_flooded))
                masked_flooded = masked_flooded[::-1,:]
        
                masked_array_dict.update({
                    (year,scenario,e_scen,storm_rps):masked_flooded
                })
        
                flooded_dict.update({
                    (year,scenario,e_scen,storm_rps):flooded
                })

                masked_xarray.loc[{'stormsurge': storm_rps, 'earthquake': e_scen, 'years': year, 'file': scenario}] = masked_flooded
                flooded_xarray.loc[{'stormsurge': storm_rps, 'earthquake': e_scen, 'years': year, 'file': scenario}] = flooded

     

In [None]:
xr_slr_vlm_adjusted_year = xr_slr_vlm_adjusted.sel(years=year,scenarios=scenario)

new_DEM_flooded_low = new_DEM-np.array(xr_slr_vlm_adjusted_year.sel(quantiles=np.min(quantiles)))/1000-e_array-float(storm_value)
new_DEM_flooded_low = new_DEM_flooded_low
flooded_low = new_DEM_flooded_low.where(new_DEM_flooded_low<0,np.nan)*0+1

new_DEM_flooded_mid = new_DEM-np.array(xr_slr_vlm_adjusted_year.sel(quantiles=0.5))/1000-e_array-float(storm_value)
new_DEM_flooded_mid = new_DEM_flooded_mid
flooded_mid = new_DEM_flooded_mid.where(new_DEM_flooded_mid<0,np.nan)*0+1

new_DEM_flooded_high = new_DEM-np.array(xr_slr_vlm_adjusted_year.sel(quantiles=np.max(quantiles)))/1000-e_array-float(storm_value)
new_DEM_flooded_high = new_DEM_flooded_high
flooded_high = new_DEM_flooded_high.where(new_DEM_flooded_high<0,np.nan)*0+1

flooded = flooded_low.fillna(0)+flooded_mid.fillna(0)+flooded_high.fillna(0)
flooded = flooded.where(flooded>0,np.nan)

masked_flooded = np.ma.masked_invalid(flooded)
masked_flooded = (np.max(masked_flooded)-masked_flooded)/(np.max(masked_flooded)-np.min(masked_flooded))
masked_flooded = masked_flooded[::-1,:]

masked_array_dict.update({
    (year,scenario,e_scen,storm_rps):masked_flooded
})

flooded_dict.update({
    (year,scenario,e_scen,storm_rps):flooded
})

masked_xarray.loc[{'stormsurge': storm_rps, 'earthquake': e_scen, 'years': year, 'file': scenario}] = masked_flooded
flooded_xarray.loc[{'stormsurge': storm_rps, 'earthquake': e_scen, 'years': year, 'file': scenario}] = flooded

# Save to the mesh

In [None]:
data_vars = {'mask': masked_xarray, 'data': flooded_xarray}
# Create a Dataset from the dictionary
xr_flooded = xr.Dataset(data_vars)

In [None]:
xr_flooded

In [None]:
# datamesh.write_datasource(datasource_id="wellington_dynamic_shoreline_projections_mask",
#                           name="Wellington Casestudy Dynamic Shoreline Projections Mask",
#                           description="Wellington Casestudy Dynamic Shoreline Projections Mask",
#                           data=xr_flooded,
#                           # coordinates={'lon':'longitude','lat':'latitude'},
#                          tags=['demo', 'wellington','lower hutt'],
#                          )

# Visualise

In [None]:
cmap = plt.cm.get_cmap('brg')

wellington_coords = [-41.28664, 174.77557]
zoom = 11

def create_map(year,ssp, earthquake,storm_surge):
    token = "pk.eyJ1Ijoic2hhbm5vbi1iZW5ndHNvbiIsImEiOiJja3F1Y2Q0dHEwMzYwMm9wYmtzYzk2bDZuIn0.5jGMyEiJdmXs1HL7x3ThPw" # your mapbox token
    tileurl = 'https://api.mapbox.com/v4/mapbox.satellite/{z}/{x}/{y}@2x.png?access_token=' + str(token)
    m = folium.Map(location=wellington_coords, zoom_start=zoom)
    custom_tile_layer = folium.TileLayer(tiles=tileurl, name='Satellite',attr='Mapbox',overlay=False).add_to(m)

    masked_flooded = masked_array_dict[(year,ssp,earthquake,storm_surge)]
    flooded = flooded_dict[(year,ssp,earthquake,storm_surge)]

    # print(year,ssp)
    
    folium.raster_layers.ImageOverlay(cmap(masked_flooded),
                                      [[flooded.y.values.min(), flooded.x.values.min()],
                                       [flooded.y.values.max(), flooded.x.values.max()]],opacity=.5).add_to(m)

    return(m)
    
# Define dropdown options
dropdown_options1 = ['ssp126','ssp126', 'ssp245','ssp370','ssp585']
dropdown_options2 = years
dropdown_options3 = list(earthquake_scens.keys())
dropdown_options4 = list(storm_tide_rps_dict.keys())

# Define callback function to update map when dropdown value changes
def update_map(year, ssp, earthquake,storm_surge):
    clear_output(wait=True)
    try:
        m = create_map(year, f'total_volc_noVLM{ssp}_medium_confidence_values.mat',earthquake,storm_surge)
        display(m)
    except Exception as e:
        print(f"Error occurred: {e}")

# Create interactive dropdowns
interact(update_map, year=dropdown_options2, ssp=dropdown_options1,earthquake=dropdown_options3,storm_surge=dropdown_options4)


In [None]:
masked_flooded.shape

In [None]:
array_10pc.shape