# Pytorch wrapper for DFCSEN12MSDataset

A DFCDataset object is created for the validation set by calling the constructor with the provided config values:
- val_dir is passed as the base_dir to point to the validation data directory
- val_mode is passed as 'test' to indicate this is the validation set
- clip_sample_values is set to True to clip/limit the pixel values as specified
- val_used_data_fraction of 1 means all validation data will be used
- image_px_size is set to 224 for a 224x224 image size
- cover_all_parts_validation is True, so a sliding window will be used to cover the entire validation image
seed is set to 42 for reproducibility


When initialized, the DFCDataset does the following main steps:
- Sets up the seasons list to contain only 'Seasons.TESTSET' for the validation set
- Reads in the validation observations CSV file from val_dir
- Samples 100% of rows from the CSV based on val_used_data_fraction=1
- No transforms are specified, so the default base transform (ToTensorV2) will be used


When getitem is called to retrieve an item:
- The observation for the index is retrieved from the DataFrame
- The S1, S2, and LC data is loaded using the scene info
- The data is clipped if clip_sample_values is True
- A sliding window is used to cover the entire original image since cover_all_parts_validation is True
- Only the base ToTensorV2 transform is applied
- The data is normalized
- A dict containing the transformed/normalized arrays is returned


The validation DFCDataset will load the full validation data using a sliding window approach, clip values, apply a standard transform/normalization, and return the image/label tensors.


The config allows control over the data sampling, transforms, normalizations etc. applied to the validation set.

## Wrapper

In [1]:
# from dfc_dataset import DFCDataset

# data_config = {
#     'train_dir': 'data/data_disini', # path to the training directory,  
#     'val_dir': 'data/data_disini', # path to the validation directory,
#     'train_mode': 'validation', # can be one of the following: 'test', 'validation'
#     'val_mode': 'test', # can be one of the following: 'test', 'validation'
#     'num_classes': 8, # number of classes in the dataset.
#     'clip_sample_values': True, # clip (limit) values
#     'train_used_data_fraction': 1, # fraction of data to use, should be in the range [0, 1]
#     'val_used_data_fraction': 1,
#     'image_px_size': 224, # image size (224x224)
#     'cover_all_parts_train': True, # if True, if image_px_size is not 224 during training, we use a random crop of the image
#     'cover_all_parts_validation': True, # if True, if image_px_size is not 224 during validation, we use a non-overlapping sliding window to cover the entire image
#     'seed': 42,
# }

In [2]:
# val_dataset = DFCDataset(
#     data_config['val_dir'],
#     mode=data_config['val_mode'],
#     clip_sample_values=data_config['clip_sample_values'],
#     used_data_fraction=data_config['val_used_data_fraction'],
#     image_px_size=data_config['image_px_size'],
#     cover_all_parts=data_config['cover_all_parts_validation'],
#     seed=data_config['seed'],
# )

## Breakdown

In [3]:
base_dir = 'data/data_disini'
mode = 'test'
clip_sample_values = True
used_data_fraction = 1
image_px_size = 224
cover_all_parts = True
seed = 42

# other parameter set to default
transforms = None
simclr_dataset = False
balanced_classes = False
sampling_seed = 42
normalize = True
moby_transform = None

cover_all_parts: 

if image_px_size is not 224, this makes sure that during validation the entire image is used during training, we read image part at random parts of the original image, during vaildation, use a non-overlapping sliding window to cover the entire image

selecting seasons

In [4]:
# from dfc_sen12ms_dataset import Seasons
from enum import Enum


class Seasons(Enum):
    SPRING = "ROIs1158_spring"
    SUMMER = "ROIs1868_summer"
    FALL = "ROIs1970_fall"
    WINTER = "ROIs2017_winter"
    AUTUMN_DFC = "ROIs0000_autumn"
    WINTER_DFC = "ROIs0000_winter"
    SPRING_DFC = "ROIs0000_spring"
    SUMMER_DFC = "ROIs0000_summer"
    TESTSET = "ROIs0000_test"
    VALSET = "ROIs0000_validation"
    TEST = [TESTSET]
    VALIDATION = [VALSET]
    TRAIN = [SPRING, SUMMER, FALL, WINTER]
    ALL = [SPRING, SUMMER, FALL, WINTER, VALIDATION, TEST]

