In [None]:
# Cell 1: Install and Import Dependencies
import sys
import subprocess

# Install required packages
def install_package(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

packages = [
    "tensorflow>=2.8.0",
    "opencv-python",
    "matplotlib",
    "tqdm",
    "Pillow",
    "numpy",
    "scikit-image"
]

for package in packages:
    try:
        install_package(package)
    except:
        print(f"Warning: Could not install {package}")

print("Installation complete!")

In [None]:
# Cell 2: Import Libraries
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import time
from tqdm import tqdm
import random
from PIL import Image
from PIL import ImageFile
import glob
from sklearn.model_selection import train_test_split

# Configure PIL for large images
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True

print("TensorFlow version:", tf.__version__)
print("GPU Available:", tf.config.list_physical_devices('GPU'))

In [None]:
# Cell 3: Configure GPU and Memory Settings
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(f"{len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
    except RuntimeError as e:
        print(f"GPU configuration error: {e}")

In [None]:
# Cell 4: Define Configuration and Constants
class Config:
    # Data paths
    DATA_DIR = "/RegWSI_pass1"
    OUTPUT_DIR = "./pix2pixhd_output"
    CHECKPOINT_DIR = "./pix2pixhd_checkpoints"
    SAMPLE_DIR = "./pix2pixhd_samples"
    
    # Training parameters
    BATCH_SIZE = 1  # Pix2PixHD typically uses batch size 1
    IMG_HEIGHT = 512  # Higher resolution for pix2pixHD
    IMG_WIDTH = 512
    PATCH_SIZE = 512  # Larger patches for better quality
    STRIDE = 256  # 50% overlap
    EPOCHS = 200
    
    # Model parameters
    LAMBDA_L1 = 100
    LAMBDA_FEAT = 10  # Feature matching loss weight
    N_LAYERS_D = 3  # Discriminator layers
    NUM_D = 2  # Number of discriminators (multi-scale)
    
    # Optimization
    LEARNING_RATE = 0.0002
    BETA1 = 0.5
    BETA2 = 0.999
    
    # Checkpointing
    CHECKPOINT_INTERVAL = 10
    SAMPLE_INTERVAL = 5

config = Config()

# Create output directories
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
os.makedirs(config.SAMPLE_DIR, exist_ok=True)

In [None]:
# Cell 5: Data Loading and Preprocessing Functions
def load_tiff_image(path):
    """Load TIFF image using PIL and convert to RGB"""
    try:
        img = Image.open(path).convert("RGB")
        return np.array(img)
    except Exception as e:
        print(f"Error loading {path}: {e}")
        return None

def load_image_pair(pair_dir):
    """Load warped_source.tiff and target.tiff from a pair directory"""
    source_path = os.path.join(pair_dir, "warped_source.tiff")
    target_path = os.path.join(pair_dir, "target.tiff")
    
    if not os.path.exists(source_path) or not os.path.exists(target_path):
        print(f"Missing files in {pair_dir}")
        return None, None
    
    source = load_tiff_image(source_path)
    target = load_tiff_image(target_path)
    
    if source is None or target is None:
        return None, None
    
    # Ensure same dimensions
    if source.shape != target.shape:
        print(f"Resizing images in {pair_dir}")
        target = cv2.resize(target, (source.shape[1], source.shape[0]))
    
    return source, target

def create_patches(image, patch_size, stride):
    """Create overlapping patches from image"""
    h, w = image.shape[:2]
    patches = []
    
    for y in range(0, h - patch_size + 1, stride):
        for x in range(0, w - patch_size + 1, stride):
            patch = image[y:y + patch_size, x:x + patch_size]
            if patch.shape[:2] == (patch_size, patch_size):
                patches.append(patch)
    
    return patches

def load_all_data():
    """Load all image pairs and create patches"""
    print("Loading data from all pairs...")
    
    all_source_patches = []
    all_target_patches = []
    
    # Get all pair directories
    pair_dirs = [os.path.join(config.DATA_DIR, f"Pair{i}") for i in range(1, 21)]
    
    for pair_dir in tqdm(pair_dirs, desc="Loading pairs"):
        if not os.path.exists(pair_dir):
            print(f"Warning: {pair_dir} does not exist")
            continue
        
        source, target = load_image_pair(pair_dir)
        if source is None or target is None:
            continue
        
        print(f"Processing {pair_dir}: {source.shape}")
        
        # Create patches
        source_patches = create_patches(source, config.PATCH_SIZE, config.STRIDE)
        target_patches = create_patches(target, config.PATCH_SIZE, config.STRIDE)
        
        print(f"Created {len(source_patches)} patches from {pair_dir}")
        
        all_source_patches.extend(source_patches)
        all_target_patches.extend(target_patches)
    
    print(f"Total patches: {len(all_source_patches)}")
    return np.array(all_source_patches), np.array(all_target_patches)

In [None]:
# Cell 6: Load and Prepare Dataset
print("Loading dataset...")
source_patches, target_patches = load_all_data()

print(f"Source patches shape: {source_patches.shape}")
print(f"Target patches shape: {target_patches.shape}")

# Visualize sample patches
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i in range(4):
    idx = random.randint(0, len(source_patches) - 1)
    
    axes[0, i].imshow(source_patches[idx])
    axes[0, i].set_title(f"HE Patch {idx}")
    axes[0, i].axis('off')
    
    axes[1, i].imshow(target_patches[idx])
    axes[1, i].set_title(f"Cd8 Patch {idx}")
    axes[1, i].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(config.SAMPLE_DIR, "sample_patches.png"), dpi=150)
plt.show()

In [None]:
# Cell 7: Data Preprocessing and Dataset Creation
def normalize_image(image):
    """Normalize image to [-1, 1]"""
    return (tf.cast(image, tf.float32) / 127.5) - 1.0

def preprocess_image_pair(source, target):
    """Preprocess source and target images"""
    source = normalize_image(source)
    target = normalize_image(target)
    return source, target

# Create train/validation split
train_source, val_source, train_target, val_target = train_test_split(
    source_patches, target_patches, test_size=0.2, random_state=42
)

print(f"Train samples: {len(train_source)}")
print(f"Validation samples: {len(val_source)}")

# Create TensorFlow datasets
AUTOTUNE = tf.data.AUTOTUNE

train_dataset = tf.data.Dataset.from_tensor_slices((train_source, train_target))
train_dataset = train_dataset.map(preprocess_image_pair, num_parallel_calls=AUTOTUNE)
train_dataset = train_dataset.shuffle(1000)
train_dataset = train_dataset.batch(config.BATCH_SIZE)
train_dataset = train_dataset.prefetch(AUTOTUNE)

val_dataset = tf.data.Dataset.from_tensor_slices((val_source, val_target))
val_dataset = val_dataset.map(preprocess_image_pair, num_parallel_calls=AUTOTUNE)
val_dataset = val_dataset.batch(config.BATCH_SIZE)
val_dataset = val_dataset.prefetch(AUTOTUNE)

In [None]:
# Cell 8: Pix2PixHD Generator Architecture
def residual_block(x, filters, kernel_size=3):
    """Residual block for generator"""
    shortcut = x
    
    x = layers.Conv2D(filters, kernel_size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    x = layers.Conv2D(filters, kernel_size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    x = layers.Add()([x, shortcut])
    x = layers.ReLU()(x)
    
    return x

def global_generator():
    """Global generator network for pix2pixHD"""
    inputs = layers.Input(shape=[config.IMG_HEIGHT, config.IMG_WIDTH, 3])
    
    # Initial convolution
    x = layers.Conv2D(64, 7, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    # Downsampling
    x = layers.Conv2D(128, 3, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    x = layers.Conv2D(256, 3, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    x = layers.Conv2D(512, 3, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    # Residual blocks
    for _ in range(9):
        x = residual_block(x, 512)
    
    # Upsampling
    x = layers.Conv2DTranspose(256, 3, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    x = layers.Conv2DTranspose(128, 3, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    x = layers.Conv2DTranspose(64, 3, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    # Output layer
    outputs = layers.Conv2D(3, 7, padding='same', activation='tanh')(x)
    
    return keras.Model(inputs=inputs, outputs=outputs, name='global_generator')

In [None]:
# Cell 9: Multi-Scale Discriminator
def discriminator_block(x, filters, stride=2, normalization=True):
    """Discriminator block"""
    x = layers.Conv2D(filters, 4, strides=stride, padding='same')(x)
    if normalization:
        x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    return x

def create_discriminator(name):
    """Create a single discriminator"""
    input_img = layers.Input(shape=[config.IMG_HEIGHT, config.IMG_WIDTH, 3])
    target_img = layers.Input(shape=[config.IMG_HEIGHT, config.IMG_WIDTH, 3])
    
    x = layers.Concatenate()([input_img, target_img])
    
    x = discriminator_block(x, 64, normalization=False)
    x = discriminator_block(x, 128)
    x = discriminator_block(x, 256)
    x = discriminator_block(x, 512, stride=1)
    
    x = layers.Conv2D(1, 4, strides=1, padding='same')(x)
    
    return keras.Model(inputs=[input_img, target_img], outputs=x, name=name)

def multi_scale_discriminator():
    """Create multi-scale discriminator"""
    discriminators = []
    
    for i in range(config.NUM_D):
        disc = create_discriminator(f'discriminator_{i}')
        discriminators.append(disc)
    
    return discriminators



In [None]:
# Cell 10: Loss Functions
def discriminator_loss(real_output, fake_output):
    """Discriminator loss"""
    real_loss = tf.reduce_mean(tf.square(real_output - 1))
    fake_loss = tf.reduce_mean(tf.square(fake_output))
    total_loss = 0.5 * (real_loss + fake_loss)
    return total_loss

def generator_adversarial_loss(fake_output):
    """Generator adversarial loss"""
    return tf.reduce_mean(tf.square(fake_output - 1))

def feature_matching_loss(real_features, fake_features):
    """Feature matching loss"""
    loss = 0
    for real_feat, fake_feat in zip(real_features, fake_features):
        loss += tf.reduce_mean(tf.abs(real_feat - fake_feat))
    return loss

def l1_loss(real_image, fake_image):
    """L1 loss for pixel-wise comparison"""
    return tf.reduce_mean(tf.abs(real_image - fake_image))

def generator_loss(fake_disc_outputs, real_features, fake_features, real_image, fake_image):
    """Combined generator loss"""
    # Adversarial loss
    adv_loss = 0
    for fake_output in fake_disc_outputs:
        adv_loss += generator_adversarial_loss(fake_output)
    
    # Feature matching loss
    feat_loss = 0
    for real_feat, fake_feat in zip(real_features, fake_features):
        feat_loss += feature_matching_loss(real_feat, fake_feat)
    
    # L1 loss
    pixel_loss = l1_loss(real_image, fake_image)
    
    total_loss = adv_loss + config.LAMBDA_FEAT * feat_loss + config.LAMBDA_L1 * pixel_loss
    
    return total_loss, adv_loss, feat_loss, pixel_loss

In [None]:
# Cell 11: Create Models and Optimizers
print("Creating models...")

# Create generator
generator = global_generator()
print("Generator created")

# Create discriminators
discriminators = multi_scale_discriminator()
print(f"Created {len(discriminators)} discriminators")

# Create optimizers
gen_optimizer = keras.optimizers.Adam(config.LEARNING_RATE, beta_1=config.BETA1, beta_2=config.BETA2)
disc_optimizers = [keras.optimizers.Adam(config.LEARNING_RATE, beta_1=config.BETA1, beta_2=config.BETA2) 
                   for _ in range(config.NUM_D)]

print("Models and optimizers created")

# Print model summaries
print("\nGenerator Summary:")
generator.summary()

print(f"\nDiscriminator Summary (showing first discriminator):")
discriminators[0].summary()


In [None]:
# Cell 12: Training Step Functions
@tf.function
def train_discriminator_step(real_image, input_image, fake_image, discriminator, optimizer):
    """Training step for a single discriminator"""
    with tf.GradientTape() as tape:
        # Get discriminator outputs
        real_output = discriminator([input_image, real_image], training=True)
        fake_output = discriminator([input_image, fake_image], training=True)
        
        # Calculate discriminator loss
        disc_loss = discriminator_loss(real_output, fake_output)
    
    # Calculate and apply gradients
    gradients = tape.gradient(disc_loss, discriminator.trainable_variables)
    optimizer.apply_gradients(zip(gradients, discriminator.trainable_variables))
    
    return disc_loss

@tf.function
def train_generator_step(real_image, input_image):
    """Training step for generator"""
    with tf.GradientTape() as tape:
        # Generate fake image
        fake_image = generator(input_image, training=True)
        
        # Get discriminator outputs and features
        fake_disc_outputs = []
        real_features = []
        fake_features = []
        
        for discriminator in discriminators:
            fake_output = discriminator([input_image, fake_image], training=True)
            fake_disc_outputs.append(fake_output)
            
            # For feature matching, we would need to extract intermediate features
            # This is simplified for brevity
            real_features.append([])
            fake_features.append([])
        
        # Calculate generator loss
        gen_loss, adv_loss, feat_loss, pixel_loss = generator_loss(
            fake_disc_outputs, real_features, fake_features, real_image, fake_image)
    
    # Calculate and apply gradients
    gradients = tape.gradient(gen_loss, generator.trainable_variables)
    gen_optimizer.apply_gradients(zip(gradients, generator.trainable_variables))
    
    return gen_loss, adv_loss, feat_loss, pixel_loss, fake_image


In [None]:
# Cell 13: Training Utilities
def generate_and_save_samples(generator, test_input, test_target, epoch, save_dir):
    """Generate and save sample images"""
    predictions = generator(test_input, training=False)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    images = [test_input[0], test_target[0], predictions[0]]
    titles = ['Input (HE)', 'Target (Cd8)', 'Generated (Cd8)']
    
    for i, (img, title) in enumerate(zip(images, titles)):
        axes[i].imshow(img * 0.5 + 0.5)  # Denormalize
        axes[i].set_title(title)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'epoch_{epoch:04d}.png'), dpi=150, bbox_inches='tight')
    plt.close()

def save_model_checkpoint(generator, discriminators, epoch, save_dir):
    """Save model checkpoint"""
    checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}')
    os.makedirs(checkpoint_path, exist_ok=True)
    
    # Save generator
    generator.save_weights(os.path.join(checkpoint_path, 'generator.h5'))
    
    # Save discriminators
    for i, disc in enumerate(discriminators):
        disc.save_weights(os.path.join(checkpoint_path, f'discriminator_{i}.h5'))
    
    print(f"Checkpoint saved at epoch {epoch}")

In [None]:
# Cell 14: Main Training Loop
def train_pix2pixhd(train_dataset, val_dataset, epochs):
    """Main training function"""
    print("Starting training...")
    
    # Get test samples for visualization
    test_batch = next(iter(val_dataset))
    test_input, test_target = test_batch
    
    # Generate initial sample
    generate_and_save_samples(generator, test_input, test_target, 0, config.SAMPLE_DIR)
    
    # Training loop
    for epoch in range(1, epochs + 1):
        start_time = time.time()
        
        # Training metrics
        gen_loss_avg = tf.keras.metrics.Mean()
        disc_loss_avg = tf.keras.metrics.Mean()
        
        # Train on all batches
        for batch_idx, (input_image, real_image) in enumerate(tqdm(train_dataset, desc=f"Epoch {epoch}")):
            
            # Generate fake image
            fake_image = generator(input_image, training=False)
            
            # Train discriminators
            total_disc_loss = 0
            for disc, optimizer in zip(discriminators, disc_optimizers):
                disc_loss = train_discriminator_step(real_image, input_image, fake_image, disc, optimizer)
                total_disc_loss += disc_loss
            
            # Train generator
            gen_loss, adv_loss, feat_loss, pixel_loss, _ = train_generator_step(real_image, input_image)
            
            # Update metrics
            gen_loss_avg.update_state(gen_loss)
            disc_loss_avg.update_state(total_disc_loss / len(discriminators))
        
        # End of epoch
        epoch_time = time.time() - start_time
        
        print(f"Epoch {epoch}/{epochs} - "
              f"Gen Loss: {gen_loss_avg.result():.4f}, "
              f"Disc Loss: {disc_loss_avg.result():.4f}, "
              f"Time: {epoch_time:.2f}s")
        
        # Generate samples
        if epoch % config.SAMPLE_INTERVAL == 0:
            generate_and_save_samples(generator, test_input, test_target, epoch, config.SAMPLE_DIR)
        
        # Save checkpoint
        if epoch % config.CHECKPOINT_INTERVAL == 0:
            save_model_checkpoint(generator, discriminators, epoch, config.CHECKPOINT_DIR)
    
    # Save final model
    save_model_checkpoint(generator, discriminators, epochs, config.CHECKPOINT_DIR)
    generator.save(os.path.join(config.OUTPUT_DIR, 'final_generator.h5'))
    
    print("Training completed!")

In [None]:
# Cell 15: Start Training
if __name__ == "__main__":
    try:
        print("Starting Pix2PixHD training...")
        train_pix2pixhd(train_dataset, val_dataset, config.EPOCHS)
    except Exception as e:
        print(f"Training error: {e}")
        import traceback
        traceback.print_exc()
        
        # Try to save current model state
        try:
            generator.save(os.path.join(config.OUTPUT_DIR, 'interrupted_generator.h5'))
            print("Saved interrupted model")
        except:
            print("Could not save interrupted model")