# 🌁 Image normalization statistiques

# 📚 Libraries

In [18]:
import os
import rasterio
from tqdm.notebook import tqdm
from glob import glob
import torch
import json

# 🛩️ Mean & Std for Aerial images

In [5]:
def read_tif(path_file):
    with rasterio.open(path_file) as f:
        image = f.read()
        image = torch.from_numpy(image)
        image = image.type(torch.uint8)  # from 0 to 255

    return image

In [19]:
def compute_stats(template_path_image, filename, num_channels):
    # Initialize variables to store accumulated pixel values
    channel_sum = torch.zeros(num_channels)
    channel_squared_diff_sum = torch.zeros(num_channels)
    total_pixels = 0

    # Iterate over the image paths
    list_path_image = glob(template_path_image, recursive=True)
    for path_image in tqdm(list_path_image, desc='Compute sum by channel'):
        # Open the image
        image = read_tif(path_image)
        image = image / 255
            
        # Reshape the image to a 2D array of pixels (height * width, channels)
        pixels = image.view(-1, num_channels)
        
        # Accumulate channel sums
        channel_sum += pixels.sum(dim=0)
        
        # Update the total number of pixels
        total_pixels += pixels.shape[0]

    # Compute mean values for each channel
    channel_mean = channel_sum / total_pixels

    for path_image in tqdm(list_path_image, desc='Compute squared diff sum by channel'):
        # Open the image
        image = read_tif(path_image)
        image = image / 255
            
        # Reshape the image to a 2D array of pixels (height * width, channels)
        pixels = image.view(-1, num_channels)
        
        # Accumulate squared differences from the mean
        diff = pixels - channel_mean
        channel_squared_diff_sum += (diff * diff).sum(dim=0)

    # Compute standard deviation values for each channel
    channel_std = torch.sqrt(channel_squared_diff_sum / total_pixels)

    dataset_dict = {
        'mean': channel_mean.tolist(),
        'std': channel_std.tolist()
    }

    with open(filename, 'w', encoding='UTF8') as f:
        json.dump(dataset_dict, f)

In [20]:
template_path_image = os.path.join(os.pardir, 'data', 'raw', 'train', 'aerial', '**', '*.tif')
filename = os.path.join(os.pardir, 'data', 'aerial_pixels_metadata.json')
compute_stats(template_path_image, filename, 5)

Compute sum by channel:   0%|          | 0/61712 [00:00<?, ?it/s]

Compute squared diff sum by channel:   0%|          | 0/61712 [00:00<?, ?it/s]