# Process a whole dataset

In [None]:
#Autoreload .py files
%load_ext autoreload
%autoreload 2

#https://github.com/chmp/ipytest/issues/80
import sys
sys.breakpointhook = sys.__breakpointhook__

In [None]:
#Load libs
import xarray as xr
import numpy as np
from pathlib import Path
import satalign
import matplotlib.pyplot as plt
from xarrayvideo import xarray2video, video2xarray, gap_fill, plot_image, to_netcdf
from tqdm.notebook import tqdm
import warnings

In [None]:
#Set parameters
dataset_in_path= Path('/scratch/users/deepextremes/deepextremes-minicubes/full')
dataset_out_path= Path('/scratch/users/deepextremes/video')
images_out_path= Path('./deepextremes_images')
conversion_rules= {
    # 'r': ( 'B04', ('time','x','y'), 'lossy'), #1 channel + lossy not working
    # 'ir4': ( ('B8A','B07','B06','B05'), ('time','x','y'), 'lossy'), #4 channels not working
    'rgb': ( ('B04','B03','B02'), ('time','x','y'), 'lossy'),
    'ir3': ( ('B8A','B06','B05'), ('time','x','y'), 'lossy'),
    'cm': ( 'cloudmask_en', ('time','x','y'), 'lossless'),
    'scl': ( 'SCL', ('time','x','y'), 'lossless'),
    }
lossy_params = {
    'c:v': 'libx264',  #libx264 always seems better for rgb[libx264, libx265, vp9, ffv1]
    'preset': 'slow',  #Preset for quality/encoding speed tradeoff: quick, medium, slow (better)
    'crf': 11, #14 default, 11 for higher quality and size
    }
files= list(dataset_in_path.glob('*/*.zarr'))
verbose=False

In [None]:
#Run for all cubes
for i, input_path in (pbar:=tqdm(enumerate(files), total=len(files))):
    try:
        #Print name
        array_id= '_'.join(input_path.stem.split('_')[1:3])
        pbar.set_description(array_id)

        #Load
        minicube= xr.open_dataset(input_path, engine='zarr')
        minicube['SCL']= minicube['SCL'].astype(np.uint8) #Fixes problem with the dataset
        minicube['cloudmask_en']= minicube['cloudmask_en'].astype(np.uint8)
        minicube= minicube.drop_vars('B07') #We drop a variable for now

        #Align
        bands= ['B04','B03','B02','B8A','B06','B05']
        reference_image= minicube[bands].isel(time=slice(74,None)).mean("time").to_array().transpose('variable', 'y', 'x')
        datacube= minicube[bands].to_array().transpose('time', 'variable', 'y', 'x')

        with warnings.catch_warnings(): 
            warnings.simplefilter("ignore")
            syncmodel= satalign.PCC( #PCC quicker, ECC more precise
                datacube=datacube, # T x C x H x W
                reference=reference_image, # C x H x W
                channel="mean", crop_center=96, num_threads=1)
            new_cube, warps= syncmodel.run_multicore()
        for b in bands: minicube[b]= new_cube.sel(variable=b)
            
        if verbose:
            warp_df = satalign.utils.warp2df(warps, datacube.time.values)
            satalign.utils.plot_s2_scatter(warp_df)
            plt.show()

        #Compress
        arr_dict= xarray2video(minicube, array_id, conversion_rules, value_range=(0.,1.),
                               lossy_params=lossy_params, fmt='mkv', exceptions='raise',
                               output_path=dataset_out_path, use_ssim=False, compute_stats=verbose,
                               loglevel='quiet', #verbose, quiet
                               )

        #Plot image
        if True:
            if i==0: continue
            minicube_new= video2xarray(dataset_out_path, array_id, fmt='mkv') 
            plot_image(minicube_new, ['B04','B03','B02'], save_name=str(images_out_path/f'{array_id}.jpg'), show=False)
            
    except Exception as e:
        print(f'Exception processing {array_id=}: {e}')

        #Stop
        if i > 10: break