In [26]:
class AlbumentationsWrapper:
    def __init__(self, albumentations_transform, static_items):
        self.transform = albumentations_transform
        self.static_items = self._init_static_items(static_items)

    def _init_static_items(self, items):
        out = {}
        for k, v in items.items():
            if isinstance(v, dict) and "init" in v:
                init_type = v["init"]
                shape = v["shape"]
                if init_type == "zeros":
                    out[k] = torch.zeros(*shape)
                elif init_type == "randn":
                    out[k] = torch.randn(*shape)
                else:
                    raise ValueError(f"Unsupported init type: {init_type}")
            else:
                out[k] = v
        return out

    def __call__(self, **kwargs):
        result = self.transform(**kwargs)
        return {**result, **self.static_items}


from albumentations import Compose, Resize
import torch
from terratorch.transforms import RemapKeys

transform = AlbumentationsWrapper(
    albumentations_transform=Compose([
        Resize(64, 64),
        RemapKeys(key_map={
            'image': 'pixels'
        }),
    ]),
    static_items={
        'time': {'init': 'zeros', 'shape': [64, 4]},
        'platform': ["sentinel-2-l2a"],
        'latlon': {'init': 'zeros', 'shape': [64, 4]},
        'waves': {'init': 'zeros', 'shape': [4]},
        'gsd': 10
    }
)

In [27]:
import albumentations as A
from terratorch.transforms import RemapKeys, AddStaticKeys
from terratorch.datamodules.torchgeo_data_module import TorchNonGeoDataModule
import torchgeo 


#       init_args:
#         key_map:
#           image: pixels
dm = TorchNonGeoDataModule(
    cls = torchgeo.datamodules.EuroSATDataModule,
    batch_size = 64,
    num_workers = 1,
    transforms= [transform],
    **kwargs_for_wrapped_datamodule,
)


In [28]:
dm.setup(stage='fit')

In [29]:
next(iter(dm.train_dataloader())).keys()

dict_keys(['pixels', 'label', 'time', 'platform', 'latlon', 'waves', 'gsd'])