In [None]:
import numpy as np
import pandas as pd

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import rasterio

from tqdm import tqdm

from PIL import Image

import os
import glob

from matplotlib import pyplot as plt

In [None]:
data_directory = '/workspace/processed_data/validate'
tif_files = glob.glob(os.path.join(data_directory, '**', '*.tif'), recursive=True)
print(len(tif_files))

## Convert all elevation into relative elevation

In [None]:
for tif_file in tqdm(tif_files, total=len(tif_files)):
    with rasterio.open(tif_file, 'r') as src:
        kwrds = src.profile
        data = src.read(1)

        min_value = data.min()
        
        # subtract minimum value from the raster
        new_data = data - min_value

    kwrds.update(
        dtype=rasterio.float32,
        nodata=None
    )
    
    with rasterio.open(tif_file, 'w', **kwrds) as dst:
        dst.write(new_data, 1)


In [None]:
with rasterio.open(tif_files[11], 'r') as src:
    img = src.read()

In [None]:
img.shape

In [None]:
bands = np.repeat(img,3,axis=0).transpose(1,2,0)

In [None]:
bands.shape[2]

In [None]:
image = Image.fromarray(bands)

In [None]:
plt.imshow(bands[2])

In [None]:
with rasterio.open(tif_files[11], 'r') as src:
    plt.imshow(src.read(1))

## Get mean and standard deviation for all image chips

In [None]:
class ImageData(Dataset):
    
    def __init__(self):
        super().__init__()
        self.file_paths = tif_files
    
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        
        # import
        file_path = self.file_paths[idx]        
        with rasterio.open(file_path, mode='r') as src:
            data = src.read().astype('float32')

        return data


In [None]:
image_dataset = ImageData()
image_loader = DataLoader(image_dataset, 
                          batch_size  = 1, 
                          shuffle     = False, 
                          num_workers = 0)

In [None]:
psum    = torch.tensor([0.0])
psum_sq = torch.tensor([0.0])
index = 0
# loop through images
for inputs in tqdm(image_loader):
    psum += inputs.sum()
    psum_sq += (inputs ** 2).sum()
    index += 1

# pixel count
count = len(tif_files) * 800 * 800

# mean and STD
total_mean = psum / count
total_var  = (psum_sq / count) - (total_mean ** 2)
total_std  = torch.sqrt(total_var)

# output
print('Training data stats:')
print('- mean: {:.4f}'.format(total_mean.item()))
print('- std:  {:.4f}'.format(total_std.item()))