functions from the dataloader.py

In [139]:
import numpy as np


def encode_lat_lon(lat, lon) :
    """
    Encode the latitude and longitude into sin/cosine values. We use a simple WRAP positional encoding, as 
    Mac Aodha et al. (2019).

    Args:
    - lat (float): the latitude
    - lon (float): the longitude

    Returns:
    - (lat_cos, lat_sin, lon_cos, lon_sin) (tuple): the sin/cosine values for the latitude and longitude
    """

    # The latitude goes from -90 to 90
    lat_cos, lat_sin = np.cos(np.pi * lat / 90), np.sin(np.pi * lat / 90)
    # The longitude goes from -180 to 180
    lon_cos, lon_sin = np.cos(np.pi * lon / 180), np.sin(np.pi * lon / 180)

    # Now we put everything in the [0,1] range
    lat_cos, lat_sin = (lat_cos + 1) / 2, (lat_sin + 1) / 2
    lon_cos, lon_sin = (lon_cos + 1) / 2, (lon_sin + 1) / 2

    return lat_cos, lat_sin, lon_cos, lon_sin


def encode_coords(central_lat, central_lon, patch_size, resolution = 10) :
    """ 
    This function computes the latitude and longitude of a patch, from the latitude and longitude of its central pixel.
    It then encodes these values into sin/cosine values, and scales the results to [0,1].

    Args:
    - central_lat (float): the latitude of the central pixel
    - central_lon (float): the longitude of the central pixel
    - patch_size (tuple): the size of the patch
    - resolution (int): the resolution of the patch

    Returns:
    - (lat_cos, lat_sin, lon_cos, lon_sin) (tuple): the sin/cosine values for the latitude and longitude
    """

    # Initialize arrays to store latitude and longitude coordinates

    i_indices, j_indices = np.indices(patch_size)

    # Calculate the distance offset in meters for each pixel
    offset_lat = (i_indices - patch_size[0] // 2) * resolution
    offset_lon = (j_indices - patch_size[1] // 2) * resolution

    # Calculate the latitude and longitude for each pixel
    latitudes = central_lat + (offset_lat / 6371000) * (180 / np.pi)
    longitudes = central_lon + (offset_lon / 6371000) * (180 / np.pi) / np.cos(central_lat * np.pi / 180)

    lat_cos, lat_sin, lon_cos, lon_sin = encode_lat_lon(latitudes, longitudes)

    return lat_cos, lat_sin, lon_cos, lon_sin


def normalize_data(data, norm_values, norm_strat, nodata_value = None) :
    """
    Normalize the data, according to various strategies:
    - mean_std: subtract the mean and divide by the standard deviation
    - pct: subtract the 1st percentile and divide by the 99th percentile
    - min_max: subtract the minimum and divide by the maximum

    Args:
    - data (np.array): the data to normalize
    - norm_values (dict): the normalization values
    - norm_strat (str): the normalization strategy

    Returns:
    - normalized_data (np.array): the normalized data
    """

    if norm_strat == 'mean_std' :
        mean, std = norm_values['mean'], norm_values['std']
        if nodata_value is not None :
            data = np.where(data == nodata_value, 0, (data - mean) / std)
        else : data = (data - mean) / std

    elif norm_strat == 'pct' :
        p1, p99 = norm_values['p1'], norm_values['p99']
        if nodata_value is not None :
            data = np.where(data == nodata_value, 0, (data - p1) / (p99 - p1))
        else :
            data = (data - p1) / (p99 - p1)
        data = np.clip(data, 0, 1)

    elif norm_strat == 'min_max' :
        min_val, max_val = norm_values['min'], norm_values['max']
        if nodata_value is not None :
            data = np.where(data == nodata_value, 0, (data - min_val) / (max_val - min_val))
        else:
            data = (data - min_val) / (max_val - min_val)
    
    else: 
        raise ValueError(f'Normalization strategy `{norm_strat}` is not valid.')

    return data


def normalize_bands(bands_data, norm_values, order, norm_strat, nodata_value = None) :
    """
    This function normalizes the bands data using the normalization values and strategy.

    Args:
    - bands_data (np.array): the bands data to normalize
    - norm_values (dict): the normalization values
    - order (list): the order of the bands
    - norm_strat (str): the normalization strategy
    - nodata_value (int/float): the nodata value

    Returns:
    - bands_data (np.array): the normalized bands data
    """
    normalized = {}
    for i, band in enumerate(order) :
        if band != 'SCL' and band != 'transform':
            print('normalizing ', band)
            band_norm = norm_values[band]
            # print(band_norm)
            # print(bands_data[band].shape)
            normalized[band] = normalize_data(bands_data[band], band_norm, norm_strat, nodata_value)
    
    return normalized

read the s2 zip files and the icesat cropped mosaic (with code from create_patches.py)

In [2]:
import os
import glob
import sys
import traceback
import geopandas as gpd
from zipfile import ZipFile
from shutil import rmtree
sys.path.insert(1, '/scratch2/biomass_estimation/code/patches')
from helper_patches import *
from create_patches import *
path_icesat = '/scratch2/biomass_estimation/code/notebook/cropped_mosaic_no_nan'
tilenames = 'tile_names_inference.txt'
path_shp = os.path.join('/scratch2', 'biomass_estimation', 'code', 'notebook', 'S2_tiles_Siberia_polybox', 'S2_tiles_Siberia_all.geojson')
path_s2 = '/scratch3/Siberia'

# Read the Sentinel-2 grid shapefile
grid_df = gpd.read_file(path_shp, engine = 'pyogrio')

# List all S2 tiles and their geometries
tile_names, tile_geoms = list_s2_tiles(tilenames, grid_df, path_s2)

for s2_prod in tile_names:
    print(f'>> Extracting patches for product {s2_prod}.')
    # Get the filename that includes s2_prod in the path_s2 folder
    total_s2_path = glob.glob(f"{path_s2}/*{s2_prod}*")[0]

    # Extract the folder and the filename separately
    s2_folder_path, s2_file_name = os.path.split(total_s2_path)

    print(f'>> Found {total_s2_path}.')

    # Unzip the S2 L2A product if it hasn't been done
    total_unzipped_path = total_s2_path[:-4] + '.SAFE'
    if not os.path.exists(total_unzipped_path):
        try:
            with ZipFile(total_s2_path, 'r') as zip_ref:
                zip_ref.extractall(path_s2)
        except Exception as e:
            print(f'>> Could not unzip {s2_prod}.')
            print(e)
            continue

    # Reproject and upsample the S2 bands            
    try: 
        transform, upsampling_shape, processed_bands, crs_2, bounds = process_S2_tile(product=s2_file_name[:-4], path_s2 = s2_folder_path)
    except IndexError:
        s2_folder_path = os.path.join(path_s2, 'scratch2', 'gsialelli', 'S2_L2A', 'Siberia')
        total_unzipped_path = os.path.join(s2_folder_path, s2_file_name[:-4] + '.SAFE')
        try: 
            transform, upsampling_shape, processed_bands, crs_2, bounds = process_S2_tile(product=s2_file_name[:-4], path_s2 = s2_folder_path)
        except Exception as e:
            print(f'>> Could not process product {s2_prod}.')
            print(traceback.format_exc())
            continue
    icesat_raw = load_BM_data(path_bm=path_icesat, tile_name=s2_prod)


    # Remove the unzipped S2 product
    rmtree(total_unzipped_path)
    break


>> Extracting patches for product 52VEL.
>> Found /scratch3/Siberia/S2A_MSIL2A_20200804T024551_N9999_R132_T52VEL_20240217T142025.zip.


normalize both icesat and s2

In [140]:
import pickle
from os.path import join
norm_path = '/scratch2/biomass_estimation/code/ml/data'
with open(join(norm_path, 'normalization_values.pkl'), mode = 'rb') as f:
            norm_values = pickle.load(f)
print("processed_bands keys ", processed_bands.keys())
print("norm_values[S2_bands] keys", norm_values['S2_bands'].keys()) #we don't have SCL in the normalization values so not in the model...?
print(norm_values.keys())

norm_strat = "mean_std"
icesat_order = sorted(list(icesat_raw.keys()))
icesat_norm = normalize_bands(icesat_raw, norm_values['BM'], icesat_order, norm_strat, nodata_value = -9999.0)

s2_order = sorted(list(processed_bands.keys()))
s2_bands_dict = normalize_bands(processed_bands, norm_values['S2_bands'], s2_order, norm_strat, nodata_value = 0)
s2_indices = [s2_order.index(band) for band in s2_bands_dict]



processed_bands keys  dict_keys(['B02', 'B03', 'B04', 'B08', 'B05', 'B06', 'B07', 'B8A', 'B11', 'B12', 'SCL', 'B01', 'B09'])
norm_values[S2_bands] keys dict_keys(['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'])
dict_keys(['S2_bands', 'BM', 'Sentinel_metadata', 'GEDI'])
normalizing  bm
normalizing  std
normalizing  B01
normalizing  B02
normalizing  B03
normalizing  B04
normalizing  B05
normalizing  B06
normalizing  B07
normalizing  B08
normalizing  B09
normalizing  B11
normalizing  B12
normalizing  B8A


make a np array from a dict

In [142]:
s2_bands = np.stack([s2_bands_dict[key] for key in s2_bands_dict.keys()], axis=-1)
s2_bands = s2_bands[:, :, s2_indices]
print(s2_bands.dtype)
print(s2_bands.shape)

float32
(10980, 10980, 12)


upsample icesat to s2 resolution (code from create_patches)

In [62]:
print("icesat_raw keys " , icesat_raw.keys())
print("processed_bands keys ", processed_bands.keys())
print("s2 bands shape ",s2_bands_dict['B01'].shape)
print("icesat_raw bands shape ", icesat_norm['bm'].shape)
print("icesat_raw transform: \n", icesat_raw['transform']) #what should we do with this??
icesat = {}
icesat['bm'] = upsampling_with_nans(icesat_norm['bm'], s2_bands_dict['B01'].shape, -9999, 3)
icesat['std'] = upsampling_with_nans(icesat_norm['std'], s2_bands_dict['B01'].shape, -9999, 3)

print("icesat bands shape after upsampling", icesat['bm'].shape)


icesat_raw keys  dict_keys(['bm', 'std', 'transform'])
processed_bands keys  dict_keys(['B02', 'B03', 'B04', 'B08', 'B05', 'B06', 'B07', 'B8A', 'B11', 'B12', 'SCL', 'B01', 'B09'])
s2 bands shape  (10980, 10980)
icesat_raw bands shape  (3612, 3612)
icesat_raw transform: 
 | 30.41, 0.00, 499955.10|
| 0.00,-30.41, 6600021.21|
| 0.00, 0.00, 1.00|
icesat bands shape after upsampling (10980, 10980)


define model

In [91]:
import torch.nn as nn
class SimpleFCN(nn.Module):
    def __init__(self,
                 in_features=18,
                 channel_dims = (16, 32, 64, 128),
                 num_outputs=1,
                 kernel_size=3,
                 stride=1):
        """
        A simple fully convolutional neural network.
        """
        super(SimpleFCN, self).__init__()
        self.relu = nn.ReLU(inplace = True)
        layers = list()
        for i in range(len(channel_dims)):
            in_channels = in_features if i == 0 else channel_dims[i-1]
            layers.append(nn.Conv2d(in_channels=in_channels, 
                                    out_channels=channel_dims[i], 
                                    kernel_size=kernel_size, stride=stride, padding=1))
            layers.append(nn.BatchNorm2d(num_features=channel_dims[i]))
            layers.append(self.relu)
        print(layers)
        self.conv_layers = nn.Sequential(*layers)
        
        self.conv_output = nn.Conv2d(in_channels=channel_dims[-1], out_channels=num_outputs, kernel_size=1,
                                     stride=1, padding=0, bias=True)
        # self.fc = nn.Linear(15*15*num_outputs, 1)  # Fully connected layer to get a single output value

    def forward(self, x):
        x = self.conv_layers(x)
        # print(x.shape)
        x = self.conv_output(x)
        # x = x.flatten(start_dim=1)
        # predictions = self.fc(x)
        # return predictions.squeeze()  # Remove the extra dimension
        return x
    

In [143]:
from affine import Affine
from pyproj import Transformer
import torch
print("S2 transform")
print(transform)
print("S2 ", crs_2)
fwd = Affine.from_gdal(transform[2], transform[0], transform[1], transform[5], transform[3], transform[4])

coordinate_transformer = Transformer.from_crs(crs_2, 'epsg:4326')

model = SimpleFCN()
model.load_state_dict(torch.load('/scratch2/biomass_estimation/code/ml/saved_model2.pth'))
model.eval()


for i in range(7, 10980, 15):
    for j in range(7, 10980, 15):
                
        data = []
        s2_temp = s2_bands[i-7:i+8, j-7:j+8,:]
        data.extend([s2_temp])
        print(len(s2_temp))
        

        lat1, lon1 = fwd * (i, j)
        #print(lat1, lon1)
        lat2, lon2 = coordinate_transformer.transform(lat1, lon1)
        #print(lat2, lon2)
        lat_cos, lat_sin, lon_cos, lon_sin = encode_coords(lat2, lon2, (15, 15))
        data.extend([lat_cos[..., np.newaxis], lat_sin[..., np.newaxis], lon_cos[..., np.newaxis], lon_sin[..., np.newaxis]])

        icesat_temp_bm = icesat['bm'][i-7:i+8, j-7:j+8, np.newaxis]
        icesat_temp_std = icesat['std'][i-7:i+8, j-7:j+8, np.newaxis]
        data.extend([icesat_temp_bm, icesat_temp_std])
        for i in range(len(data)):
            print(data[i].shape)
        print(len(data))
        # Concatenate the data together
        data = torch.from_numpy(np.concatenate(data, axis = -1).swapaxes(-1, 0)).to(torch.float)
        outputs = model(data.unsqueeze(0))
        print(outputs)
        break
    break


S2 transform
| 10.00, 0.00, 499980.00|
| 0.00,-10.00, 6600000.00|
| 0.00, 0.00, 1.00|
S2  EPSG:32652
[Conv2d(18, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True)]
15
(15, 15, 12)
(15, 15, 1)
(15, 15, 1)
(15, 15, 1)
(15, 15, 1)
(15, 15, 1)
(15, 15, 1)
7
tensor([[[[5.5139e+35, 8.6832e+35, 1.1441e+36, 1.1593e+36, 1.1866e+36,
           1.2026e+36, 1.2205e+36, 1.2377e+36, 1.2523e+36, 1.26