*Recommended method to use SEVIR Dataloader to set up train/dev/test, simply modify the .yaml file with your own configs.*
*Install OmegaConf and pytorch_lightning*

In [1]:
from omegaconf import OmegaConf
from sevir_loader.sevir_torch_wrap import get_sevir_datamodule


config_path =  "/mnt/data/public_datasets/SEVIR/dataloader/sevir_example.yaml"  # Change to your project path
oc_from_file = OmegaConf.load(open(config_path, "r"))
dataset_oc = OmegaConf.to_object(oc_from_file.dataset)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dm = get_sevir_datamodule(
    dataset_oc=dataset_oc,
    num_workers=8,)
dm.prepare_data() # Check if SEVIR dataset is available
dm.setup() # Preprocess train/val/test data set

train_data_loader = dm.train_dataloader()
val_data_loader = dm.val_dataloader()
test_data_loader = dm.test_dataloader()

In [None]:
for idx, batch in enumerate(train_data_loader):
    data_seq = batch['vil']
    print(data_seq.shape)  # torch.Size([16, 29, 384, 384, 1])

*Configure Dataloader manually, this method is more time-consuming without multiprocess*

In [2]:
import datetime
import numpy as np
from sevir_loader.sevir_dataloader import SEVIR_CATALOG, SEVIR_DATA_DIR, SEVIR_RAW_SEQ_LEN, \
    SEVIR_LR_CATALOG, SEVIR_LR_DATA_DIR, SEVIR_LR_RAW_SEQ_LEN, SEVIRDataLoader


dataset = 'sevir'
if dataset == "sevir":
    catalog_path = SEVIR_CATALOG
    data_dir = SEVIR_DATA_DIR
    raw_seq_len = SEVIR_RAW_SEQ_LEN
elif dataset == "sevir_lr":
    catalog_path = SEVIR_LR_CATALOG
    data_dir = SEVIR_LR_DATA_DIR
    raw_seq_len = SEVIR_LR_RAW_SEQ_LEN
else:
    raise ValueError(f"Invalid dataset: {dataset}")

batch_size = 16
data_types = ["vil", ]
layout = "NTHWC"
seq_len = 29
stride = seq_len
sample_mode = "sequent"
start_date = datetime.datetime(2019, 5, 27)
end_date = datetime.datetime(2019, 5, 29)

In [3]:
import datetime
from sevir_loader.sevir_dataloader import SEVIR_CATALOG, SEVIR_DATA_DIR, SEVIR_RAW_SEQ_LEN, \
    SEVIRDataLoader


batch_size = 16
data_types = ["vil", ]
layout = "NTHWC"
seq_len = 29
stride = 5
sample_mode = "sequent"
start_date = datetime.datetime(2019, 5, 27)
end_date = datetime.datetime(2019, 5, 29)

dataloader = SEVIRDataLoader(
    data_types=data_types,
    seq_len=seq_len,
    raw_seq_len=SEVIR_RAW_SEQ_LEN,
    sample_mode=sample_mode,
    stride=stride,
    batch_size=batch_size,
    layout=layout,
    num_shard=1, rank=0, split_mode="uneven",
    sevir_catalog=SEVIR_CATALOG,
    sevir_data_dir=SEVIR_DATA_DIR,
    start_date=start_date, end_date=end_date,
    shuffle=True)

for data_idx, data in enumerate(dataloader):
    data_seq = data['vil']
    print(data_seq.shape)

torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
torch.Size([16, 29, 384, 384, 1])
