# Preparation
Here we only define the paths that we are going to use throughout the notebook

In [1]:
path_split = "/data/kaiclasen/split.csv"
path_labels = "/data/kaiclasen/lbls.parquet"
path_image_lmdb = "/data/kaiclasen/BENv2.lmdb"
path_s2v2name_to_s1v1name = "/data/kaiclasen/new_s2s1_mapping.parquet"

# Read patch names

For the dataset that we are using later, we can eighter say we want to use all data in the data base or we choose to supply a list of patch names that we want to use.
To supply the _official split_, we simply read the `split.csv` and filter then by the `split` column into the three different predefined splits.
However, right now this `.csv` contains patches that are filtered during creation of the data set, which means that not all names contained here actually exist within the LMDB file or have labels.
Therefore we filter all the patches that don't have any labels associated with them.
For easier access later we save this information in a `dict`.

Of cause generating this list is later up to the user of the DS and can be substituded for any other selection method like season- or country-based methods.

In [2]:
from pprint import pprint
import pandas as pd

# read the csv for split info
df = pd.read_csv(path_split)

# get all patches without dublicates 
# -> this means not all labels are in this table but we don't use it for label information at this point anyways
lbls = pd.read_parquet(path_labels).drop_duplicates(['patch'])

# get only patches that also have a label, drop the additional columns
df = df.merge(lbls, how='inner', left_on=['name'], right_on=['patch']).drop(['lbl_19', 'patch'], axis=1)
print(df, end='\n\n')

# filter by split column and write into dict
patches = {
    split: sorted(list(df[df.split == split].name.values)) for split in ['train', 'validation', 'test']
}

# show first 10 entries in each split
# due to the naming scema and the frame-like splitting, we can nicely see the differences for the split 
#    for the patches
pprint({s: patches[s][:10] for s in patches.keys()})
print()
pprint({s: len(patches[s]) for s in patches.keys()})

                                                     name       split
0       S2B_MSIL2A_20171206T094349_N9999_R036_T34TCR_8...       train
1       S2A_MSIL2A_20171002T094031_N9999_R036_T34TCR_8...       train
2       S2A_MSIL2A_20170613T101031_N9999_R022_T34VER_8...       train
3       S2B_MSIL2A_20180525T094029_N9999_R036_T35VNK_4...  validation
4       S2B_MSIL2A_20180522T093029_N9999_R136_T35VPJ_3...        test
...                                                   ...         ...
549483  S2A_MSIL2A_20180413T095031_N9999_R079_T35VLG_6...        test
549484  S2B_MSIL2A_20180224T112109_N9999_R037_T29SNC_0...       train
549485  S2B_MSIL2A_20180223T101019_N9999_R022_T34WFT_7...  validation
549486  S2B_MSIL2A_20170817T101019_N9999_R022_T34WFS_4...        test
549487  S2B_MSIL2A_20180326T112109_N9999_R037_T29SNB_4...        test

