[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Guidosalimbeni/aicaricaturist/blob/main/lora_img2img/utils.ipynb)

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

def test_transformations(image_path, method="standard"):
    """
    Test different grayscale and normalization approaches
    
    Methods:
    - "standard": Simple grayscale + [-1,1] normalization
    - "minmax": MinMax scaling to [-1,1]
    - "meanstd": Mean-std normalization
    - "raw": Just convert to grayscale, no normalization
    """
    # Load image
    img = Image.open(image_path).convert('L')
    
    # Base transform (just resize and convert to tensor)
    base_transform = transforms.Compose([
        transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
    ])
    
    # Apply base transform
    tensor = base_transform(img)
    
    # Apply different normalizations based on method
    if method == "standard":
        # Scale to [-1, 1]
        tensor = 2 * tensor - 1
    elif method == "minmax":
        # MinMax scaling to [-1, 1]
        min_val = tensor.min()
        max_val = tensor.max()
        tensor = 2 * (tensor - min_val) / (max_val - min_val) - 1
    elif method == "meanstd":
        # Mean-std normalization
        mean = tensor.mean()
        std = tensor.std()
        tensor = (tensor - mean) / std
    # "raw" method doesn't need additional normalization
    
    # Repeat channels for model compatibility
    tensor_3ch = tensor.repeat(3, 1, 1)
    
    # Visualize results
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original grayscale
    axes[0].imshow(img, cmap='gray')
    axes[0].set_title('Original Grayscale')
    axes[0].axis('off')
    
    # Normalized (single channel)
    axes[1].imshow(tensor.squeeze(), cmap='gray')
    axes[1].set_title(f'Normalized ({method})\nRange: [{tensor.min():.2f}, {tensor.max():.2f}]')
    axes[1].axis('off')
    
    # 3-channel repeated
    axes[2].imshow(tensor_3ch.permute(1, 2, 0))
    axes[2].set_title('3-Channel Repeated')
    axes[2].axis('off')
    
    plt.tight_layout()
    return tensor, tensor_3ch

# Example usage:
image_path = "/content/drive/MyDrive/caricature Project Diffusion/paired_caricature/001_f.png"
# Try different methods
for method in ["standard", "minmax", "meanstd", "raw"]:
    print(f"\nTesting {method} method:")
    single_ch, three_ch = test_transformations(image_path, method)
    plt.show()