if mode == "dfc":
    seasons = [
        Seasons.AUTUMN_DFC,
        Seasons.SPRING_DFC,
        Seasons.SUMMER_DFC,
        Seasons.WINTER_DFC,
    ]
elif mode == "test":
    seasons = [Seasons.TESTSET]
elif mode == "validation":
    seasons = [Seasons.VALSET]
elif mode == "sen12ms":
    seasons = [
        Seasons.SPRING,
        Seasons.SUMMER,
        Seasons.FALL,
        Seasons.WINTER,
    ]
else:
    raise ValueError(
        "Unsupported mode, must be in ['dfc', 'sen12ms', 'test', 'validation']"
    )

print(mode, seasons)

test [<Seasons.TESTSET: 'ROIs0000_test'>]


In [5]:
from dfc_sen12ms_dataset import DFCSEN12MSDataset

print(base_dir)
data = DFCSEN12MSDataset(base_dir)
data

data/data_disini


<dfc_sen12ms_dataset.DFCSEN12MSDataset at 0x107f524c0>

In [6]:
import pandas as pd
import os

if balanced_classes:
    observations = pd.read_csv(
        os.path.join(base_dir, mode + "_observations_balanced_classes.csv"),
        header=0,
        # names=["Season", "Scene", "ID", "dfc_label", "copy_nr"],
    )
else:
    observations = pd.read_csv(
        os.path.join(base_dir, mode + "_observations.csv"),
        header=None,
        names=["Season", "Scene", "ID"],
    )
print(balanced_classes)
print(observations.shape)
observations.head()

False
(5128, 3)


Unnamed: 0,Season,Scene,ID
0,Seasons.TESTSET,0,2014
1,Seasons.TESTSET,0,3654
2,Seasons.TESTSET,0,1101
3,Seasons.TESTSET,0,3235
4,Seasons.TESTSET,0,467


In [7]:
print(observations['Season'].value_counts())
print(observations['Scene'].value_counts())
print(observations['ID'].value_counts())

Season
Seasons.TESTSET    5128
Name: count, dtype: int64
Scene
0    5128
Name: count, dtype: int64
ID
2014    1
2604    1
385     1
1144    1
3611    1
       ..
1603    1
2716    1
3156    1
704     1
619     1
Name: count, Length: 5128, dtype: int64


1 season, 1 scene & 5128 unique ID's

In [8]:
if cover_all_parts:
    num_img_parts = int(256**2 / image_px_size**2)
    obs = []
    for season, scene, idx in observations.values:
        for i in range(num_img_parts):
            obs.append([season, scene, idx, i])

    observations = pd.DataFrame(
        obs, columns=["Season", "Scene", "ID", "ScenePart"]
    )
print(cover_all_parts)
print(image_px_size)
observations.head()

True
224


Unnamed: 0,Season,Scene,ID,ScenePart
0,Seasons.TESTSET,0,2014,0
1,Seasons.TESTSET,0,3654,0
2,Seasons.TESTSET,0,1101,0
3,Seasons.TESTSET,0,3235,0
4,Seasons.TESTSET,0,467,0


In [9]:
num_img_parts

1

num_img_parts = int(256^2 / 224^2) computes to int(65536 / 50176), which is int(1.306...). When converted to an integer, this will result in num_img_parts being 1.

This means that the loop for image parts will only iterate once (for i in range(1):), implying that each image is considered to be a single part of size 224x224.

As a result, for each original observation, one entry will be added to the obs list. This entry represents the whole image (since 224x224 is almost the full area of a 256x256 image, with very little cropping).

In [10]:
observations["ScenePart"].value_counts()

ScenePart
0    5128
Name: count, dtype: int64

Bagaimana kalo image_px_size kita turunkan angkanya?

