In [1]:
# Jupyter notebook related
%reload_ext autoreload
%autoreload 2

In [3]:
import importlib.metadata

In [6]:
importlib.metadata.version("numpy")

'1.24.3'

In [7]:
import numpy

In [8]:
numpy.__package__

'numpy'

In [2]:
import hvplot.xarray  # noqa
import hvplot.pandas  # noqa
import panel as pn  # noqa
import panel.widgets as pnw

In [3]:
import os
import geopandas as gpd
from dask import delayed

# from elogs import Elogs, ElogsTask

with open('../../../connstr_vegteam') as f:
    connect_str = f.read()
container_name = 'evotrain'

locs_fn = "../../../locations_v2.csv"

In [22]:
import pandas as pd

locs = pd.read_csv(locs_fn)

In [24]:
from satio_pc.utils.azure import AzureBlobReader
azure = AzureBlobReader(connect_str,
                        container_name)

# keys = azure.list_files()
def list_files(prefix=None):
    files = set()
    for blob in azure.container_client.list_blobs(name_starts_with=prefix):
        if blob.name[-1] != "/":
            files.add(blob.name)
    return files

keys = list_files(f'v2ts/{year}/')

In [23]:
"_".join(keys[0].split('.')[0].split('_')[-3:])

done_patch_ids = list(map(lambda key: "_".join(key.split('.')[0].split('_')[-3:]), keys))

locs = locs[~locs.patch_id.isin(done_patch_ids)]
locs.shape


KeyboardInterrupt



# Cluster setup

In [3]:
# from dask_gateway import Gateway
# gateway = Gateway()


# # List the clusters and get the cluster report
# clusters_reports = gateway.list_clusters()

# # Get the first cluster report
# cluster_report = clusters_reports[0]

# # Connect to the cluster using the cluster report
# cluster = gateway.connect(cluster_report)

# # Get the client object from the cluster
# client = cluster.get_client()

In [38]:
# stop clusters
from dask_gateway import Gateway
gateway = Gateway()
clusters_reports = gateway.list_clusters()

clusters = [gateway.stop_cluster(c.name) for c in clusters_reports]
clusters

[]

In [39]:
# create and scale cluster
from dask.distributed import PipInstall, Client
import dask_gateway

cluster = dask_gateway.GatewayCluster()
client = cluster.get_client()

print(client.dashboard_link)

cluster.scale(100)

cluster

https://pccompute.westeurope.cloudapp.azure.com/compute/services/dask-gateway/clusters/prod.4e5cba7ca94f4430b331a2207ef9e060/status


