In [1]:
import pandas as pd
import os
import pathlib
import torch
import rasterio
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

In [13]:
class SatelliteData(Dataset):
    
    def __init__(self, 
                 root):
        self.root_dir = pathlib.Path(root)
        self.image_dir = self.root_dir.joinpath("images")
        self.tif_paths = self._get_tif_paths()


    def _get_tif_paths(self):
        tif_paths = [self.image_dir.joinpath(i) for i in os.listdir(self.image_dir)]
        return tif_paths
    
    def __len__(self):
        return len(self.tif_paths)

    def __getitem__(self, index):
        def read_tif_as_np_array(path):
            with rasterio.open(path) as src:
                    return src.read()

        # Read in merged tif as ground truth
        groundtruth = read_tif_as_np_array(self.tif_paths[index])
        groundtruth = torch.tensor(groundtruth, dtype=torch.float32)
        return groundtruth

In [14]:
image_dataset = SatelliteData(root = "/home/data")

# data loader
image_loader = DataLoader(image_dataset, 
                          batch_size  = 320, 
                          shuffle     = False)

# display images
for batch_idx, inputs in enumerate(image_loader):
    print(inputs.shape)
    break



torch.Size([320, 4, 512, 512])


In [15]:
psum    = torch.tensor([0.0, 0.0, 0.0, 0.0])
psum_sq = torch.tensor([0.0, 0.0, 0.0, 0.0])


# loop through images
for tensor in tqdm(image_loader):
    
    psum    += tensor.sum(axis = [0, 2, 3])
    psum_sq += (tensor ** 2).sum(axis = [0, 2, 3])


100%|████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:09<00:00,  1.42s/it]


In [19]:
pmin    = torch.tensor([0.0, 0.0, 0.0, 0.0])
pmax    = torch.tensor([0.0, 0.0, 0.0, 0.0])

for tensor in tqdm(image_loader):
    tensor_min = torch.amin(tensor, (0, 2, 3))
    tensor_max = torch.amax(tensor, (0, 2, 3))
    pmin = torch.min(tensor_min, pmin)
    pmax = torch.max(tensor_max, pmax)

print(pmin)
print(pmax)

100%|████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:09<00:00,  1.40s/it]

tensor([0., 0., 0., 0.])
tensor([255., 255., 255., 255.])





In [17]:
count = len(image_dataset) * 512 * 512

# mean and std
total_mean = psum / count
total_var  = (psum_sq / count) - (total_mean ** 2)
total_std  = torch.sqrt(total_var)

# output
print('mean: '  + str(total_mean))
print('std:  '  + str(total_std))

mean: tensor([112.1533, 114.1895,  92.9912, 254.9901])
std:  tensor([39.2668, 34.4827, 34.6238,  1.5245])


In [18]:
print(len(image_dataset))

2074
