## Color Normalization

Histopathology images can have significant color variations due to differences in staining and scanning. Color normalization aims to reduce this variation.

**Our Goals:**
1.  Understand the need for color normalization.
2.  Implement a simple stain normalization technique.

In [None]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# For this example, we'll create two "stained" images to simulate variation.
# In a real scenario, you would use real image tiles.
def create_stained_image(r_stain, g_stain, b_stain):
    # Create a base image (e.g., representing tissue)
    base = np.ones((100, 100, 3), dtype=np.uint8) * 250
    base[20:80, 20:80, :] = [200, 150, 200] # A "tissue" area
    
    # Apply "stain"
    stained = base.astype(np.float32)
    stained[:,:,0] *= r_stain
    stained[:,:,1] *= g_stain
    stained[:,:,2] *= b_stain
    
    return Image.fromarray(np.clip(stained, 0, 255).astype(np.uint8))

source_image = create_stained_image(0.9, 0.7, 0.85)
target_image = create_stained_image(1.0, 0.8, 0.7)

fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(source_image)
axes[0].set_title('Source Image')
axes[1].imshow(target_image)
axes[1].set_title('Target Image (Reference)')
plt.show()

### 1. Simple Normalization using Mean and Standard Deviation

A common method is to scale the source image's color channels to match the mean and standard deviation of a target (reference) image.

In [None]:
def normalize_color(source, target):
    source_arr = np.array(source, dtype=np.float32)
    target_arr = np.array(target, dtype=np.float32)
    
    normalized_arr = np.zeros_like(source_arr)
    
    for i in range(3): # For each channel (R, G, B)
        source_channel = source_arr[:,:,i]
        target_channel = target_arr[:,:,i]
        
        # Get stats
        src_mean, src_std = np.mean(source_channel), np.std(source_channel)
        tgt_mean, tgt_std = np.mean(target_channel), np.std(target_channel)
        
        # Normalize
        normalized_channel = (source_channel - src_mean) / src_std * tgt_std + tgt_mean
        normalized_arr[:,:,i] = normalized_channel
        
    return Image.fromarray(np.clip(normalized_arr, 0, 255).astype(np.uint8))

normalized_image = normalize_color(source_image, target_image)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(source_image)
axes[0].set_title('Source')
axes[1].imshow(target_image)
axes[1].set_title('Target')
axes[2].imshow(normalized_image)
axes[2].set_title('Normalized')
plt.show()

## ✅ Final Check

Let's check the mean of the normalized image's channels. They should be closer to the target's channel means.

In [None]:
source_means = np.mean(np.array(source_image), axis=(0,1))
target_means = np.mean(np.array(target_image), axis=(0,1))
normalized_means = np.mean(np.array(normalized_image), axis=(0,1))

print(f"Source means: {source_means}")
print(f"Target means: {target_means}")
print(f"Normalized means: {normalized_means}")

# Check if normalized means are closer to target means
assert np.all(np.abs(normalized_means - target_means) < np.abs(source_means - target_means))

print("\nSUCCESS: Normalized image stats are closer to the target.")