In [None]:
!pip install torch torchvision transformers timm tifffile matplotlib opencv-python

In [None]:
import os
import torch
from torchvision import transforms
from transformers import AutoImageProcessor, AutoModelForImageToImage
from PIL import Image
import tifffile as tiff
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [None]:
!huggingface-cli login

In [None]:
# Load the Hugging Face MAXIM Dehazing model
from transformers import AutoImageProcessor, AutoModelForImageToImage
from huggingface_hub import HfFolder  # Import HfFolder for loading the token

# Load the token from the stored location
token = HfFolder.get_token()

# Use the loaded token when loading the model
processor = AutoImageProcessor.from_pretrained("google/maxim-s2-dehazing-sots-outdoor", token=token)
model = AutoModelForImageToImage.from_pretrained("google/maxim-s2-dehazing-sots-outdoor", token=token)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
# TIFF image loader and preprocessor
def load_tiff_image(file_path):
    img = tiff.imread(file_path)
    if img.ndim == 2:
        img = np.stack([img] * 3, axis=-1)
    elif img.ndim == 3:
        if img.shape[0] == 3:
            img = np.transpose(img, (1, 2, 0))
        if img.shape[-1] > 3:
            img = img[:, :, :3]
    img = Image.fromarray(np.uint8(img))
    return img

In [None]:
# Run inference using the model
def dehaze_image(image):
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    output_image = processor.post_process(outputs, output_type="pil")[0]
    return output_image

In [None]:
# Visualize attention map (last layer average)
def visualize_attention(image, model, processor):
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)

    # Extract attention from last layer
    attention_maps = outputs.attentions[-1]
    avg_attention = attention_maps.mean(dim=1)[0]  # average over heads

    # Take attention from [CLS] to all tokens and reshape
    attn_weights = avg_attention[0, 1:]
    num_patches = int(attn_weights.shape[0] ** 0.5)
    attn_map = attn_weights.reshape(num_patches, num_patches).cpu().numpy()

    # Resize attention map to image size
    attn_map = cv2.resize(attn_map, image.size)
    attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())

    # Overlay attention on image
    image_np = np.array(image).astype(np.float32)
    heatmap = (attn_map[..., None] * 255).astype(np.uint8)
    heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(image_np.astype(np.uint8), 0.6, heatmap_colored, 0.4, 0)
    return Image.fromarray(overlay)

In [None]:
# Main pipeline
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Single Image Dehazing with XAI")
    parser.add_argument("--input", type=str, required=True, help="Path to input .tiff image")
    parser.add_argument("--output_dir", type=str, default="outputs", help="Directory to save outputs")
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    # Load and process image
    image = load_tiff_image(args.input)
    dehazed = dehaze_image(image)
    attention_overlay = visualize_attention(image, model, processor)

    # Save outputs
    input_name = os.path.splitext(os.path.basename(args.input))[0]
    image.save(os.path.join(args.output_dir, f"{input_name}_original.png"))
    dehazed.save(os.path.join(args.output_dir, f"{input_name}_dehazed.png"))
    attention_overlay.save(os.path.join(args.output_dir, f"{input_name}_attention.png"))

    print("Processing complete. Outputs saved to:", args.output_dir)