**Note:** Outputs have been cleared.

In [None]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [None]:
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.0591, 0.0526, 0.0591],
                         std=[0.0936, 0.0823, 0.0838])
])

# These values were pre-computed using `compute_mean_std()`, found below
# Global Mean: tensor([0.0591, 0.0526, 0.0591])
# Global Std:  tensor([0.0936, 0.0823, 0.0838])

In [None]:
class GZdataset(Dataset):
    def __init__(self, catalog, image_dir, transform=None):
        if isinstance(catalog, str):
            self.catalog = pd.read_parquet(catalog)
        else:
            self.catalog = catalog

        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.catalog)

    def __getitem__(self, idx):
        row = self.catalog.iloc[idx]
        img_path = os.path.join(self.image_dir, row['filename'])
        image = Image.open(img_path).convert('RGB')
        transformed_image = self.transform(image)

        return transformed_image

In [None]:
original_images_folder_path = '../gz/gz_candels/images' # relative path

# Load the original catalogs
train_catalog = pd.read_parquet('../gz/gz_candels/train_catalog.parquet')
test_catalog = pd.read_parquet('../gz/gz_candels/test_catalog.parquet')

# Sample images within their respective catalogs (reduced count from the original due to hardware constraints)
train_subset_catalog = train_catalog.sample(16_000)
test_subset_catalog = test_catalog.sample(4_000)

# Save subsets to new parquet files (I put them on an external SSD)
train_subset_catalog.to_parquet('/Volumes/Samsung T7 SSD/gz_candels_tensors/train_catalog.parquet')
test_subset_catalog.to_parquet('/Volumes/Samsung T7 SSD/gz_candels_tensors/test_catalog.parquet')

In [None]:
data_subset = GZdataset(
    catalog=train_subset_catalog,
    image_dir=original_images_folder_path,
    transform=image_transform
)

test_subset = GZdataset(
    catalog=test_subset_catalog,
    image_dir=original_images_folder_path,
    transform=image_transform
)

train_loader = DataLoader(data_subset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=128, shuffle=False)

In [None]:
# Convert images to tensor files and save them
train_tensor_list = []
for batch in tqdm(train_loader):
    train_tensor_list.append(batch)
train_images_tensor = torch.cat(train_tensor_list)
torch.save(train_images_tensor, '/Volumes/Samsung T7 SSD/gz_candels_tensors/train_images_tensor.pt')

test_tensor_list = []
for batch in tqdm(test_loader):
    test_tensor_list.append(batch)
test_images_tensor = torch.cat(test_tensor_list)
torch.save(test_images_tensor, '/Volumes/Samsung T7 SSD/gz_candels_tensors/test_images_tensor.pt')

In [None]:
def compute_mean_std(dataset, batch_size=128):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # Accumulate sums and sums of squares separately for each channel.
    sum_channels = torch.zeros(3)
    sum_channels_sq = torch.zeros(3)
    n_pixels_total = 0

    for data in tqdm(dataloader, desc="Computing global mean/std"):
        b, c, h, w = data.shape

        # Reshape to (B, C, H*W)
        data = data.view(b, c, -1)

        # Sum of values and squares in this batch per channel
        sum_channels += data.sum(dim=(0, 2))
        sum_channels_sq += (data ** 2).sum(dim=(0, 2))

        n_pixels_total += b * h * w

    # Compute mean and std across all pixels
    mean_pixels = sum_channels / n_pixels_total
    std_pixels = torch.sqrt(((sum_channels_sq / n_pixels_total) - (mean ** 2)))

    return mean_pixels, std_pixels


# Simple transform to not manipulate image data
tensor_transform = transforms.Compose([transforms.ToTensor()])

mean_std_dataset = GZdataset(
    catalog=train_catalog,
    image_dir=original_images_folder_path,
    transform=tensor_transform
)

mean, std = compute_mean_std(mean_std_dataset)
print("Global Mean:", mean)
print("Global Std: ", std)