# Large-Scale Image Preprocessing for SimCLR and Supervised Datasets

This notebook preprocesses large image datasets by resizing, normalizing, and saving images and tensors for efficient loading in downstream experiments. It handles both SimCLR (unlabeled) and supervised (labeled) datasets, ensuring portability and memory efficiency.

## 1. Import Required Libraries

In [None]:
import os
import glob
from PIL import Image
import torch
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import kagglehub
import torchvision.datasets as datasets
import torchvision.transforms.functional as F

## 2. Image Transformations (Resize & Normalize)

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

## 3. Dataset retrieval and transformation for Self-Supervised Learning dataset

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
N = 10

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

output_dir = os.path.join(os.getcwd(), "datasets")
pretrain_img_dir = os.path.join(output_dir, "pretrain")
os.makedirs(pretrain_img_dir, exist_ok=True)

if os.path.exists(os.path.join(output_dir, "pretrain.pt")):
    print("Pretrain dataset already exists. Loading...")
    pretrain_dataset = torch.load(os.path.join(output_dir, "pretrain.pt"), weights_only=False)
else:
    print("Pretrain dataset does not exist. Downloading and processing datasets...")

    ds1_path = kagglehub.dataset_download("soumikrakshit/anime-faces")
    ds2_path = kagglehub.dataset_download("stevenevan99/face-of-pixiv-top-daily-illustration-2020")
    ds3_path = kagglehub.dataset_download("hirunkulphimsiri/fullbody-anime-girls-datasets")

    dataset1 = datasets.ImageFolder(root=ds1_path, transform=transform)
    dataset2 = datasets.ImageFolder(root=ds2_path, transform=transform)
    dataset3 = datasets.ImageFolder(root=ds3_path, transform=transform)
    pretrain_dataset = torch.utils.data.ConcatDataset([dataset1, dataset2, dataset3])

    # Save processed images as PNG
    idx = 1
    for ds in [dataset1, dataset2, dataset3]:
        for i in range(len(ds)):
            img_tensor, _ = ds[i]
            # Denormalize for saving as PNG
            img = F.to_pil_image(img_tensor)
            img.save(os.path.join(pretrain_img_dir, f"{idx}.png"))
            idx += 1

    torch.save(pretrain_dataset, os.path.join(output_dir, "pretrain.pt"))

# Visualize n random images from the processed images
def visualize_processed_images(img_dir, num_images=5):
    img_files = sorted(os.listdir(img_dir))[:num_images]
    fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
    for ax, img_file in zip(axes, img_files):
        img = Image.open(os.path.join(img_dir, img_file))
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(img_file)
    plt.show()

visualize_processed_images(pretrain_img_dir, num_images=N)

## 4. Dataset retrieval and transformation for Supervised fine-tuning

In [None]:
from google.colab import drive # Mount Google Drive to access dataset
drive.mount('/content/gdrive/', force_remount=True)

In [None]:
import zipfile
import shutil
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.transforms.functional as F
from PIL import Image

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
N = 10

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

output_dir = os.path.join(os.getcwd(), "datasets")
finetune_img_dir = os.path.join(output_dir, "finetune")
os.makedirs(finetune_img_dir, exist_ok=True)

dataset_path_compressed = "/content/gdrive/MyDrive/GenshinImageClassifier/dataset.zip"

if os.path.exists(os.path.join(output_dir, "finetune.pt")):
    print("Finetune dataset already exists. Loading...")
    finetune_dataset = torch.load(os.path.join(output_dir, "finetune.pt"), weights_only=False)
else:
    if not os.path.exists(dataset_path_compressed):
        print(f"Dataset file at {dataset_path_compressed} does not exist. Please download it or update the path.")

    if not os.path.exists(os.getcwd() + "/tmp/"):
        os.makedirs(os.getcwd() + "/tmp/")

    print(f"Copying dataset file to /tmp/ directory...")
    shutil.copy(dataset_path_compressed, os.getcwd() + "/tmp/dataset.zip")

    with zipfile.ZipFile(os.getcwd() + "/tmp/dataset.zip", 'r') as zip_ref:
        zip_ref.extractall(os.getcwd() + "/tmp/")

    dataset_path = os.path.join(os.getcwd(), "tmp", "dataset")
    if not os.path.exists(dataset_path):
        print(f"Dataset path {dataset_path} does not exist. Please check the extraction.")

    finetune_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

    # Save processed images as PNG
    for i in range(len(finetune_dataset)):
        img_tensor, _ = finetune_dataset[i]
        img = F.to_pil_image(img_tensor)
        img.save(os.path.join(finetune_img_dir, f"{i+1}.png"))

    torch.save(finetune_dataset, os.path.join(output_dir, "finetune.pt"))

def visualize_processed_images(img_dir, num_images=5):
    img_files = sorted(os.listdir(img_dir))[:num_images]
    fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
    for ax, img_file in zip(axes, img_files):
        img = Image.open(os.path.join(img_dir, img_file))
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(img_file)
    plt.show()

visualize_processed_images(finetune_img_dir, num_images=N)

### 5. Zip and Save Preprocessed Images and Tensors

In [None]:
import zipfile
import os

datasets_dir = os.path.join(os.getcwd(), "datasets")
zip_path = os.path.join(os.getcwd(), "datasets.zip")

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, dirs, files in os.walk(datasets_dir):
        for file in files:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, datasets_dir)
            zipf.write(file_path, arcname)
print(f"Zipped datasets folder to {zip_path}")