In [11]:
def percobaan(image_px_size=224):
    import pandas as pd
    import os

    if balanced_classes:
        observations = pd.read_csv(
            os.path.join(base_dir, mode + "_observations_balanced_classes.csv"),
            header=0,
            # names=["Season", "Scene", "ID", "dfc_label", "copy_nr"],
        )
    else:
        observations = pd.read_csv(
            os.path.join(base_dir, mode + "_observations.csv"),
            header=None,
            names=["Season", "Scene", "ID"],
        )
    print(balanced_classes)
    
    if cover_all_parts:
        num_img_parts = int(256**2 / image_px_size**2)
        obs = []
        for season, scene, idx in observations.values:
            for i in range(num_img_parts):
                obs.append([season, scene, idx, i])

        observations = pd.DataFrame(
            obs, columns=["Season", "Scene", "ID", "ScenePart"]
        )
    print(cover_all_parts)
    print(image_px_size)
    return observations

observations_percobaan = percobaan(image_px_size=112)
print(observations_percobaan['ScenePart'].value_counts())
observations_percobaan.head()

False
True
112
ScenePart
0    5128
1    5128
2    5128
3    5128
4    5128
Name: count, dtype: int64


Unnamed: 0,Season,Scene,ID,ScenePart
0,Seasons.TESTSET,0,2014,0
1,Seasons.TESTSET,0,2014,1
2,Seasons.TESTSET,0,2014,2
3,Seasons.TESTSET,0,2014,3
4,Seasons.TESTSET,0,2014,4


Samples 100% of rows from the CSV based on val_used_data_fraction=1

In [12]:
print(used_data_fraction, sampling_seed)
print(observations.shape)
observations = observations.sample(frac=used_data_fraction, random_state=sampling_seed).sort_index()
print(observations.shape)

1 42
(5128, 4)
(5128, 4)


In [13]:
import albumentations as A
from utils import AlbumentationsToTorchTransform
from albumentations.pytorch import ToTensorV2

print(transforms)
base_aug = A.Compose([ToTensorV2()])
base_transform = AlbumentationsToTorchTransform(base_aug)

None


  warn(f"Failed to load image Python extension: {e}")


getitem

In [14]:
# coba index 2
idx = 2

In [15]:
obs = observations.iloc[idx]

class Seasons(Enum):
    SPRING = "ROIs1158_spring"
    SUMMER = "ROIs1868_summer"
    FALL = "ROIs1970_fall"
    WINTER = "ROIs2017_winter"
    AUTUMN_DFC = "ROIs0000_autumn"
    WINTER_DFC = "ROIs0000_winter"
    SPRING_DFC = "ROIs0000_spring"
    SUMMER_DFC = "ROIs0000_summer"
    TESTSET = "ROIs0000_test"
    VALSET = "ROIs0000_validation"
    TEST = [TESTSET]
    VALIDATION = [VALSET]
    TRAIN = [SPRING, SUMMER, FALL, WINTER]
    ALL = [SPRING, SUMMER, FALL, WINTER, VALIDATION, TEST]

season = Seasons[obs.Season[len("Seasons.") :]]

print(season)
obs

Seasons.TESTSET


Season       Seasons.TESTSET
Scene                      0
ID                      1101
ScenePart                  0
Name: 2, dtype: object

windowing

In [16]:
import numpy as np
from rasterio.windows import Window

if image_px_size != 256:
    # crop the data to image_px_size times image_px_size (e.g. 128x128)
    x_offset, y_offset = np.random.randint(0, 256 - image_px_size, 2)
    window = Window(x_offset, y_offset, image_px_size, image_px_size)
else:
    window = None
print(x_offset, y_offset, image_px_size)
window

22 15 224


Window(col_off=22, row_off=15, width=224, height=224)

Window is a rasterio class that defines a rectangular subset of a raster dataset in terms of row and column offsets and width and height in rows and columns.

We creates a Window object with the offsets and dimensions specified by image_px_size. The Window specifies the rectangular region of the image to be cropped.

