In [None]:
import os
from osgeo import gdal, ogr, osr
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import osgeo
import copy

gdal.UseExceptions()

In [None]:
os.environ['PROJ_DATA']='/workspace/.conda/envs/env_labels/share/proj' 
# os.environ['GDAL_DATA']='/workspace/.conda/envs/env_labels/share/gdal'
# os.environ['GTIFF_SRS_SOURCE'] = 'EPSG'

In [None]:
import pystac
from pystac import Link, Asset
from pystac.extensions.label import LabelExtension
from pystac.extensions.label import LabelType
from pystac.extensions.label import LabelClasses
from pystac.extensions.label import LabelStatistics
from pystac.extensions.version import ItemVersionExtension

In [None]:
# Set dataframe to None 
df = None

In [None]:
def pixel_to_coords(source, x, y):
    """Returns global coordinates in EPSG:4326 from pixel x, y coords"""

    geo_transform = source.GetGeoTransform()

    x_min = geo_transform[0]
    x_size = geo_transform[1]
    y_min = geo_transform[3]
    y_size = geo_transform[5]
    px = x * x_size + x_min
    py = y * y_size + y_min

    srs = osr.SpatialReference()
    
    # GDAL 3 changes axis order: https://github.com/OSGeo/gdal/issues/1546
    if int(osgeo.__version__[0]) >= 3:
        srs.SetAxisMappingStrategy(osgeo.osr.OAMS_TRADITIONAL_GIS_ORDER)
    
    srs.ImportFromWkt(source.GetProjection())

    srs_4326 = srs.CloneGeogCS()
    ct = osr.CoordinateTransformation(srs, srs_4326)

    long, lat, _ = ct.TransformPoint(px, py)

    return long, lat

## Read STAC Item

In [None]:
aws_url = "https://earth-search.aws.element84.com/v0/collections/sentinel-s2-l2a-cogs/items/S2B_10TFK_20210713_0_L2A"
item = pystac.read_file(aws_url)
display(item.properties)

In [None]:
print(f'Available bands: {list(item.assets.keys())}')

## Extract values of selected band(s) for a defined random number of pixels
Select a band with highest resolution, ie `B02`, `B03`, `B04`, `B08`.

In [None]:
high_res_bands = ['B02', 'B03', 'B04', 'B08']

In [None]:
band = 'B04' # 'B04'
assert band in high_res_bands, f'choose high res. band, ie {high_res_bands}'

In [None]:
print('Band:', band)
print(f'- Res: {item.assets[band].to_dict()["proj:transform"][0]}m')
# Extract band
b_href = item.assets[band].href
print('- href:', b_href)

# Get gdal object
b_g_10m = gdal.Open(b_href)
# b_rst = b_g.GetRasterBand(1)

# # Get array
# b_arr = b_rst.ReadAsArray()
# print('- Shape:', np.shape(b_arr))
# # n_cl = len(np.unique(b_arr))
# # print(f'- Number of LC classes: {n_cl}')

In [None]:
# Run the function for testing
assert pixel_to_coords(b_g_10m, 100, 100) == (-121.8056363767779, 40.63567108385486)

### Generate (fixed) random sample of image coordinates 

In [None]:
# Define number of pixels 
no_pixels = 500 
np.random.seed(42) # keep this fixed 

# Generate array representing the pair of coordinates 
# xy = np.random.randint(1, np.shape(b_arr)[0], size=(no_pixels, 2))
xy = np.random.randint(1, 10980, size=(no_pixels, 2))

xy[:5]

### Extract values of a specific band for each pair of coordinates

In [None]:
band = 'SCL'

In [None]:
print('Working on band:', band)

# Open gdal object and get raster band
b_g = gdal.Open(item.assets[band].href)
b_rst = b_g.GetRasterBand(1)

In [None]:
print('Extracting values')
x_values = []
y_values = []

if b_g.GetGeoTransform()[1] == 10:
    print("res = 10m, no need to rescale")
    
    for pos in xy:

        x_values.append([*pixel_to_coords(b_g_10m, pos[0], pos[1])])

        y_values.append(
            int(
                b_rst.ReadAsArray(
                    xoff=int(pos[0]), yoff=int(pos[1]), win_xsize=1, win_ysize=1
                )[0][0]
            )
        )

elif b_g.GetGeoTransform()[1] == 20:
    print("Rescaling to res = 20m")
    
    # Need to find the pixel index of the raster which shape is half the size of the high res band, so need to divide by a factor of 2
    xy2 = np.round(xy/2, 0)
    
    for pos,pos2 in zip(xy,xy2):

        x_values.append([*pixel_to_coords(b_g_10m, pos[0], pos[1])]) # b_g must be the same as the pair of coordinates must be the same for all bands 

        y_values.append(
            int(
                b_rst.ReadAsArray(
                    xoff=int(pos2[0]), yoff=int(pos2[1]), win_xsize=1, win_ysize=1
                )[0][0]
            )
        )

# Empty b_rst
b_rst = None

print(x_values[:10])
print(y_values[:10])

## Make Pandas dataframe
**Note**: The *pandas* dataframe will be used as input for the EDA Notebook from the ARSET training.

In [None]:
# make dictionary
data = {'long': [x[0] for x in x_values], 
        'lat': [x[1] for x in x_values], 
        band: y_values}
# data

In [None]:
if df is None: 
    print('Creating Dataframe')
    # Create a DataFrame from the dictionary
    df = pd.DataFrame(data)
    df.index.name = 'Index'
    
    # Create backup
    data_bk = copy.deepcopy(df)

