In [37]:
import pandas as pd
import os
import pathlib
import torch
import rasterio
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

In [22]:
class SatelliteData(Dataset):
    
    def __init__(self, 
                 root):
        self.root_dir = pathlib.Path(root)
        self.image_dir = self.root_dir.joinpath("chips_filtered/")
        self.tif_paths = self._get_tif_paths()


    def _get_tif_paths(self):
        csv = pd.read_csv(self.root_dir.joinpath("final_chip_tracker.csv"))
        catalog = csv.loc[(csv["bad_pct_max"] < 5) & (csv["na_count"] == 0)]
        itemlist = sorted(catalog["chip_id"].tolist())
        pathlist = [self.image_dir.joinpath(f"{item}_merged.tif") for item in itemlist]
        chipslist = list(self.image_dir.glob("*.tif"))
        truelist = sorted(list(set(pathlist) & set(chipslist)))
        return truelist

    def __len__(self):
        return len(self.tif_paths)

    def __getitem__(self, index):
        def read_tif_as_np_array(path):
            with rasterio.open(path) as src:
                    return src.read()

        # Read in merged tif as ground truth
        groundtruth = read_tif_as_np_array(self.tif_paths[index])
        groundtruth = torch.tensor(groundtruth, dtype=torch.float32)
        return groundtruth

In [42]:
image_dataset = SatelliteData(root = "/workspace/data/gapfill6band")

# data loader
image_loader = DataLoader(image_dataset, 
                          batch_size  = 320, 
                          shuffle     = False)

# display images
for batch_idx, inputs in enumerate(image_loader):
    print(inputs.shape)
    break



torch.Size([320, 18, 224, 224])


In [43]:
psum    = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
psum_sq = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

# loop through images
for tensor in tqdm(image_loader):
    reshaped_tensor = torch.cat([tensor[:,0:6,:,:]] + [tensor[:,6:12,:,:]] + [tensor[:,12:18,:,:]], axis=2)
    psum    += reshaped_tensor.sum(axis = [0, 2, 3])
    psum_sq += (reshaped_tensor ** 2).sum(axis = [0, 2, 3])

100%|███████████████████████████████████████████| 25/25 [01:16<00:00,  3.05s/it]


In [46]:
count = len(image_dataset) * 224 * 224 * 3

# mean and std
total_mean = psum / count
total_var  = (psum_sq / count) - (total_mean ** 2)
total_std  = torch.sqrt(total_var)

# output
print('mean: '  + str(total_mean))
print('std:  '  + str(total_std))

mean: tensor([ 495.7316,  814.1386,  924.5740, 2962.5623, 2640.8833, 1740.3031])
std:  tensor([286.9569, 359.3304, 576.3471, 892.2656, 945.9432, 916.1625])


In [47]:
print(len(image_dataset))

7852