If we set image_px_size to 224, the code would generate random x_offset and y_offset values between 0 and 32 (256 - 224 = 32). 

This would define the top-left corner of a 224x224 pixel window that the Window object represents. This window would be used to crop a 224x224 pixel region from a larger 256x256 pixel raster image. The purpose of this random offset is likely to introduce randomness in the cropping process, which can be a form of data augmentation or just a way to get different subsets of data from a larger image dataset.

In [17]:
mode

'test'

In [18]:
class S1Bands(Enum):
    VV = 1
    VH = 2
    ALL = [VV, VH]
    NONE = None


class S2Bands(Enum):
    B01 = aerosol = 1
    B02 = blue = 2
    B03 = green = 3
    B04 = red = 4
    B05 = re1 = 5
    B06 = re2 = 6
    B07 = re3 = 7
    B08 = nir1 = 8
    B08A = nir2 = 9
    B09 = vapor = 10
    B10 = cirrus = 11
    B11 = swir1 = 12
    B12 = swir2 = 13
    ALL = [B01, B02, B03, B04, B05, B06, B07, B08, B08A, B09, B10, B11, B12]
    RGB = [B04, B03, B02]
    NONE = None


class LCBands(Enum):
    LC = lc = 0
    DFC = dfc = 1
    ALL = [DFC]
    NONE = None

class Sensor(Enum):
    s1 = "s1"
    s2 = "s2"
    lc = "lc"
    dfc = "dfc"

In [19]:
season

<Seasons.TESTSET: 'ROIs0000_test'>

In [20]:
Seasons(season).value

'ROIs0000_test'

In [21]:
# s2_bands=S2Bands.ALL

# s1, s2, lc, bounds = [
#     x.astype(np.float32) if type(x) == np.ndarray else x
#     for x in data.get_s1_s2_lc_dfc_quad(
#         season,
#         obs.Scene,
#         int(obs.ID),
#         s1_bands=S1Bands.ALL,
#         s2_bands=s2_bands,
#         lc_bands=LCBands.LC,
#         dfc_bands=LCBands.DFC,
#         include_dfc=False,
#         window=window,
#     )
# ]

In [44]:
dfc = None

In [47]:
import rasterio

IGBP2DFC = np.array([0, 1, 1, 1, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 6, 8, 9, 10])

def get_patch(season, scene_id, patch_id, bands, window=None):
    """
        Returns raster data and image bounds for the defined bands of a specific patch
        This method only loads a sinlge patch from a single sensor as defined by the bands specified
    """
    season = Seasons(season).value
    sensor = None

    if not bands:
        return None, None

    if isinstance(bands, (list, tuple)):
        b = bands[0]
    else:
        b = bands
    
    if isinstance(b, S1Bands):
        sensor = Sensor.s1.value
        bandEnum = S1Bands
    elif isinstance(b, S2Bands):
        sensor = Sensor.s2.value
        bandEnum = S2Bands
    elif isinstance(b, LCBands):
        if LCBands(bands) == LCBands.LC:
            sensor = Sensor.lc.value 
        else:
            sensor = Sensor.dfc.value 

        bands = LCBands(1)
        bandEnum = LCBands
    else:
        raise Exception("Invalid bands specified")

    if isinstance(bands, (list, tuple)):
        bands = [b.value for b in bands]
    else:
        bands = bandEnum(bands).value

    scene = "{}_{}".format(sensor, scene_id)
    filename = "{}_{}_p{}.tif".format(season, scene, patch_id)
    patch_path = os.path.join(base_dir, season, scene, filename)
    print(patch_path)

    with rasterio.open(patch_path) as patch:
        if window is not None:
            data = patch.read(bands, window=window) 
        else:
            data = patch.read(bands)
        bounds = patch.bounds

    # Remap IGBP to DFC bands
    if sensor  == "lc":
        data = IGBP2DFC[data]

    if len(data.shape) == 2:
        data = np.expand_dims(data, axis=0)

    return data, bounds

