In [None]:
import os
import staintools
from PIL import Image
import numpy as np
import cv2 as cv
import torch, torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from skimage import io
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from pathlib import Path

# Params

In [None]:
source_path = '../data/VPC/multiscale_patches_Train/'
normalized_path = '../data/VPC/Normalized/'
reference_img = '../data/Reference.jpg'
method = 'vahadane' # or 'macenko'

# Normalizing with staintools

In [None]:
os.environ.pop("CUDA_VISIBLE_DEVICES", None)  
print("torch:", torch.__version__)
print("compiled_with_cuda:", torch.version.cuda)
print("cuda.is_available:", torch.cuda.is_available())
print("device_count:", torch.cuda.device_count())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    try:
        print("Using GPU:", torch.cuda.get_device_name(0))
    except Exception:
        print("GPU visible but get_device_name failed (continuing).")
else:
    print("⚠ Running on CPU")

In [None]:
# NumPy compat for old spams
np.bool = np.bool_

# simple tissue check: treat pixels with high luminosity as background
def has_tissue(img, min_frac=0.10, lum_thresh=0.8):
    # img uint8 RGB -> [0,1] float
    I = img.astype(np.float32) / 255.0
    # luminance approximation
    lum = 0.299 * I[...,0] + 0.587 * I[...,1] + 0.114 * I[...,2]
    tissue_mask = (lum < lum_thresh)  # True = tissue-ish
    return tissue_mask.mean() >= min_frac

In [None]:
# fit normalizer once
target = staintools.read_image(reference_img)
normalizer = staintools.StainNormalizer(method=method)
normalizer.fit(target)

Path(normalized_path).mkdir(parents=True, exist_ok=True)

# recurse through your hierarchical folder and preserve structure
for root, _, files in os.walk(source_path):
    for img_file in files:
        if not img_file.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff')):
            continue

        src_path = os.path.join(root, img_file)
        rel_path = os.path.relpath(src_path, source_path)
        dst_path = os.path.join(normalized_path, rel_path)
        os.makedirs(os.path.dirname(dst_path), exist_ok=True)

        img = staintools.read_image(src_path)

        # skip if no tissue detected
        if not has_tissue(img, min_frac=0.10, lum_thresh=0.8):
            continue

        norm_img = normalizer.transform(img)
        io.imsave(dst_path, norm_img)