else: 
    print('Adding to existing Dataframe')
    
    # Create temp dataframe
    df2 = pd.DataFrame(data)
    df2.index.name = 'Index'
    
    # Assert the two dataframes have the same long and lat values
    assert df['long'].isin(df2['long']).value_counts().values[0] == no_pixels
    assert df['lat'].isin(df2['lat']).value_counts().values[0] == no_pixels
    
    # Merge temp dataframe with original dataframe, based on matching columns
    df = pd.merge(df, df2, on=['Index', 'long', 'lat'])   
    # Empty memory
    df2 = None
    
display(df)

In [None]:
# # Export dataframe 
# df.to_csv('dataframe_multiband.csv')

## Split dataset into train and validation

In [None]:
def to_geojson(t, x, y):
    """Converts the given x, y, and split dataset type (train, test, validate ) to a geojson file
    The geojson file is saved in the current directory with the name label-{t}.geojson
    """

    field_name = "class"
    field_type = ogr.OFTInteger

    # Create the output Driver
    out_driver = ogr.GetDriverByName("GeoJSON")

    geojson_filename = f"label-{t}.geojson"
    # Create the output GeoJSON
    out_datasource = out_driver.CreateDataSource(geojson_filename)
    out_layer = out_datasource.CreateLayer("labels", geom_type=ogr.wkbPolygon)
    id_field = ogr.FieldDefn(field_name, field_type)
    out_layer.CreateField(id_field)
    # Get the output Layer's Feature Definition
    feature_def = out_layer.GetLayerDefn()

    for index, v in enumerate(y):
        point = ogr.Geometry(ogr.wkbPoint)
        point.AddPoint(x[index][0], x[index][1])

        # create a new feature
        out_feature = ogr.Feature(feature_def)

        # Set new geometry
        out_feature.SetGeometry(point)

        out_feature.SetField(field_name, int(v))
        # Add new feature to output Layer
        out_layer.CreateFeature(out_feature)

        # dereference the feature
        out_feature = None

    # Save and close DataSources
    out_datasource = None

### Test with 1 variable 

In [None]:
x_values[:10]

In [None]:
y_values[:10]

In [None]:
assert len(x_values) == len(y_values)
print(len(x_values), len(y_values))

In [None]:
x_train, x_rem, y_train, y_rem = train_test_split(
    np.array(x_values), np.array(y_values), train_size=0.8
)
print(len(x_train),len(y_train))
print(len(x_rem),len(y_rem))

In [None]:
x_valid, x_test, y_valid, y_test = train_test_split(
    np.array(x_rem), np.array(y_rem), test_size=0.5
)
print(len(x_valid),len(y_valid))
print(len(x_test),len(y_test))

In [None]:
x_valid.shape, x_train.shape, x_test.shape

In [None]:
y_valid.shape, y_train.shape, y_test.shape

In [None]:
to_geojson(f"train", x_train, y_train)
# to_geojson(f"test", x_test, y_test)
# to_geojson(f"validate", x_valid, y_valid)

### Test with multiple variables (one geojson per variable)

In [None]:
# Read CSV file and make pandas dataframe
df_pandas = pd.read_csv('dataframe_multiband.csv')
df_pandas

In [None]:
y_names = ['B02', 'B03', 'B04', 'B08', 'SCL']

In [None]:
for y_name in y_names:
    print('Working on:', y_name)
    
    y_values = df_pandas[y_name].values
    assert len(x_values) == len(y_values)
    print('Total # of values:', len(x_values), len(y_values))
    
    # Training 
    x_train, x_rem, y_train, y_rem = train_test_split(
        np.array(x_values), np.array(y_values), train_size=0.8
    )
    print('# used for training:', len(x_train),len(y_train))
    print('# residuals:', len(x_rem),len(y_rem))
    
    # Testing and Validation
    x_valid, x_test, y_valid, y_test = train_test_split(
        np.array(x_rem), np.array(y_rem), test_size=0.5
    )
    print('# used for validation:', len(x_valid),len(y_valid))
    print('# used for testing:', len(x_test),len(y_test))
    
    print('x_shapes:', x_valid.shape, x_train.shape, x_test.shape)
    print('y_shapes:', y_valid.shape, y_train.shape, y_test.shape)
    
    # Now creating the geojson files
    to_geojson(f"train_{y_name}", x_train, y_train)
    # to_geojson(f"test_{y_name}", x_test, y_test)
    # to_geojson(f"validate_{y_name}", x_valid, y_valid)
    print()

### Test with multiple variables in unique geojson

In [None]:
df_pandas

In [None]:
y_val = df_pandas['SCL']#.values

In [None]:
X = df_pandas[['long', 'lat', 'B02', 'B03', 'B04', 'B08']]

In [None]:
# Training (directly from dataframe)

# x_train, x_rem, y_train, y_rem = train_test_split(
#         np.array(x_values), np.array(y_values), train_size=0.8
#     )

long_train, long_rem, lat_train, lat_rem, B02_train, B02_rem, B03_train, B03_rem, B04_train, B04_rem, B08_train, B08_rem, SCL_train, SCL_rem = train_test_split(X['long'], X['lat'], X['B02'], X['B03'], X['B04'], X['B08'], y_val, train_size=0.8, random_state=42)
print('# used for training:', len(long_train))
print('# residuals:', len(long_rem))

In [None]:
# Testing and Validation

# x_valid, x_test, y_valid, y_test = train_test_split(
#     np.array(x_rem), np.array(y_rem), test_size=0.5
# )

long_valid, long_test, lat_valid, lat_test, B02_valid, B02_test, B03_valid, B03_test, B04_valid, B04_test, B08_valid, B08_test, SCL_valid, SCL_test = train_test_split(long_rem, lat_rem, B02_rem, B03_rem, B04_rem, B08_rem, SCL_rem, train_size=0.5)

print('# used for validation:', len(long_valid))
print('# used for testing:', len(long_test))