# Mask the Rrs data using sea ice fraction

In [None]:
import numpy as np
import netCDF4 as nc
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import cartopy.crs as ccrs
import scipy


In [None]:
def load_data(filepath):
    ds=xr.open_mfdataset(filepath,mask_and_scale=True)
    return ds

## Load Rrs and sea ice

In [None]:
#AVHRR data first

ROOT = "/data/datasets/Projects/TuringCoccolithophoreBlooms"

filepath = ROOT+"/AVHRR_reflectance/monthly_mean/*.nc"
varname = "filtered_remote_sensing_reflectance"
#filepath = ROOT+"/no_backup/TuringCoccolithophoreBlooms/regridded_data/Rrs_560/*.nc"
#varname = "Rrs_560"

ds = load_data(filepath)
print(ds)
rrs = ds[varname]

#ensure no data written as nan
rrs.rio.write_nodata(np.nan, inplace=True)
print(f"nodata: {rrs.rio.nodata}")

In [None]:
#Comparative meteorlogical data

filepath = ROOT+"/no_backup/TuringCoccolithophoreBlooms/regridded_data/analysed_sst/*.nc"
varname = "sea_ice_fraction"

ds = load_data(filepath)
print(ds)
variable = ds[varname]

#ensure no data written as nan
variable.rio.write_nodata(np.nan, inplace=True)
print(f"nodata: {variable.rio.nodata}")

# Mask the Rrs data by the sea ice 

In [None]:
#Resample so date names are consistent

da = rrs.resample(time='1m').sum(skipna=False)
ds = variable.resample(time='1m').sum(skipna=False)

time = np.array(da.time)

In [None]:
masked = da.where((ds < 0.15), drop=False) #Keep where sea ice fraction is less than 0.15, else mask as NaN

In [None]:
#Set NaNs in Rrs to 0 and data to 1
relf = da.fillna(-999)
relf = relf.where(relf == -999, other=1)
relf = relf.where(relf != -999, other=0)

In [None]:
#Set NaNs in masked data to 0 and data to 1
filt = masked.fillna(-999)
filt = filt.where(filt == -999, other=1)
filt = filt.where(filt != -999, other=0)

In [None]:
#Find the difference to find where the pixels are now masked
diff = relf - filt

In [None]:
diff = diff.where(diff != 0, other=np.nan)

In [None]:
diff[300,:,:].plot()

# Calculate lat lon grid

In [None]:
#Get edges of grid cells (assuming consistent lat, lon grid)
lon = np.array(removed['longitude'])-0.05
lat = np.array(removed['latitude'])-0.025

#Get end of grid
lon = np.append(lon,removed['longitude'][-1])
lat = np.append(lat,removed['latitude'][-1])

#make into grid
x,y = np.meshgrid(lon,lat)

# Calculate area of grid cells

In [None]:
# Radius of Earth
R = 6371*1000

#Empty area for area
A = np.zeros((len(lon)-1,len(lat)-1))

#Cut grids
x_shift1 = x[1:,1:]
y_shift1 = y[1:,1:]

#Calculate area
A = np.pi/180 * R**2 * (np.sin(y[:-1,:-1]*np.pi/180)-np.sin(y_shift1*np.pi/180))*(x[:-1,:-1]-x_shift1)

In [None]:
#slices
slice_1 = slice('1981-01-01','1989-12-01')
slice_2 = slice('1990-01-01','1999-12-01')
slice_3 = slice('2000-01-01','2009-12-01')
slice_4 = slice('2010-01-01','2016-12-01')
slices=[slice_1,slice_2,slice_3,slice_4]

In [None]:
dec_area_min = np.zeros(4)
dec_area_mean = np.zeros(4)
dec_area_max = np.zeros(4)

In [None]:
#Calculate the area for each decade covered by those areas masked by sea ice
for c,s in enumerate(slices):
    dec_diff = diff.sel(time=s)
    rrs_time = dec_diff.time
    m_areas = np.zeros(len(rrs_time))
    for i in range(0,len(rrs_time)):
        #Find area of rrs removed by filter
        m_areas[i] = np.nansum(A*dec_diff[i,:,:])#.isnull()==0])
    
    dec_area_min[c] = np.nanmin(m_areas)
    dec_area_mean[c] = np.nanmean(m_areas)
    dec_area_max[c] = np.nanmax(m_areas)

In [None]:
dec_area_mean*1e-6

In [None]:
max(dec_area_mean*1e-6)

# Plot example

In [None]:
from matplotlib import colors

# Define the figure and each axis
fig, ax = plt.subplots(nrows=1,ncols=2,
                        subplot_kw={'projection': ccrs.PlateCarree()},figsize=(11,8.5))

index = np.datetime64('2008-12')
print(index)

ax[0].coastlines()
pcm = ax[0].pcolormesh(rrs.longitude,rrs.latitude,da.sel(time=index,method='nearest'))

cmap = plt.colorbar(pcm,shrink=0.3)
ax[0].set_yticks([-50, 0, 50], crs=ccrs.PlateCarree())
ax[0].set_xticks([-150, -100, -50, 0, 50, 100, 150], crs=ccrs.PlateCarree()) 
cmap.set_label('Rrs (sr$^{-1}$)')

ax[1].coastlines()
pcm = ax[1].pcolormesh(rrs.longitude,rrs.latitude,masked.sel(time=index,method='nearest'))
cmap = plt.colorbar(pcm,shrink=0.3)
ax[1].set_yticks([-50, 0, 50], crs=ccrs.PlateCarree())
ax[1].set_xticks([-150, -100, -50, 0, 50, 100, 150], crs=ccrs.PlateCarree()) 


cmap.set_label('Ice fraction masked Rrs (sr$^{-1}$)')
plt.show()

In [None]:
fileout = ROOT + "/data/rrs_masked_by_sea_ice.nc"

masked.to_netcdf(fileout)