In [31]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from datasets.Era5 import Era5
from datasets.Dem import Dem
from datasets.Sentinel3 import Sentinel3
from datasets.Sentinel5 import Sentinel5
from datasets.LandCover import LandCover
from datasets.CollectionDataset import CollectionDataset
from utils import process_images, get_day_of_year_and_day_of_week

from presto.presto import Encoder, Decoder, Presto
from PixelTimeseries import PixelTimeSeries
from PrestoMaskedLanguageModel import make_mask

MASK_STRATEGIES = (
    "group_bands",
    "random_timesteps",
    "chunk_timesteps",
    "random_combinations",
)

dataset = PixelTimeSeries(num_timesteps=5, jump = 5,  input_data_path = 'data_small_file.pt')

In [32]:
train_dataloader = DataLoader(
        dataset,
        batch_size=64,
        shuffle=False,
    )

last_x = None
last_hard_mask = None
last_batch_size = None

for x, hard_mask, latlons, day_of_year, day_of_week in train_dataloader:
    soft_mask = make_mask(x, hard_mask, MASK_STRATEGIES[2], 0.1)
    last_x = x
    last_hard_mask = hard_mask
    last_batch_size = x.shape[0]

In [33]:
import numpy as np
for i in range(last_batch_size):
    print( np.logical_and(soft_mask[i], last_hard_mask[i]).sum())
    size = last_x[i].shape[0] * last_x[i].shape[1]
    ratio = soft_mask[i].sum() / ( size - hard_mask[i].sum())
    print( ratio )
    print( soft_mask[i].sum() , soft_mask[i].shape , last_x[i].shape , last_hard_mask[i].sum() ,"\n")

tensor(0)
tensor(0.2190)
tensor(46) torch.Size([5, 46]) torch.Size([5, 46]) tensor(20.) 

tensor(0)
tensor(0.2255)
tensor(46) torch.Size([5, 46]) torch.Size([5, 46]) tensor(26.) 

tensor(0)
tensor(0.2312)
tensor(46) torch.Size([5, 46]) torch.Size([5, 46]) tensor(31.) 

tensor(0)
tensor(0.2180)
tensor(46) torch.Size([5, 46]) torch.Size([5, 46]) tensor(19.) 

tensor(0)
tensor(0.2233)
tensor(46) torch.Size([5, 46]) torch.Size([5, 46]) tensor(24.) 

tensor(0)
tensor(0.2312)
tensor(46) torch.Size([5, 46]) torch.Size([5, 46]) tensor(31.) 

tensor(0)
tensor(0.2201)
tensor(46) torch.Size([5, 46]) torch.Size([5, 46]) tensor(21.) 

tensor(0)
tensor(0.2201)
tensor(46) torch.Size([5, 46]) torch.Size([5, 46]) tensor(21.) 

tensor(0)
tensor(0.2266)
tensor(46) torch.Size([5, 46]) torch.Size([5, 46]) tensor(27.) 

tensor(0)
tensor(0.2201)
tensor(46) torch.Size([5, 46]) torch.Size([5, 46]) tensor(21.) 

tensor(0)
tensor(0.2222)
tensor(46) torch.Size([5, 46]) torch.Size([5, 46]) tensor(23.) 

tensor(0)


In [34]:
from datasets.CollectionDataset import BANDS, BANDS_GROUPS_IDX, BAND_EXPANSION
BAND_EXPANSION
BANDS_GROUPS_IDX

OrderedDict([('S3', [0, 1, 2, 3, 4]),
             ('S5', [5, 6, 7, 8, 9, 10, 11, 12, 13, 14]),
             ('ERA5',
              [15,
               16,
               17,
               18,
               19,
               20,
               21,
               22,
               23,
               24,
               25,
               26,
               27,
               28,
               29,
               30,
               31,
               32,
               33,
               34]),
             ('DEM', [35]),
             ('LC', [36, 37, 38, 39, 40, 41, 42, 43, 44, 45])])