In [None]:
# Use shared utils for persistent data dir and color normalization samples
import sys, os
from pathlib import Path
try:
    from shared import utils as u
except ImportError:
    repo_url = "https://github.com/anand-indx/dp-t25.git"; dest = "/content/dp-t25"
    if 'google.colab' in sys.modules and not os.path.exists(dest):
        import subprocess
        subprocess.run(['git', 'clone', '--depth', '1', repo_url, dest], check=False)
        sys.path.insert(0, dest)
    else:
        # Ensure repo root is on sys.path when running locally
        sys.path.insert(0, str(Path.cwd().parents[2]))
    from shared import utils as u

DATA_DIR = u.get_data_dir()
color_data_dir, norm_params = u.ensure_color_normalization_samples(DATA_DIR)
print(f"📁 Color normalization samples located at: {color_data_dir}")

In [None]:
# Optional: Download sample tiles/WSI archives into DATA_DIR (set TILES_ZIP_URL or enable defaults)
import os
import requests
from pathlib import Path
from tqdm import tqdm

def download_file_with_progress(url: str, dst: Path, description: str = "Downloading") -> bool:
    try:
        with requests.get(url, stream=True, timeout=60) as r:
            r.raise_for_status()
            total = int(r.headers.get('content-length', 0))
            with open(dst, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True, desc=description) as pbar:
                for chunk in r.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)
                        pbar.update(len(chunk))
        return True
    except Exception as e:
        print(f"⚠️ Download failed: {url} → {dst}\n   {e}")
        return False

# Configuration: set a tiles ZIP URL to download (optional). You can also place any .zip file into DATA_DIR manually.
TILES_ZIP_URL = os.environ.get('TILES_ZIP_URL', '').strip()  # e.g., https://example.com/tiles.zip

# Known demo WSI (small) from OpenSlide (for reference); contains a single SVS, not tiles.
DEMO_WSI_URL = 'https://openslide.cs.cmu.edu/download/openslide-testdata/Aperio/CMU-1-Small-Region.svs'

if TILES_ZIP_URL:
    tiles_zip_path = DATA_DIR / Path(TILES_ZIP_URL).name
    if not tiles_zip_path.exists():
        ok = download_file_with_progress(TILES_ZIP_URL, tiles_zip_path, description=f"Downloading {tiles_zip_path.name}")
        if ok:
            print(f"✅ Downloaded tiles archive to {tiles_zip_path}")
    else:
        print(f"✅ Tiles archive already exists at {tiles_zip_path}")
else:
    print("ℹ️ No TILES_ZIP_URL provided; skipping tile archive download. Place a .zip in DATA_DIR to use real tiles.")

## Color Normalization

Histopathology images can have significant color variations due to differences in staining and scanning. Color normalization aims to reduce this variation.

**Our Goals:**
1.  Understand the need for color normalization.
2.  Implement a simple stain normalization technique.

In [None]:
# Optional: Use real tiles if available by unzipping archives in DATA_DIR and sampling images
import zipfile, random
from PIL import Image
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

TILES_DIR = DATA_DIR / "tiles"
TILES_DIR.mkdir(parents=True, exist_ok=True)

# Unzip any .zip archives found in DATA_DIR into TILES_DIR
for z in list(DATA_DIR.glob("*.zip")):
    try:
        with zipfile.ZipFile(z, 'r') as zip_ref:
            zip_ref.extractall(TILES_DIR)
            print(f"✅ Extracted {z.name} -> {TILES_DIR}")
    except Exception as e:
        print(f"⚠️ Failed to extract {z.name}: {e}")

# Collect candidate image files
image_exts = {".png", ".jpg", ".jpeg", ".tif", ".tiff"}
candidates = [p for p in TILES_DIR.rglob("*") if p.suffix.lower() in image_exts]

source_image = None
target_image = None