s1, bounds1 = get_patch(season=season, scene_id=obs.Scene, patch_id=int(obs.ID), bands=S1Bands.ALL, window=window)
s2, bounds2 = get_patch(season, scene_id=obs.Scene, patch_id=int(obs.ID), bands=S2Bands.ALL, window=window)
lc, bounds3 = get_patch(season, scene_id=obs.Scene, patch_id=int(obs.ID), bands=LCBands.LC, window=window)

bounds = next(filter(None, [bounds1, bounds2, bounds3]), None)

data/data_disini/ROIs0000_test/s1_0/ROIs0000_test_s1_0_p1101.tif
data/data_disini/ROIs0000_test/s2_0/ROIs0000_test_s2_0_p1101.tif
data/data_disini/ROIs0000_test/lc_0/ROIs0000_test_lc_0_p1101.tif


In [23]:
lc

array([[[6, 6, 6, ..., 6, 6, 6],
        [6, 6, 6, ..., 6, 6, 6],
        [6, 6, 6, ..., 6, 6, 6],
        ...,
        [6, 6, 6, ..., 6, 6, 6],
        [6, 6, 6, ..., 6, 6, 6],
        [6, 6, 6, ..., 6, 6, 6]]])

In [24]:
lc[lc == 3] = 0
lc[lc == 8] = 0
lc[lc >= 3] -= 1
lc[lc >= 8] -= 1
lc -= 1

# print("Number of invalid pixels:", lc[lc == -1].size)
lc[lc == -1] = 255

In [25]:
lc

array([[[4, 4, 4, ..., 4, 4, 4],
        [4, 4, 4, ..., 4, 4, 4],
        [4, 4, 4, ..., 4, 4, 4],
        ...,
        [4, 4, 4, ..., 4, 4, 4],
        [4, 4, 4, ..., 4, 4, 4],
        [4, 4, 4, ..., 4, 4, 4]]])

In [26]:
lc_unique, lc_counts = np.unique(lc, return_counts=True)
print(lc_unique, lc_counts)

[4] [50176]


In [27]:
DFC_map_clean = {
    0: "Forest",
    1: "Shrubland",
    2: "Grassland",
    3: "Wetlands",
    4: "Croplands",
    5: "Urban/Built-up",
    6: "Barren",
    7: "Water",
    255: "Invalid",
}

lc_label = lc_unique[
    lc_counts.argmax()
]  # this is already mapped to dfc in data.get_s1_s2_lc_dfc_quad
lc_label_str = DFC_map_clean[int(lc_label)]

In [28]:
print(lc_label, lc_label_str)

4 Croplands


In [29]:
import torch

lc_multilabel = torch.tensor(
    [
        class_idx
        for class_idx, num in zip(lc_unique, lc_counts)
        if num / image_px_size**2 >= 0.1 and class_idx != 255
    ]
).long()
lc_multilabel_one_hot = torch.nn.functional.one_hot(
    lc_multilabel.flatten(), num_classes=8
).float()
lc_multilabel_one_hot = lc_multilabel_one_hot.sum(dim=0)

In [30]:
lc_multilabel

tensor([4])

In [31]:
lc_multilabel_one_hot

tensor([0., 0., 0., 0., 1., 0., 0., 0.])

In [32]:
print(clip_sample_values)
if clip_sample_values:
    s1 = np.clip(s1, a_min=-25, a_max=0)
    s1 = (
        s1 + 25
    )  # go from [-25,0] to [0,25] interval to make normalization easier
    s2 = np.clip(s2, a_min=0, a_max=1e4)

True


In [33]:
s1