VBox(children=(HTML(value='<h2>GatewayCluster</h2>'), HBox(children=(HTML(value='\n<div>\n<style scoped>\n    …

In [40]:
# Once cluster is scaled, install satio_pc
# satio_pc_url = "https://s3-eu-central-1.amazonaws.com/vito-worldcover-public/wheels/satio_pc-0.0.1-py3-none-any.whl"
satio_pc_url = "satio-pc==0.0.3"
plugin = PipInstall(packages=[satio_pc_url])
client.register_worker_plugin(plugin)

{'tls://10.244.114.12:44169': {'status': 'OK'},
 'tls://10.244.114.13:36851': {'status': 'OK'},
 'tls://10.244.116.12:32829': {'status': 'OK'},
 'tls://10.244.116.13:46173': {'status': 'OK'},
 'tls://10.244.116.14:44309': {'status': 'OK'},
 'tls://10.244.117.12:32815': {'status': 'OK'},
 'tls://10.244.118.12:33825': {'status': 'OK'},
 'tls://10.244.118.13:34261': {'status': 'OK'},
 'tls://10.244.118.14:46373': {'status': 'OK'},
 'tls://10.244.119.12:38931': {'status': 'OK'},
 'tls://10.244.119.13:32937': {'status': 'OK'},
 'tls://10.244.137.10:32915': {'status': 'OK'},
 'tls://10.244.137.9:46183': {'status': 'OK'},
 'tls://10.244.138.9:33759': {'status': 'OK'},
 'tls://10.244.140.10:35741': {'status': 'OK'},
 'tls://10.244.140.11:38103': {'status': 'OK'},
 'tls://10.244.140.9:35731': {'status': 'OK'},
 'tls://10.244.141.10:38527': {'status': 'OK'},
 'tls://10.244.141.11:32769': {'status': 'OK'},
 'tls://10.244.141.9:43795': {'status': 'OK'},
 'tls://10.244.142.9:33477': {'status': 'OK'

In [None]:
# check logs
logs = client.get_worker_logs()

print(len(logs))

for worker, worker_logs in logs.items():
    print(f"Logs for worker {worker}:")
    for log in worker_logs:
        print(log)
    print()
    print('*'*100)

In [15]:
# shutdown cluster
# cluster.shutdown()

# Training data extraction

In [4]:
max_cloud_cover = 90
max_workers = 10
year = 2020

In [5]:
import time
from pathlib import Path


def store_ts(ts, fn, complevel=5):
    if 'spec' in ts.attrs.keys():
        ts.attrs['spec'] = str(ts.attrs['spec'])
    for var in ['proj:bbox', 'proj:shape', 'proj:transform']:
        if var in ts.coords:
            ts = ts.drop_vars(var)
    encoding = {ts.name: {'zlib': True, 'complevel': complevel}}
    # Save to a NetCDF file with compression
    ts.to_netcdf(fn, encoding=encoding)


In [6]:
def _extract_loc(patch_id, tile, epsg, xmin, ymin, xmax, ymax, year):
    from loguru import logger
    from satio_pc.reader import S2TileReader
    from satio_pc.preprocessing.clouds import preprocess_scl
    from satio_pc.sentinel2 import BANDS_RESOLUTION
    from satio_pc.geotiff import slash_tile
    from satio_pc.utils.azure import AzureBlobReader

    from rasterio.enums import Resampling
    
    container_name = 'evotrain'
    azure = AzureBlobReader(connect_str,
                            container_name)

    bounds = xmin, ymin, xmax, ymax
    
    
    fns = [f'v2ts/{year}/{slash_tile(tile)}/{pre}/evotrain_v2ts_{year}_{patch_id}_{pre}.nc'
           for pre in '10m 20m 60m'.split()]
    
    if all(map(lambda fn: azure.check_file_exists(fn), fns)):
        logger.warning(f"Targets {fns} exists, skipping...")
        return True

    start_date = f'{year}-01-01'
    end_date = f'{year + 1}-01-01'

    reader = S2TileReader(tile,
                          start_date,
                          end_date,
                          max_cloud_cover)

    logger.info("Loading SCL mask")
    scl = reader.read(bounds, epsg, ['SCL'], max_workers=max_workers)
    
    logger.info("Filtering no data obs")
    valid_obs = scl.mean(dim=('y', 'x', 'band')) > 0

    scl = scl.sel(time=valid_obs)
    
    reader._items = [i for i, b in zip(reader._items, valid_obs) if b]

    bands = ['B01', 'B09',
             'B02', 'B03', 'B04', 'B08', 
             'B05', 'B06', 'B07', 'B8A', 'B11', 'B12']

    bands_10m = [b for b in bands if BANDS_RESOLUTION[b] == 10]
    bands_20m = [b for b in bands if BANDS_RESOLUTION[b] == 20] + ['SCL']
    bands_60m = [b for b in bands if BANDS_RESOLUTION[b] == 60] + ['AOT', 'WVP']
    
    logger.info("Loading data")
    ts10 = reader.read(bounds, epsg, bands_10m, max_workers=20,
                      resolution=10).ewc.harmonize()
    ts20 = reader.read(bounds, epsg, bands_20m, max_workers=20,
                       resolution=20).ewc.harmonize()
    ts60 = reader.read(bounds, epsg, bands_60m, max_workers=20,
                       resolution=60).ewc.harmonize()
    
    for ts, pre in zip([ts10, ts20, ts60],
                       ['10m', '20m', '60m']):
        fn = f'evotrain_v2ts_{year}_{patch_id}_{pre}.nc'
        dst_fn = f"v2ts/{year}/{slash_tile(tile)}/{pre}/{fn}"
        
        logger.info(f"Saving {fn} and uploading to {dst_fn}")
        ts.name = f'evotrain_v2ts_{year}_{patch_id}_{pre}'
        store_ts(ts, fn, complevel=4)
    
        azure.upload_file(fn,
                          dst_fn,
                          overwrite=True)
        os.remove(fn)
    
    logger.success("Done")
    return True

In [7]:
def extract_loc(tup):
    patch_id, tile, epsg, xmin, ymin, xmax, ymax, year = tup
    xmin, ymin, xmax, ymax = list(map(float, (xmin, ymin, xmax, ymax)))
    epsg = int(epsg)
    year = int(year)
    
    try:
        from loguru import logger
        return _extract_loc(patch_id, tile, epsg, xmin, ymin, xmax, ymax, year)
    except Exception as e:
        logger.exception(e)
        return False

In [28]:
import pandas as pd
import xarray as xr
import numpy as np

def split_dataframe(df, n, nth_id):
    if nth_id >= n or nth_id < 0:
        raise ValueError("Invalid nth_id value")
    chunk_size = len(df) // n
    remainder = len(df) % n
    start_idx = nth_id * chunk_size + min(nth_id, remainder)
    end_idx = start_idx + chunk_size + (1 if nth_id < remainder else 0)
    return df.iloc[start_idx:end_idx]


locs = pd.read_csv(locs_fn)
locs = locs.sample(frac=1, random_state=0)  # shuffle

chunk_id = 0
locs = split_dataframe(locs, 10, chunk_id)

cols = ['patch_id', 'tile', 'epsg', 'xmin', 'ymin', 'xmax', 'ymax']
locs = locs[cols]

In [29]:
locs.shape

(17735, 7)

In [8]:
from satio_pc.utils.azure import AzureBlobReader
azure = AzureBlobReader(connect_str,
                        container_name)

# keys = azure.list_files()
def list_files(prefix=None):
    files = set()
    for blob in azure.container_client.list_blobs(name_starts_with=prefix):
        if blob.name[-1] != "/":
            files.add(blob.name)
    return files

In [9]:
keys = list_files(f'v2ts/{year}/')

In [12]:
azure.download_file('v2ts/2020/18/R/TN/20m/evotrain_v2ts_2020_18RTN_099_34_20m.nc', 'evotrain_v2ts_2020_18RTN_099_34_20m.nc')

In [13]:
import xarray as xr
da = xr.open_dataarray('evotrain_v2ts_2020_18RTN_099_34_20m.nc')

In [17]:
from satio_pc.extension import ESAWorldCoverTimeSeries

In [18]:
da.sel(band=['SCL']).ewc.show()

In [20]:
da.band

In [11]:
keys

{'v2ts/2020/48/U/VE/10m/evotrain_v2ts_2020_48UVE_056_14_10m.nc',
 'v2ts/2020/16/T/GT/10m/evotrain_v2ts_2020_16TGT_070_05_10m.nc',
 'v2ts/2020/04/W/DC/10m/evotrain_v2ts_2020_04WDC_026_46_10m.nc',
 'v2ts/2020/28/N/GP/10m/evotrain_v2ts_2020_28NGP_016_48_10m.nc',
 'v2ts/2020/18/R/TN/20m/evotrain_v2ts_2020_18RTN_099_34_20m.nc',
 'v2ts/2020/20/T/LT/20m/evotrain_v2ts_2020_20TLT_049_32_20m.nc',
 'v2ts/2020/50/J/MT/60m/evotrain_v2ts_2020_50JMT_107_05_60m.nc',
 'v2ts/2020/48/Q/VL/20m/evotrain_v2ts_2020_48QVL_013_35_20m.nc',
 'v2ts/2020/39/U/WV/20m/evotrain_v2ts_2020_39UWV_080_34_20m.nc',
 'v2ts/2020/12/S/YG/20m/evotrain_v2ts_2020_12SYG_035_26_20m.nc',
 'v2ts/2020/27/X/VB/10m/evotrain_v2ts_2020_27XVB_089_47_10m.nc',
 'v2ts/2020/35/L/QD/20m/evotrain_v2ts_2020_35LQD_082_59_20m.nc',
 'v2ts/2020/22/K/EE/20m/evotrain_v2ts_2020_22KEE_106_62_20m.nc',
 'v2ts/2020/39/R/WM/10m/evotrain_v2ts_2020_39RWM_006_22_10m.nc',
 'v2ts/2020/11/S/LB/60m/evotrain_v2ts_2020_11SLB_069_21_60m.nc',
 'v2ts/2020/22/X/ES/10m/e

In [9]:
done_patch_ids = list(map(lambda key: "_".join(key.split('.')[0].split('_')[-3:]), keys))

locs = locs[~locs.patch_id.isin(done_patch_ids)]
locs.shape

(177341, 7)

### Cluster processing

In [34]:
# 'patch_id', 'tile', 'epsg', 'xmin', 'ymin', 'xmax', 'ymax'
args = [(loc.patch_id,
         loc.tile,
         loc.epsg,
         loc.xmin,
         loc.ymin,
         loc.xmax,
         loc.ymax,
         year)
        for loc in locs.itertuples()
        ]
len(args)

17735

In [35]:
import numpy as np
args = np.array(args)

In [36]:
import warnings
warnings.filterwarnings("ignore")

In [37]:
final = extract_loc(args[0])

[32m2023-10-02 13:36:20.691[0m | [1mINFO    [0m | [36m__main__[0m:[36m_extract_loc[0m:[36m33[0m - [1mLoading SCL mask[0m
[32m2023-10-02 13:36:23.661[0m | [1mINFO    [0m | [36m__main__[0m:[36m_extract_loc[0m:[36m36[0m - [1mFiltering no data obs[0m
[32m2023-10-02 13:36:23.719[0m | [1mINFO    [0m | [36m__main__[0m:[36m_extract_loc[0m:[36m51[0m - [1mLoading data[0m
[32m2023-10-02 13:36:32.601[0m | [1mINFO    [0m | [36m__main__[0m:[36m_extract_loc[0m:[36m64[0m - [1mSaving evotrain_v2ts_2020_24LXR_100_13_10m.nc and uploading to v2ts/2020/24/L/XR/10m/evotrain_v2ts_2020_24LXR_100_13_10m.nc[0m
[32m2023-10-02 13:36:33.191[0m | [1mINFO    [0m | [36m__main__[0m:[36m_extract_loc[0m:[36m64[0m - [1mSaving evotrain_v2ts_2020_24LXR_100_13_20m.nc and uploading to v2ts/2020/24/L/XR/20m/evotrain_v2ts_2020_24LXR_100_13_20m.nc[0m
[32m2023-10-02 13:36:33.611[0m | [1mINFO    [0m | [36m__main__[0m:[36m_extract_loc[0m:[36m64[0m - [1mSaving e

In [45]:
import dask

extract_delayed = dask.delayed(extract_loc)

lazy_results = [extract_delayed(ag)
                for ag in args]

In [None]:
results = dask.compute(*lazy_results)

# finish download missing

In [10]:
locs_all = pd.read_csv(locs_fn)

args = []

for year in range(2018, 2023):
    
    print(year)
    keys = list_files(f'evotrain/v2/{year}/')

    done_patch_ids = list(map(lambda key: "_".join(key.split('.')[0].split('_')[-3:]), keys))

    locs = locs_all[~locs_all.patch_id.isin(done_patch_ids)]
    
    args += [(loc.patch_id,
         loc.tile,
         loc.epsg,
         loc.xmin,
         loc.ymin,
         loc.xmax,
         loc.ymax,
         year)
        for loc in locs.itertuples()
        ]
    print(year, locs.shape)
    
args = np.array(args)

args.size

2018
2018 (199, 28)
2019
2019 (75, 28)
2020
2020 (59, 28)
2021
2021 (260, 28)
2022
2022 (741, 28)


10672

In [14]:
import dask

extract_delayed = dask.delayed(extract_loc)

lazy_results = [extract_delayed(ag)
                for ag in args]

results = dask.compute(*lazy_results)