In [1]:
from lightning.data import StreamingDataLoader, StreamingDataset
import torch 
import os
import pandas as pd
from transforms import make_transforms
import matplotlib.pyplot as plt
import random
import numpy as np
from data_loading import LeJEPADataset
import tqdm


In [2]:
class LeJEPAStreamingDataset(StreamingDataset):
    def __init__(self, metadata, split, transform_s1, transform_s2, num_views=4, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.split = split
        self.metadata = pd.read_parquet(metadata)
        self.metadata1 = self.metadata[self.metadata['split'] == split]['s1_name'].to_list()
        self.metadata2 = self.metadata[self.metadata['split'] == split]['patch_id'].to_list()
        self.transform_s1 = transform_s1
        self.transform_s2 = transform_s2

        self.V = num_views

    def __getitem__(self, idx):
        mc_img_c_h_w1, mc_img_c_h_w2, labels = super().__getitem__(idx)

        views_s1 = torch.stack([self.transform_s1(mc_img_c_h_w1) for _ in range(self.V)])
        views_s2 = torch.stack([self.transform_s2(mc_img_c_h_w2) for _ in range(self.V)])
        return views_s1, views_s2, labels

os.chdir('..')
# Prepare training data
transform_s1 = make_transforms(2, 120)
transform_s2 = make_transforms(10, 120)


# dataloader = StreamingDataLoader(train_dataset, batch_size=16, profile_batches=5, num_workers=os.cpu_count())


In [3]:
%%time

ds1 = LeJEPAStreamingDataset(
    metadata='data/BEN_14k/serbia_metadata.parquet',
    split='train',
    transform_s1=transform_s1,
    transform_s2=transform_s2,
    input_dir='data/opt_BEN_14k/train'
)

dl1 = StreamingDataLoader(ds1, batch_size=256, num_workers=12, shuffle=True, pin_memory=True)



CPU times: user 61.9 ms, sys: 19.1 ms, total: 81.1 ms
Wall time: 412 ms


In [4]:
%%time

ds2 = LeJEPADataset('data/BEN_14k/BigEarthNet-S1', 'data/BEN_14k/BigEarthNet-S2', metadata='data/BEN_14k/serbia_metadata.parquet', split='train', transform_s1=transform_s1, transform_s2=transform_s2)

dl2 = torch.utils.data.DataLoader(ds2, batch_size=256, shuffle=True, num_workers=12, pin_memory=True)

CPU times: user 251 ms, sys: 11 ms, total: 262 ms
Wall time: 378 ms


In [5]:
import time

for i in range(100):
    s = time.time()
    for batch in tqdm.tqdm(dl1):
        pass
    e = time.time()
    print(round(e-s, 2))





100%|██████████| 29/29 [01:51<00:00,  3.85s/it]


111.78


100%|██████████| 29/29 [00:13<00:00,  2.17it/s]


13.38


100%|██████████| 29/29 [00:12<00:00,  2.38it/s]


12.16


100%|██████████| 29/29 [00:12<00:00,  2.36it/s]


12.3


100%|██████████| 29/29 [00:12<00:00,  2.37it/s]


12.22


100%|██████████| 29/29 [00:12<00:00,  2.33it/s]


12.43


100%|██████████| 29/29 [00:12<00:00,  2.39it/s]


12.12


100%|██████████| 29/29 [00:12<00:00,  2.37it/s]


12.22


100%|██████████| 29/29 [00:12<00:00,  2.40it/s]


12.09


 69%|██████▉   | 20/29 [00:12<00:05,  1.64it/s]


KeyboardInterrupt: 

In [7]:
for i in range(100):
    s = time.time()
    for batch in tqdm.tqdm(dl2):
        pass
    e = time.time()
    print(round(e-s, 2))


100%|██████████| 29/29 [00:40<00:00,  1.38s/it]


40.16


100%|██████████| 29/29 [00:16<00:00,  1.72it/s]


16.91


100%|██████████| 29/29 [00:16<00:00,  1.77it/s]


16.42


100%|██████████| 29/29 [00:16<00:00,  1.75it/s]


16.58


100%|██████████| 29/29 [00:16<00:00,  1.75it/s]


16.59


100%|██████████| 29/29 [00:16<00:00,  1.71it/s]


16.98


100%|██████████| 29/29 [00:16<00:00,  1.74it/s]


16.71


100%|██████████| 29/29 [00:16<00:00,  1.73it/s]


16.74


100%|██████████| 29/29 [00:16<00:00,  1.75it/s]


16.56


 41%|████▏     | 12/29 [00:07<00:10,  1.61it/s]


KeyboardInterrupt: 