[549488 rows x 2 columns]
{'test': ['S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57',
          'S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_27_55',
     

# Creating the DataSet

To create the dataset for the respective splits, we can just use the DataSet class provided and later pass it to the DataLoader.
We have to supply the files for the image data `image_lmdb_file`, the file containing the labels `label_file` and the mapping for lookup from the _NEW_ S2 name to the _OLD_ S1 name `s2s1_mapping_file`. 
This mapping is not required if you only want to use the S2 images.

Optionally we can also supply the band names that should be returned in `bands` in a list like `["B02", "B03", "B04", "B12", "VV"]`. 
Only those bands will then be read and returned in a dict. Restricting this to only S1 or S2 band names reduces the reading accesses needed and increases throughput.

The data is returned as `Tuple[Dict[str, numpy.nd_array], List[str]]`. 
There is no resizing or stacking performed on the bands. If you want to use the DS for pytorch training inside a DataLoader the values have to be returned as tensors with some dimension restrictions. 
To account for that, the DS accepts two parameters `process_bands_fn` and `process_labels_fn`, each being functions that accept the repective elements of the tuple mentioned before. 
`process_bands_fn` also takes in as a second parameter the name of the bands. 
If you want to simply stack the bands and interpolate to an equal size, there is a function provided in `BENv2TorchUtils` where you can select the interpolation mode via `functools.partial`.
To convert the `List[str]` of the labels to a multi-hot tensor, a function is provided in `BENv2TorchUtils` as well.

To select which images to include in the DS, a list of `keys` can be passed. 
If `None` is passed (default), all S2 keys in the LMDB data base will be used.
Here we will use the patch names that we read earlier.

Augmentations can be passed as `transforms`. 
They have to accept whatever is returned from the `process_bands_fn` function if set or the dictionary that is returned by default.

If we need the patch name that was read - e.g. for visualization purposes - we can set `return_patchname` to `True`. 
Instead of `(Image, Label)` the DS will then return `(Image, Label, Patch name)`.

If you want additional information what the DataSet is doing, you can pass `verbose=True` to print different status updates.

In [3]:
from functools import partial

import torchvision

from BENv2DataSet import BENv2DataSet
from BENv2Stats import means, stds
from BENv2TorchUtils import ben_19_labels_to_multi_hot
from BENv2TorchUtils import stack_and_interpolate

image_size = 120
upsample_mode = "nearest"

bands = ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B11", "B12", "B8A", "VV", "VH"]

# there are different combinations depending on the mode available, select the right mode
mean = means[f"{image_size}_{upsample_mode}"]
std = stds[f"{image_size}_{upsample_mode}"]
# only select the right bands
mean = [mean[b] for b in bands]
std = [std[b] for b in bands]

# some default transformations that we can pass to the data set class
transforms = {
    "train": torchvision.transforms.Compose(
        [
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomVerticalFlip(),
            torchvision.transforms.Normalize(mean, std),
        ]
    ),
    "validation": torchvision.transforms.Compose([torchvision.transforms.Normalize(mean, std)]),
    "test": torchvision.transforms.Compose([torchvision.transforms.Normalize(mean, std)]),
}

# create a DataSet for each split
ds = {
    s: BENv2DataSet(
        image_lmdb_file=path_image_lmdb,
        label_file=path_labels,
        s2s1_mapping_file=path_s2v2name_to_s1v1name,
        bands=bands,
        process_bands_fn=partial(
            stack_and_interpolate, img_size=image_size, upsample_mode=upsample_mode
        ),
        process_labels_fn=ben_19_labels_to_multi_hot,
        transforms=transforms[s],
        keys=patches[s],
        return_patchname=False,
    )
    for s in patches
}


In [4]:
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

bs = 256
workers = 8

# create a DataLoader for each DS
dl = {
    "train": DataLoader(
        ds["train"], batch_size=bs, num_workers=workers, pin_memory=False, shuffle=True
    ),
    "validation": DataLoader(
        ds["validation"],
        batch_size=bs,
        num_workers=workers,
        pin_memory=False,
        shuffle=False,
    ),
    "test": DataLoader(
        ds["test"], batch_size=bs, num_workers=workers, pin_memory=False, shuffle=False
    ),
}


# create a fake train-val-loop
epochs = 5
for e in tqdm(range(epochs), position=0, desc="Epoch progress"):
    for batch in tqdm(dl["train"], position=1, leave=False, desc="training"):
        pass

    for batch in tqdm(dl["validation"], position=1, leave=False, desc="validating"):
        pass


for batch in tqdm(dl["test"], position=1, leave=False, desc="testing"):
    pass


Epoch progress:   0%|          | 0/5 [00:00<?, ?it/s]

training:   0%|          | 0/1065 [00:00<?, ?it/s]

validating:   0%|          | 0/546 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Data Module

The Data Module is a wrapper around the DataSet and DataLoader that is used by pytorch lightning. It also handles the splitting of the data into train, validation and test set according to the split file as shown above. The augmentation is also handled by the Data Module in the same way as shown above if no augmentation is passed to the Data Module.

In [None]:
from time import time
from tqdm.notebook import tqdm

from BENv2DataModule import BENv2DataModule

total = 64

bs = 256
workers = 8

dm = BENv2DataModule(
    image_lmdb_file=path_image_lmdb,
    label_file=path_labels,
    s2s1_mapping_file=path_s2v2name_to_s1v1name,
    split_file=path_split,
    batch_size=bs,
    num_workers=workers
)

dm.setup("fit")

t0 = time()
for j in range(2):
    for i, batch in tqdm(enumerate(dm.train_dataloader()), total=len(dm.train_dataloader())):
        pass

t1 = time()
print(f"{bs:4d} {workers:2d} --  Took {t1-t0}s for {(i+1)*bs*(j+1)} samples")
print(f"            {(i+1)*bs*(j+1) / (t1-t0)} samples/s")
print(f"            {(t1-t0)/((i+1)*bs*(j+1))} s/sample")