if len(candidates) >= 2:
    chosen = random.sample(candidates, 2)
    try:
        source_image = Image.open(chosen[0]).convert('RGB')
        target_image = Image.open(chosen[1]).convert('RGB')
        print(f"🧪 Using real tiles: {chosen[0].name} (source), {chosen[1].name} (target)")
    except Exception as e:
        print(f"⚠️ Failed to load chosen tiles: {e}")

# If not enough tiles, the next cell will generate synthetic images

In [None]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# Create synthetic stained images only if real tiles weren't found in the previous step
if 'source_image' not in globals() or source_image is None or 'target_image' not in globals() or target_image is None:
    # For this example, we'll create two "stained" images to simulate variation.
    # In a real scenario, you would use real image tiles.
    def create_stained_image(r_stain, g_stain, b_stain):
        # Create a base image (e.g., representing tissue)
        base = np.ones((100, 100, 3), dtype=np.uint8) * 250
        base[20:80, 20:80, :] = [200, 150, 200]  # A "tissue" area
        
        # Apply "stain"
        stained = base.astype(np.float32)
        stained[:, :, 0] *= r_stain
        stained[:, :, 1] *= g_stain
        stained[:, :, 2] *= b_stain
        
        return Image.fromarray(np.clip(stained, 0, 255).astype(np.uint8))

    source_image = create_stained_image(0.9, 0.7, 0.85)
    target_image = create_stained_image(1.0, 0.8, 0.7)

fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(source_image); axes[0].set_title('Source Image'); axes[0].axis('off')
axes[1].imshow(target_image); axes[1].set_title('Target Image (Reference)'); axes[1].axis('off')
plt.show()

### 1. Simple Normalization using Mean and Standard Deviation

A common method is to scale the source image's color channels to match the mean and standard deviation of a target (reference) image.

In [None]:
def normalize_color(source, target):
    src_img = source.convert('RGB') if hasattr(source, 'mode') else Image.fromarray(source).convert('RGB')
    tgt_img = target.convert('RGB') if hasattr(target, 'mode') else Image.fromarray(target).convert('RGB')

    source_arr = np.array(src_img, dtype=np.float32)
    target_arr = np.array(tgt_img, dtype=np.float32)
    
    normalized_arr = np.zeros_like(source_arr)
    eps = 1e-6
    
    for i in range(3):  # For each channel (R, G, B)
        source_channel = source_arr[:, :, i]
        target_channel = target_arr[:, :, i]
        
        # Get stats
        src_mean, src_std = float(np.mean(source_channel)), float(np.std(source_channel))
        tgt_mean, tgt_std = float(np.mean(target_channel)), float(np.std(target_channel))
        
        # Normalize (guard std dev)
        scale = tgt_std / max(src_std, eps)
        normalized_channel = (source_channel - src_mean) * scale + tgt_mean
        normalized_arr[:, :, i] = normalized_channel
        
    return Image.fromarray(np.clip(normalized_arr, 0, 255).astype(np.uint8))

normalized_image = normalize_color(source_image, target_image)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(source_image); axes[0].set_title('Source')
axes[1].imshow(target_image); axes[1].set_title('Target')
axes[2].imshow(normalized_image); axes[2].set_title('Normalized')
for ax in axes: ax.axis('off')
plt.show()

## ✅ Final Check

Let's check the mean of the normalized image's channels. They should be closer to the target's channel means.

In [None]:
def channel_means(img):
    arr = np.array(img.convert('RGB'), dtype=np.float32)
    return np.mean(arr, axis=(0, 1))

source_means = channel_means(source_image)
target_means = channel_means(target_image)
normalized_means = channel_means(normalized_image)

print(f"Source means: {np.round(source_means, 2)}")
print(f"Target means: {np.round(target_means, 2)}")
print(f"Normalized means: {np.round(normalized_means, 2)}")

# Check if normalized means are closer to target means (tolerant for real data)
closer = np.linalg.norm(normalized_means - target_means) < np.linalg.norm(source_means - target_means)
assert closer, "Normalized image stats are not closer to target; check inputs or tile selection."

print("\nSUCCESS: Normalized image stats are closer to the target.")