In [3]:
import torch
from torchvision import transforms
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import glob
from tqdm.notebook import tqdm
from google.colab import drive

# =================================================================
# PART 1: SETUP AND DATA DIRECTORY
# =================================================================

# Mount your Google Drive
drive.mount('/content/drive', force_remount=True)

#  1. INPUT PATH: The top-level folder containing your images.
DATA_DIR = '/content/drive/MyDrive/resized_dataset'

#  2. OUTPUT PATH: The folder where 0-1 normalized images will be saved.
OUTPUT_DIR = '/content/drive/MyDrive/normalized_dataset_0_1_output'

# Verify and create directories
if not os.path.exists(DATA_DIR):
    print(" ERROR: Input directory not found. Please check 'DATA_DIR'.")
    raise Exception("Input directory error.")
else:
    print(f" Input connected to: {DATA_DIR}")

os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f" Output folder created at: {OUTPUT_DIR}")

# =================================================================
# PART 2: CUSTOM DATASET (Searches for images recursively)
# =================================================================

class RecursiveImageDataset(Dataset):
    """Loads images recursively from a directory and its subfolders."""
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        patterns = ['**/*.jpg', '**/*.jpeg', '**/*.png', '**/*.JPG', '**/*.JPEG', '**/*.PNG']
        self.image_paths = []
        for pattern in patterns:
             self.image_paths.extend(glob.glob(os.path.join(self.root_dir, pattern), recursive=True))

        if len(self.image_paths) == 0:
             raise Exception(f" ERROR: Found 0 images after searching recursively in {root_dir}.")
        else:
             print(f" Found total of {len(self.image_paths)} images.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, 0

# =================================================================
# PART 3: CALCULATE GLOBAL MEAN AND STD DEV (Essential for model training)
# =================================================================

TEMP_SIZE = 224
temp_transform = transforms.Compose([
    transforms.Resize((TEMP_SIZE, TEMP_SIZE)),
    transforms.ToTensor() # Pixels are scaled to [0.0, 1.0] here
])

temp_dataset = RecursiveImageDataset(root_dir=DATA_DIR, transform=temp_transform)
temp_loader = DataLoader(temp_dataset, batch_size=128, shuffle=False, num_workers=2)

mean_sum = torch.zeros(3)
std_sum_sq = torch.zeros(3)
total_pixels = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mean_sum = mean_sum.to(device)

print("\nCalculating global mean and standard deviation...")

# Two-pass calculation (omitted for brevity, assume DATASET_MEAN/STD are calculated)
# ... (The calculation logic is identical to your original code and runs here) ...

# --- Pass 1: Calculate Mean ---
for images, _ in tqdm(temp_loader, desc="Pass 1/2: Mean"):
    images = images.to(device)
    images = images.permute(0, 2, 3, 1).reshape(-1, 3)
    mean_sum += images.sum(dim=0)
    total_pixels += images.shape[0]
mean_final = mean_sum / total_pixels

# --- Pass 2: Calculate Std Dev ---
for images, _ in tqdm(temp_loader, desc="Pass 2/2: Std Dev"):
    images = images.to(device)
    images = images.permute(0, 2, 3, 1).reshape(-1, 3)
    std_sum_sq += ((images - mean_final) ** 2).sum(dim=0).cpu()
std_final = torch.sqrt(std_sum_sq / total_pixels)

DATASET_MEAN = mean_final.cpu().tolist()
DATASET_STD = std_final.tolist()

print(f"\n=========================================================")
print(f" Calculated GLOBAL Mean (R, G, B): {DATASET_MEAN}")
print(f" Calculated GLOBAL Std Dev (R, G, B): {DATASET_STD}")
print(f"=========================================================")


# =================================================================
# PART 4: 0-1 NORMALIZATION AND SAVING
# =================================================================

# Define the transformation for saving: just resize and scale to [0, 1]
# We explicitly EXCLUDE transforms.Normalize(mean, std) here.
saving_transform = transforms.Compose([
    transforms.Resize((TEMP_SIZE, TEMP_SIZE)),
    transforms.ToTensor() # This step ensures 0-1 normalization (pixels / 255)
])

# Load the dataset using the simple 0-1 transform
normalized_dataset_0_1 = RecursiveImageDataset(root_dir=DATA_DIR, transform=saving_transform)


#  SIMPLIFIED SAVING FUNCTION: Converts the [0, 1] Tensor back to PIL Image.
def tensor_0_1_to_pil_image(tensor):
    """Converts a standard [0, 1] PyTorch Tensor (C, H, W) to a PIL Image (scales to [0, 255])."""
    # F.to_pil_image inherently performs the scaling from [0, 1] to [0, 255]
    return F.to_pil_image(tensor)


print("\n--- Starting Saving Process (0-1 Normalized Images) ---")
total_saved_images = 0

for i, (image_tensor, _) in enumerate(tqdm(normalized_dataset_0_1, desc="Saving 0-1 Images")):

    original_path = normalized_dataset_0_1.image_paths[i]
    original_filename = os.path.basename(original_path)
    base_name, ext = os.path.splitext(original_filename)

    new_filename = f"{base_name}_0_1_normalized.png"
    new_path = os.path.join(OUTPUT_DIR, new_filename)

    # Save the 3-channel image
    pil_image = tensor_0_1_to_pil_image(image_tensor)
    pil_image.save(new_path)
    total_saved_images += 1

print(f"\n Process complete. Total 0-1 normalized images saved: {total_saved_images}")
print(f"Data is saved in the [0, 1] range as 3-channel PNGs in: {OUTPUT_DIR}")

Mounted at /content/drive
 Input connected to: /content/drive/MyDrive/resized_dataset
 Output folder created at: /content/drive/MyDrive/normalized_dataset_0_1_output
 Found total of 1584 images.

Calculating global mean and standard deviation...


Pass 1/2: Mean:   0%|          | 0/13 [00:00<?, ?it/s]

Pass 2/2: Std Dev:   0%|          | 0/13 [00:00<?, ?it/s]


 Calculated GLOBAL Mean (R, G, B): [0.509886622428894, 0.4991682171821594, 0.49289384484291077]
 Calculated GLOBAL Std Dev (R, G, B): [0.3258674144744873, 0.32224684953689575, 0.32660314440727234]
 Found total of 1584 images.

--- Starting Saving Process (0-1 Normalized Images) ---


Saving 0-1 Images:   0%|          | 0/1584 [00:00<?, ?it/s]


 Process complete. Total 0-1 normalized images saved: 1584
Data is saved in the [0, 1] range as 3-channel PNGs in: /content/drive/MyDrive/normalized_dataset_0_1_output
