# Large-Scale Image Preprocessing for SimCLR and Supervised Datasets

This notebook preprocesses large image datasets by resizing, normalizing, and saving images and tensors for efficient loading in downstream experiments. It handles both SimCLR (unlabeled) and supervised (labeled) datasets, ensuring portability and memory efficiency.

## 1. Import Required Libraries

In [None]:
import os
import glob
from PIL import Image
import torch
from torchvision import transforms
from tqdm import tqdm
import numpy as np

## 2. Define Dataset Paths

In [None]:
# Example directory structure (customize as needed)
RAW_SIMCLR_DIR = "/home/zerotwo/ssl-ml/data/raw/simclr_dataset"
RAW_SUPERVISED_DIR = "/home/zerotwo/ssl-ml/data/raw/genshin_classification"

PROCESSED_SIMCLR_IMG_DIR = "/home/zerotwo/ssl-ml/data/processed/simclr_images"
PROCESSED_SUPERVISED_IMG_DIR = "/home/zerotwo/ssl-ml/data/processed/genshin_images"

PROCESSED_SIMCLR_TENSOR_PATH = "/home/zerotwo/ssl-ml/data/processed/simclr_dataset.pt"
PROCESSED_SUPERVISED_TENSOR_PATH = "/home/zerotwo/ssl-ml/data/processed/genshin_dataset.pt"

## 3. Image Transformations (Resize & Normalize)

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

## 4. Preprocess and Save SimCLR Images and Tensors

In [None]:
os.makedirs(PROCESSED_SIMCLR_IMG_DIR, exist_ok=True)
image_paths = glob.glob(os.path.join(RAW_SIMCLR_DIR, '**', '*.jpg'), recursive=True) + \
              glob.glob(os.path.join(RAW_SIMCLR_DIR, '**', '*.png'), recursive=True)

tensor_list = []
for img_path in tqdm(image_paths, desc='SimCLR Preprocessing'):
    img = Image.open(img_path).convert('RGB')
    tensor = transform(img)
    tensor_list.append(tensor)
    # Save processed image (as PNG, normalized values will be visible only for tensor loading)
    out_path = os.path.join(PROCESSED_SIMCLR_IMG_DIR, os.path.basename(img_path))
    img.save(out_path)

simclr_tensor = torch.stack(tensor_list)
torch.save(simclr_tensor, PROCESSED_SIMCLR_TENSOR_PATH)
print(f"SimCLR processed tensor shape: {simclr_tensor.shape}")

## 5. Preprocess and Save Supervised Images and Tensors (No Splits)

In [None]:
os.makedirs(PROCESSED_SUPERVISED_IMG_DIR, exist_ok=True)
all_tensors = []
all_labels = []

# Expect RAW_SUPERVISED_DIR/class_name/*.jpg (or .png)
for class_name in os.listdir(RAW_SUPERVISED_DIR):
    class_dir = os.path.join(RAW_SUPERVISED_DIR, class_name)
    if not os.path.isdir(class_dir): continue
    img_paths = glob.glob(os.path.join(class_dir, '*.jpg')) + glob.glob(os.path.join(class_dir, '*.png'))
    for img_path in tqdm(img_paths, desc=f'{class_name}'):
        img = Image.open(img_path).convert('RGB')
        tensor = transform(img)
        all_tensors.append(tensor)
        all_labels.append(class_name)
        # Save processed image
        out_dir = os.path.join(PROCESSED_SUPERVISED_IMG_DIR, class_name)
        os.makedirs(out_dir, exist_ok=True)
        out_path = os.path.join(out_dir, os.path.basename(img_path))
        img.save(out_path)

# Convert labels to indices
class_names = sorted(list(set(all_labels)))
label_to_idx = {name: idx for idx, name in enumerate(class_names)}
labels_idx = [label_to_idx[label] for label in all_labels]

supervised_tensor = torch.stack(all_tensors)
labels_tensor = torch.tensor(labels_idx)
torch.save({'images': supervised_tensor, 'labels': labels_tensor, 'class_names': class_names}, PROCESSED_SUPERVISED_TENSOR_PATH)
print(f"Supervised processed tensor shape: {supervised_tensor.shape}, labels shape: {labels_tensor.shape}")

## 6. Example: Loading Preprocessed Tensors in Training Notebook

In [None]:
# SimCLR
simclr_tensor = torch.load(PROCESSED_SIMCLR_TENSOR_PATH)
print(simclr_tensor.shape)

# Supervised
supervised_data = torch.load(PROCESSED_SUPERVISED_TENSOR_PATH)
images = supervised_data['images']
labels = supervised_data['labels']
class_names = supervised_data['class_names']
print(images.shape, labels.shape, class_names)