array([[[10.02639401,  9.78071514, 10.13958169, ..., 11.13204785,
         10.49177439, 11.37512439],
        [10.71481945,  9.88525474,  9.2328549 , ..., 10.48376037,
         10.79516637, 11.75380497],
        [10.92238227, 10.84129933,  9.19645813, ..., 10.61944233,
         10.50448479, 11.16301832],
        ...,
        [ 9.52744245,  9.67639517,  9.89710906, ...,  9.5774624 ,
         10.12303341, 10.31890257],
        [ 9.53173707, 10.2359588 ,  9.76507098, ..., 10.12211425,
         10.66147032, 10.8344312 ],
        [10.28666748,  9.97718324, 10.86651068, ...,  9.81332924,
         10.48862606, 11.14928442]],

       [[ 4.0022885 ,  5.46716512,  6.54422221, ...,  5.26850347,
          4.57792922,  5.76102369],
        [ 5.24813458,  5.01389379,  5.43050958, ...,  3.70929459,
          4.72105449,  5.34384327],
        [ 5.06700391,  6.88751629,  4.68191418, ...,  4.53374084,
          4.98471528,  5.67624763],
        ...,
        [ 3.83213516,  3.93717569,  5.4616625 , ...,  

In [34]:
s2

array([[[1467., 1671., 1671., ..., 1397., 1397., 1397.],
        [1467., 1671., 1671., ..., 1397., 1397., 1397.],
        [1467., 1671., 1671., ..., 1397., 1397., 1397.],
        ...,
        [1413., 1413., 1413., ..., 1410., 1398., 1398.],
        [1413., 1413., 1413., ..., 1410., 1398., 1398.],
        [1413., 1413., 1413., ..., 1410., 1398., 1398.]],

       [[1256., 1285., 1410., ..., 1304., 1304., 1291.],
        [1260., 1273., 1312., ..., 1292., 1286., 1286.],
        [1279., 1282., 1310., ..., 1253., 1271., 1271.],
        ...,
        [1362., 1355., 1303., ..., 1319., 1298., 1296.],
        [1225., 1256., 1241., ..., 1309., 1304., 1281.],
        [1247., 1206., 1206., ..., 1288., 1301., 1265.]],

       [[1234., 1249., 1426., ..., 1339., 1339., 1292.],
        [1211., 1220., 1300., ..., 1346., 1266., 1266.],
        [1216., 1239., 1293., ..., 1291., 1265., 1265.],
        ...,
        [1458., 1422., 1354., ..., 1368., 1352., 1375.],
        [1236., 1274., 1253., ..., 1372., 138

In [35]:
s1 = base_transform(np.moveaxis(s1, 0, -1))
s2 = base_transform(np.moveaxis(s2, 0, -1))

In [36]:
s1

tensor([[[10.0264,  9.7807, 10.1396,  ..., 11.1320, 10.4918, 11.3751],
         [10.7148,  9.8853,  9.2329,  ..., 10.4838, 10.7952, 11.7538],
         [10.9224, 10.8413,  9.1965,  ..., 10.6194, 10.5045, 11.1630],
         ...,
         [ 9.5274,  9.6764,  9.8971,  ...,  9.5775, 10.1230, 10.3189],
         [ 9.5317, 10.2360,  9.7651,  ..., 10.1221, 10.6615, 10.8344],
         [10.2867,  9.9772, 10.8665,  ...,  9.8133, 10.4886, 11.1493]],

        [[ 4.0023,  5.4672,  6.5442,  ...,  5.2685,  4.5779,  5.7610],
         [ 5.2481,  5.0139,  5.4305,  ...,  3.7093,  4.7211,  5.3438],
         [ 5.0670,  6.8875,  4.6819,  ...,  4.5337,  4.9847,  5.6762],
         ...,
         [ 3.8321,  3.9372,  5.4617,  ...,  4.5156,  4.3102,  5.3107],
         [ 3.7988,  4.4828,  3.7054,  ...,  3.9380,  5.2030,  4.6722],
         [ 2.4013,  3.0792,  2.5382,  ...,  3.5661,  4.0790,  4.2157]]],
       dtype=torch.float64)

In [37]:
s2

tensor([[[1467., 1671., 1671.,  ..., 1397., 1397., 1397.],
         [1467., 1671., 1671.,  ..., 1397., 1397., 1397.],
         [1467., 1671., 1671.,  ..., 1397., 1397., 1397.],
         ...,
         [1413., 1413., 1413.,  ..., 1410., 1398., 1398.],
         [1413., 1413., 1413.,  ..., 1410., 1398., 1398.],
         [1413., 1413., 1413.,  ..., 1410., 1398., 1398.]],

        [[1256., 1285., 1410.,  ..., 1304., 1304., 1291.],
         [1260., 1273., 1312.,  ..., 1292., 1286., 1286.],
         [1279., 1282., 1310.,  ..., 1253., 1271., 1271.],
         ...,
         [1362., 1355., 1303.,  ..., 1319., 1298., 1296.],
         [1225., 1256., 1241.,  ..., 1309., 1304., 1281.],
         [1247., 1206., 1206.,  ..., 1288., 1301., 1265.]],

        [[1234., 1249., 1426.,  ..., 1339., 1339., 1292.],
         [1211., 1220., 1300.,  ..., 1346., 1266., 1266.],
         [1216., 1239., 1293.,  ..., 1291., 1265., 1265.],
         ...,
         [1458., 1422., 1354.,  ..., 1368., 1352., 1375.],
         [

In [38]:
# normalize images channel wise
s1_maxs = []
for ch_idx in range(s1.shape[0]):
    s1_maxs.append(
        torch.ones((s1.shape[-2], s1.shape[-1])) * s1[ch_idx].max().item()
        + 1e-5
    )
s1_maxs = torch.stack(s1_maxs)
s1_maxs

tensor([[[22.6798, 22.6798, 22.6798,  ..., 22.6798, 22.6798, 22.6798],
         [22.6798, 22.6798, 22.6798,  ..., 22.6798, 22.6798, 22.6798],
         [22.6798, 22.6798, 22.6798,  ..., 22.6798, 22.6798, 22.6798],
         ...,
         [22.6798, 22.6798, 22.6798,  ..., 22.6798, 22.6798, 22.6798],
         [22.6798, 22.6798, 22.6798,  ..., 22.6798, 22.6798, 22.6798],
         [22.6798, 22.6798, 22.6798,  ..., 22.6798, 22.6798, 22.6798]],

        [[14.9390, 14.9390, 14.9390,  ..., 14.9390, 14.9390, 14.9390],
         [14.9390, 14.9390, 14.9390,  ..., 14.9390, 14.9390, 14.9390],
         [14.9390, 14.9390, 14.9390,  ..., 14.9390, 14.9390, 14.9390],
         ...,
         [14.9390, 14.9390, 14.9390,  ..., 14.9390, 14.9390, 14.9390],
         [14.9390, 14.9390, 14.9390,  ..., 14.9390, 14.9390, 14.9390],
         [14.9390, 14.9390, 14.9390,  ..., 14.9390, 14.9390, 14.9390]]])

In [39]:
s2_maxs = []
for b_idx in range(s2.shape[0]):
    s2_maxs.append(
        torch.ones((s2.shape[-2], s2.shape[-1])) * s2[b_idx].max().item() + 1e-5
    )
s2_maxs = torch.stack(s2_maxs)

In [40]:
if normalize:
    s1 = s1 / s1_maxs
    s2 = s2 / s2_maxs

In [41]:
s1

tensor([[[0.4421, 0.4313, 0.4471,  ..., 0.4908, 0.4626, 0.5016],
         [0.4724, 0.4359, 0.4071,  ..., 0.4623, 0.4760, 0.5182],
         [0.4816, 0.4780, 0.4055,  ..., 0.4682, 0.4632, 0.4922],
         ...,
         [0.4201, 0.4267, 0.4364,  ..., 0.4223, 0.4463, 0.4550],
         [0.4203, 0.4513, 0.4306,  ..., 0.4463, 0.4701, 0.4777],
         [0.4536, 0.4399, 0.4791,  ..., 0.4327, 0.4625, 0.4916]],

        [[0.2679, 0.3660, 0.4381,  ..., 0.3527, 0.3064, 0.3856],
         [0.3513, 0.3356, 0.3635,  ..., 0.2483, 0.3160, 0.3577],
         [0.3392, 0.4610, 0.3134,  ..., 0.3035, 0.3337, 0.3800],
         ...,
         [0.2565, 0.2636, 0.3656,  ..., 0.3023, 0.2885, 0.3555],
         [0.2543, 0.3001, 0.2480,  ..., 0.2636, 0.3483, 0.3128],
         [0.1607, 0.2061, 0.1699,  ..., 0.2387, 0.2730, 0.2822]]],
       dtype=torch.float64)

In [42]:
s2

tensor([[[0.8388, 0.9554, 0.9554,  ..., 0.7987, 0.7987, 0.7987],
         [0.8388, 0.9554, 0.9554,  ..., 0.7987, 0.7987, 0.7987],
         [0.8388, 0.9554, 0.9554,  ..., 0.7987, 0.7987, 0.7987],
         ...,
         [0.8079, 0.8079, 0.8079,  ..., 0.8062, 0.7993, 0.7993],
         [0.8079, 0.8079, 0.8079,  ..., 0.8062, 0.7993, 0.7993],
         [0.8079, 0.8079, 0.8079,  ..., 0.8062, 0.7993, 0.7993]],

        [[0.5244, 0.5365, 0.5887,  ..., 0.5445, 0.5445, 0.5390],
         [0.5261, 0.5315, 0.5478,  ..., 0.5395, 0.5370, 0.5370],
         [0.5340, 0.5353, 0.5470,  ..., 0.5232, 0.5307, 0.5307],
         ...,
         [0.5687, 0.5658, 0.5441,  ..., 0.5507, 0.5420, 0.5411],
         [0.5115, 0.5244, 0.5182,  ..., 0.5466, 0.5445, 0.5349],
         [0.5207, 0.5035, 0.5035,  ..., 0.5378, 0.5432, 0.5282]],

        [[0.4515, 0.4570, 0.5218,  ..., 0.4899, 0.4899, 0.4727],
         [0.4431, 0.4464, 0.4757,  ..., 0.4925, 0.4632, 0.4632],
         [0.4449, 0.4533, 0.4731,  ..., 0.4724, 0.4629, 0.

In [45]:
output = {
    "s1": s1,
    "s2": s2,
    "lc": lc,
    "bounds": bounds,
    "idx": idx,
    "lc_label": lc_label,
    "lc_label_str": lc_label_str,
    "lc_multilabel": lc_multilabel.numpy().tolist(),
    "lc_multilabel_one_hot": lc_multilabel_one_hot,
    "season": str(season.value),
    "scene": obs.Scene,
    "id": obs.ID,
}

In [46]:
output

{'s1': tensor([[[0.4421, 0.4313, 0.4471,  ..., 0.4908, 0.4626, 0.5016],
          [0.4724, 0.4359, 0.4071,  ..., 0.4623, 0.4760, 0.5182],
          [0.4816, 0.4780, 0.4055,  ..., 0.4682, 0.4632, 0.4922],
          ...,
          [0.4201, 0.4267, 0.4364,  ..., 0.4223, 0.4463, 0.4550],
          [0.4203, 0.4513, 0.4306,  ..., 0.4463, 0.4701, 0.4777],
          [0.4536, 0.4399, 0.4791,  ..., 0.4327, 0.4625, 0.4916]],
 
         [[0.2679, 0.3660, 0.4381,  ..., 0.3527, 0.3064, 0.3856],
          [0.3513, 0.3356, 0.3635,  ..., 0.2483, 0.3160, 0.3577],
          [0.3392, 0.4610, 0.3134,  ..., 0.3035, 0.3337, 0.3800],
          ...,
          [0.2565, 0.2636, 0.3656,  ..., 0.3023, 0.2885, 0.3555],
          [0.2543, 0.3001, 0.2480,  ..., 0.2636, 0.3483, 0.3128],
          [0.1607, 0.2061, 0.1699,  ..., 0.2387, 0.2730, 0.2822]]],
        dtype=torch.float64),
 's2': tensor([[[0.8388, 0.9554, 0.9554,  ..., 0.7987, 0.7987, 0.7987],
          [0.8388, 0.9554, 0.9554,  ..., 0.7987, 0.7987, 0.7987],