In [5]:
import os
import sys
import torch
import numpy as np
import cv2
from typing import List


sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../..")))

from src.data_processing.dataset_loader import CoastData

In [12]:
def calculate_mean_std(image_paths: List[str]):
    """
    Computes the mean and standard deviation of a dataset.

    Parameters:
        image_paths (List[str]): List of image file paths.

    Returns:
        Tuple[Tuple[float, float, float], Tuple[float, float, float]]: Mean and standard deviation for each channel.
    """
    sum_rgb = torch.zeros(3)
    sum_rgb_sq = torch.zeros(3)
    num_pixels = 0

    for path in image_paths:
        # Load image
        img = cv2.imread(path['image'])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        img = img / 255.0  # Normalize to [0,1]
        
        img_tensor = torch.tensor(img, dtype=torch.float32)
        
        # Sum per channel
        sum_rgb += img_tensor.mean(dim=(0, 1))
        sum_rgb_sq += (img_tensor ** 2).mean(dim=(0, 1))
        num_pixels += 1

    # Compute mean and std
    mean = sum_rgb / num_pixels
    std = torch.sqrt(sum_rgb_sq / num_pixels - mean ** 2)

    return mean.tolist(), std.tolist()

In [14]:
data_path = os.path.abspath(os.path.join(os.getcwd(), "../../data/processed/"))

# Load the data, all the different stations
data = CoastData(data_path)


filtered_data = data.get_images_and_masks() 


mean, std = calculate_mean_std(filtered_data)
print("Mean:", mean)
print("Std:", std)

CoastData: global - 1717 images
Mean: [0.4288156032562256, 0.45132672786712646, 0.4600674510002136]
Std: [0.31724053621292114, 0.3093735873699188, 0.31197479367256165]
