# Vizualization of Inference

This notebook utilizes Geemap to display the results of the inference step on entire river basins. Refer to `inference_oos.ipynb` for the code to generate the inferences.

In [3]:
import geemap
import ee
import matplotlib.cm as cm
import matplotlib.colors
import numpy as np
import sys
import os
sys.path.insert(0, os.path.abspath('..'))

In [4]:
try:
    ee.Initialize()
except Exception as e:
    ee.Authenticate()
    ee.Initialize()

In [6]:
from experiment_configs.configs import satmae_large_inf_config
# config = satmae_large_config
wandb_id = satmae_large_inf_config.wandb_id.split('/')[-1]
threshold = satmae_large_inf_config.mean_threshold

bucket = 'gs://sand_mining_inference'
river = 'sone'
date = '2023-05-01'


prediction_path = f'{bucket}/{river}/{date}/{river}_prediction_{date}_{wandb_id}.tif'
s2_path = f'{bucket}/{river}/{date}/S2/{river}_s2_{date}.tif'
s1_path = f'{bucket}/{river}/{date}/S1/{river}_s1_{date}.tif'

In [9]:
# Create an ee.Image object from the GeoTIFF file
prediction = ee.Image.loadGeoTIFF(prediction_path)

threshold = threshold * 0.8

#mask out all values that are zero
mask = prediction.neq(0)

mask = prediction.lte(1.0).And(prediction.gt(threshold))

# Update the image to include the mask
prediction = prediction.updateMask(mask)

# Generate a viridis colormap
viridis = cm.get_cmap('magma', 256)

# Convert the colormap to a list of hexadecimal colors
viridis_hex = [matplotlib.colors.rgb2hex(rgb) for rgb in viridis(np.arange(256))]

# Define visualization parameters with the viridis palette
prediction_vis_params = {
    'min': 0,
    'max': 1,
    'palette': viridis_hex,
    'opacity': .5
}

# Create a Map
Map = geemap.Map(height='800px')


#add a satellite basemap


#Add the S2 image
s2_image_params = {
    'min': 0,
    'max': 3000,
    'bands': ['B3', 'B2', 'B1'],
    'gamma': 1.4 
}

s2_image = ee.Image.loadGeoTIFF(s2_path)
mask_s2 = s2_image.neq(0)
s2_image = s2_image.updateMask(mask_s2)



#S1
# s1_image = ee.Image.loadGeoTIFF(s1_path)

# s1_image_params = {
#     'min': -20,
#     'max': 0,
#     'bands': ['B1', 'B0', 'B0'],
#     'gamma': 1.0 
# }

# #stretch s1 to 98% of the histogram


# mask_s1 = s1_image.neq(0)
# s1_image = s1_image.updateMask(mask_s1)


# Add the image layer to the map and display it
Map.add_basemap('SATELLITE')
# Map.addLayer(s1_image, s1_image_params, 'S1')
Map.addLayer(s2_image, s2_image_params, 'S2 RGB')
# Add the prediction layer to the map
Map.addLayer(prediction, prediction_vis_params, 'Predictions')

# Display the map
#center the map on the image
Map.centerObject(prediction, 10)

Map


Map(center=[25.078281891522835, 83.62381047257249], controls=(WidgetControl(options=['position', 'transparent_…