# BTC Change Detection Inference Pipeline

This notebook demonstrates how to perform change detection inference using the BTC (Be The Change) model. 

The pipeline includes:
1. Loading and converting TIFF images to PNG
2. Preprocessing images with normalization
3. Loading the pre-trained BTC model checkpoint
4. Running inference to detect changes
5. Visualizing results at each step

In [1]:
# Import required libraries
import sys
import os
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Add the BTC directory to Python path to import BTC modules
btc_path = Path(__file__).parent.parent / "BTC"
sys.path.append(str(btc_path))

# Import BTC-specific modules
from models.finetune_framework import FinetuneFramework
from torchmetrics import MetricCollection
from torchmetrics.classification import (
    BinaryF1Score,
    BinaryRecall,
    BinaryPrecision,
    BinaryJaccardIndex,
)

print("All libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

  from .autonotebook import tqdm as notebook_tqdm


NameError: name '__file__' is not defined

## Configuration

Set up the paths and parameters for the inference pipeline.

In [None]:
# Configuration parameters (from BTC-B.yaml)
CONFIG = {
    'img_size': 256,
    'normalize_mean': [0.485, 0.456, 0.406],  # ImageNet mean
    'normalize_std': [0.229, 0.224, 0.225],   # ImageNet std
    'model_checkpoint': 'blaz-r/BTC-B_oscd96',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

# Input image paths - MODIFY THESE PATHS TO YOUR IMAGES
image_a_path = "path/to/your/image_a.tiff"  # Before image
image_b_path = "path/to/your/image_b.tiff"  # After image

# Output directory for converted PNG images
output_dir = Path("./converted_images")
output_dir.mkdir(exist_ok=True)

print("Configuration loaded:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")
print(f"\nOutput directory: {output_dir}")
print(f"Device: {CONFIG['device']}")

## Step 1: Convert TIFF Images to PNG

First, we'll load the TIFF images and convert them to PNG format for easier handling.

In [None]:
def convert_tiff_to_png(tiff_path, output_dir, filename):
    """
    Convert TIFF image to PNG format and resize to 256x256 if needed.
    
    Args:
        tiff_path: Path to input TIFF file
        output_dir: Directory to save PNG file
        filename: Output filename (without extension)
    
    Returns:
        Path to converted PNG file
    """
    try:
        # Load TIFF image
        img = Image.open(tiff_path)
        print(f"Original image size: {img.size}")
        print(f"Original image mode: {img.mode}")
        
        # Convert to RGB if necessary
        if img.mode != 'RGB':
            img = img.convert('RGB')
            print(f"Converted to RGB mode")
        
        # Resize to 256x256 if not already
        if img.size != (CONFIG['img_size'], CONFIG['img_size']):
            img = img.resize((CONFIG['img_size'], CONFIG['img_size']), Image.LANCZOS)
            print(f"Resized to {CONFIG['img_size']}x{CONFIG['img_size']}")
        
        # Save as PNG
        png_path = output_dir / f"{filename}.png"
        img.save(png_path, 'PNG')
        print(f"Saved PNG to: {png_path}")
        
        return png_path, np.array(img)
    
    except Exception as e:
        print(f"Error converting {tiff_path}: {e}")
        return None, None

# Convert both images
print("Converting Image A (before):")
png_a_path, img_a_array = convert_tiff_to_png(image_a_path, output_dir, "image_a")

print("\nConverting Image B (after):")
png_b_path, img_b_array = convert_tiff_to_png(image_b_path, output_dir, "image_b")

In [None]:
# Visualize the converted PNG images
if img_a_array is not None and img_b_array is not None:
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    axes[0].imshow(img_a_array)
    axes[0].set_title('Image A (Before) - PNG')
    axes[0].axis('off')
    
    axes[1].imshow(img_b_array)
    axes[1].set_title('Image B (After) - PNG')
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Image A shape: {img_a_array.shape}")
    print(f"Image B shape: {img_b_array.shape}")
    print(f"Image A pixel range: [{img_a_array.min()}, {img_a_array.max()}]")
    print(f"Image B pixel range: [{img_b_array.min()}, {img_b_array.max()}]")
else:
    print("Error: Could not load images. Please check the file paths.")

## Step 2: Image Preprocessing and Normalization

Now we'll apply the same preprocessing pipeline used by the BTC model:
1. Normalize pixel values to [0, 1] range
2. Apply ImageNet normalization (subtract mean, divide by std)
3. Convert to PyTorch tensors

In [None]:
def preprocess_image(image_array, config):
    """
    Preprocess image according to BTC model requirements.
    
    Args:
        image_array: numpy array of shape (H, W, 3) with values in [0, 255]
        config: configuration dictionary with normalization parameters
    
    Returns:
        preprocessed tensor of shape (1, 3, H, W)
    """
    # Step 1: Convert to float and normalize to [0, 1]
    image_float = image_array.astype(np.float32) / 255.0
    print(f"After [0,1] normalization - range: [{image_float.min():.3f}, {image_float.max():.3f}]")
    
    # Step 2: Apply Albumentations transforms (same as BTC training)
    transform = A.Compose([
        A.Normalize(
            mean=config['normalize_mean'], 
            std=config['normalize_std']
        ),
        ToTensorV2()
    ])
    
    # Apply transforms
    transformed = transform(image=image_float)
    tensor = transformed['image']
    
    print(f"After ImageNet normalization - range: [{tensor.min():.3f}, {tensor.max():.3f}]")
    print(f"Tensor shape: {tensor.shape}")
    
    # Add batch dimension
    tensor = tensor.unsqueeze(0)  # Shape: (1, 3, H, W)
    
    return tensor, image_float

# Preprocess both images
print("Preprocessing Image A:")
tensor_a, normalized_a = preprocess_image(img_a_array, CONFIG)

print("\nPreprocessing Image B:")
tensor_b, normalized_b = preprocess_image(img_b_array, CONFIG)

print(f"\nFinal tensor shapes:")
print(f"Tensor A: {tensor_a.shape}")
print(f"Tensor B: {tensor_b.shape}")

In [None]:
# Visualize preprocessing steps
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Original images
axes[0, 0].imshow(img_a_array)
axes[0, 0].set_title('Original Image A')
axes[0, 0].axis('off')

axes[1, 0].imshow(img_b_array)
axes[1, 0].set_title('Original Image B')
axes[1, 0].axis('off')

# Normalized to [0,1]
axes[0, 1].imshow(normalized_a)
axes[0, 1].set_title('Image A - Normalized [0,1]')
axes[0, 1].axis('off')

axes[1, 1].imshow(normalized_b)
axes[1, 1].set_title('Image B - Normalized [0,1]')
axes[1, 1].axis('off')

# After ImageNet normalization (denormalized for visualization)
def denormalize_for_viz(tensor, mean, std):
    """Denormalize tensor for visualization"""
    tensor_copy = tensor.clone()
    for i, (m, s) in enumerate(zip(mean, std)):
        tensor_copy[i] = tensor_copy[i] * s + m
    return torch.clamp(tensor_copy, 0, 1)

viz_a = denormalize_for_viz(tensor_a[0], CONFIG['normalize_mean'], CONFIG['normalize_std'])
viz_b = denormalize_for_viz(tensor_b[0], CONFIG['normalize_mean'], CONFIG['normalize_std'])

axes[0, 2].imshow(viz_a.permute(1, 2, 0))
axes[0, 2].set_title('Image A - After ImageNet Norm\n(denormalized for viz)')
axes[0, 2].axis('off')

axes[1, 2].imshow(viz_b.permute(1, 2, 0))
axes[1, 2].set_title('Image B - After ImageNet Norm\n(denormalized for viz)')
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

## Step 3: Load Pre-trained BTC Model

Download and load the pre-trained BTC-B model checkpoint from HuggingFace.

In [None]:
# Load the pre-trained BTC model
print(f"Loading model checkpoint: {CONFIG['model_checkpoint']}")
print("This may take a few minutes for the first time...")

try:
    # Create metrics collection (required for model loading)
    metrics = MetricCollection({
        "F1": BinaryF1Score(),
        "Recall": BinaryRecall(),
        "Precision": BinaryPrecision(),
        "cIoU": BinaryJaccardIndex(),
    })
    
    # Load the model from HuggingFace
    model = FinetuneFramework.from_pretrained(
        CONFIG['model_checkpoint'],
        metrics=metrics,
        logger=None
    )
    
    # Move model to appropriate device
    model = model.to(CONFIG['device'])
    model.eval()  # Set to evaluation mode
    
    print(f"✓ Model loaded successfully!")
    print(f"✓ Model moved to device: {CONFIG['device']}")
    print(f"✓ Model set to evaluation mode")
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
except Exception as e:
    print(f"❌ Error loading model: {e}")
    print("Please check your internet connection and model checkpoint name.")

## Step 4: Run Change Detection Inference

Perform inference using the loaded model to detect changes between the two images.

In [None]:
# Prepare input batch for the model
def prepare_batch(tensor_a, tensor_b, device):
    """
    Prepare input batch for BTC model inference.
    
    Args:
        tensor_a: preprocessed tensor for image A
        tensor_b: preprocessed tensor for image B
        device: target device
    
    Returns:
        batch dictionary ready for model input
    """
    batch = {
        'imageA': tensor_a.to(device),
        'imageB': tensor_b.to(device)
    }
    return batch

# Run inference
print("Running change detection inference...")

try:
    with torch.no_grad():  # Disable gradient computation for inference
        # Prepare input batch
        batch = prepare_batch(tensor_a, tensor_b, CONFIG['device'])
        
        print(f"Input shapes:")
        print(f"  Image A: {batch['imageA'].shape}")
        print(f"  Image B: {batch['imageB'].shape}")
        
        # Forward pass through the model
        start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
        end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
        
        if torch.cuda.is_available():
            start_time.record()
        
        # Model inference
        output = model(batch)
        
        if torch.cuda.is_available():
            end_time.record()
            torch.cuda.synchronize()
            inference_time = start_time.elapsed_time(end_time)
            print(f"Inference time: {inference_time:.2f} ms")
        
        print(f"✓ Inference completed successfully!")
        print(f"Output shape: {output.shape}")
        print(f"Output range: [{output.min():.4f}, {output.max():.4f}]")
        
        # Apply sigmoid to get probabilities
        probabilities = torch.sigmoid(output)
        print(f"Probability range: [{probabilities.min():.4f}, {probabilities.max():.4f}]")
        
        # Create binary mask (threshold at 0.5)
        binary_mask = (probabilities > 0.5).float()
        
        # Move results to CPU for visualization
        output_cpu = output.cpu().squeeze()
        prob_cpu = probabilities.cpu().squeeze()
        mask_cpu = binary_mask.cpu().squeeze()
        
        print(f"Results moved to CPU for visualization")
        
except Exception as e:
    print(f"❌ Error during inference: {e}")
    import traceback
    traceback.print_exc()

## Step 5: Visualize Results

Display the inference results including the original images, probability map, and binary change mask.

In [None]:
# Comprehensive visualization of results
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Row 1: Original images and probability map
axes[0, 0].imshow(img_a_array)
axes[0, 0].set_title('Image A (Before)', fontsize=14, fontweight='bold')
axes[0, 0].axis('off')

axes[0, 1].imshow(img_b_array)
axes[0, 1].set_title('Image B (After)', fontsize=14, fontweight='bold')
axes[0, 1].axis('off')

# Probability map (heatmap)
prob_map = axes[0, 2].imshow(prob_cpu.numpy(), cmap='hot', vmin=0, vmax=1)
axes[0, 2].set_title('Change Probability Map', fontsize=14, fontweight='bold')
axes[0, 2].axis('off')
plt.colorbar(prob_map, ax=axes[0, 2], fraction=0.046, pad=0.04)

# Row 2: Binary mask, overlay, and statistics
axes[1, 0].imshow(mask_cpu.numpy(), cmap='gray', vmin=0, vmax=1)
axes[1, 0].set_title('Binary Change Mask\n(Threshold = 0.5)', fontsize=14, fontweight='bold')
axes[1, 0].axis('off')

# Overlay change mask on Image B
overlay = img_b_array.copy()
change_pixels = mask_cpu.numpy() > 0.5
overlay[change_pixels] = [255, 0, 0]  # Red for changes
axes[1, 1].imshow(overlay)
axes[1, 1].set_title('Changes Overlaid on Image B\n(Red = Change)', fontsize=14, fontweight='bold')
axes[1, 1].axis('off')

# Statistics and summary
total_pixels = mask_cpu.numel()
changed_pixels = torch.sum(mask_cpu).item()
change_percentage = (changed_pixels / total_pixels) * 100

stats_text = f"""
Change Detection Results:

Total pixels: {total_pixels:,}
Changed pixels: {changed_pixels:,}
Change percentage: {change_percentage:.2f}%

Model output range: 
  Min: {output_cpu.min():.4f}
  Max: {output_cpu.max():.4f}

Probability range:
  Min: {prob_cpu.min():.4f}
  Max: {prob_cpu.max():.4f}

Mean probability: {prob_cpu.mean():.4f}
"""

axes[1, 2].text(0.05, 0.95, stats_text, transform=axes[1, 2].transAxes, 
                fontsize=12, verticalalignment='top', fontfamily='monospace',
                bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8))
axes[1, 2].set_xlim(0, 1)
axes[1, 2].set_ylim(0, 1)
axes[1, 2].axis('off')
axes[1, 2].set_title('Summary Statistics', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

print("="*60)
print("CHANGE DETECTION COMPLETED SUCCESSFULLY!")
print("="*60)
print(f"✓ Images processed: {CONFIG['img_size']}x{CONFIG['img_size']} pixels")
print(f"✓ Model used: {CONFIG['model_checkpoint']}")
print(f"✓ Device: {CONFIG['device']}")
print(f"✓ Changed pixels detected: {changed_pixels:,} ({change_percentage:.2f}%)")
print("="*60)

## Optional: Save Results

Save the generated masks and visualizations to disk for later use.

In [None]:
# Save results to disk
save_results = True  # Set to False if you don't want to save

if save_results:
    results_dir = Path("./results")
    results_dir.mkdir(exist_ok=True)
    
    try:
        # Save probability map as grayscale image
        prob_img = Image.fromarray((prob_cpu.numpy() * 255).astype(np.uint8))
        prob_path = results_dir / "probability_map.png"
        prob_img.save(prob_path)
        print(f"✓ Probability map saved to: {prob_path}")
        
        # Save binary mask
        mask_img = Image.fromarray((mask_cpu.numpy() * 255).astype(np.uint8))
        mask_path = results_dir / "binary_mask.png"
        mask_img.save(mask_path)
        print(f"✓ Binary mask saved to: {mask_path}")
        
        # Save overlay image
        overlay_img = Image.fromarray(overlay.astype(np.uint8))
        overlay_path = results_dir / "overlay_result.png"
        overlay_img.save(overlay_path)
        print(f"✓ Overlay result saved to: {overlay_path}")
        
        # Save raw model output as numpy array
        np.save(results_dir / "raw_output.npy", output_cpu.numpy())
        print(f"✓ Raw model output saved to: {results_dir / 'raw_output.npy'}")
        
        # Save statistics to text file
        with open(results_dir / "statistics.txt", "w") as f:
            f.write(f"Change Detection Results\n")
            f.write(f"{'='*30}\n")
            f.write(f"Model: {CONFIG['model_checkpoint']}\n")
            f.write(f"Image size: {CONFIG['img_size']}x{CONFIG['img_size']}\n")
            f.write(f"Total pixels: {total_pixels:,}\n")
            f.write(f"Changed pixels: {changed_pixels:,}\n")
            f.write(f"Change percentage: {change_percentage:.2f}%\n")
            f.write(f"Model output range: [{output_cpu.min():.4f}, {output_cpu.max():.4f}]\n")
            f.write(f"Probability range: [{prob_cpu.min():.4f}, {prob_cpu.max():.4f}]\n")
            f.write(f"Mean probability: {prob_cpu.mean():.4f}\n")
        
        print(f"✓ Statistics saved to: {results_dir / 'statistics.txt'}")
        print(f"\nAll results saved in: {results_dir.absolute()}")
        
    except Exception as e:
        print(f"❌ Error saving results: {e}")
else:
    print("Skipping save results (save_results = False)")