# EVA Probe Transform Visualizer

This notebook is designed to be executed from inside the [`eva-probe`](https://github.com/MedARC-AI/eva-probe) repository. It demonstrates how to

1. Load a handful of raw images from disk.
2. Instantiate the image transformations defined by the project configuration.
3. Inspect the transformation pipeline.
4. Visualize the original and transformed versions of each sample.

> **Tip:** Before running the notebook, make sure you have installed the repository's dependencies (see the project's `README.md`) and that any paths configured below point to existing files/directories on your machine.


## 1. Imports

This cell gathers all the Python packages we need. The `omegaconf` and `hydra` ecosystem is commonly used in the project for configuration, so we rely on it to parse the transform definitions. Feel free to add extra imports if your environment requires them.


In [None]:
from __future__ import annotations

import importlib
import inspect
import sys
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import Any, Iterable

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

try:
    import torch
except ImportError as exc:  # pragma: no cover - torch should be available but we fail loudly otherwise
    raise ImportError("PyTorch is required to run this notebook. Please install it before proceeding.") from exc

try:
    from omegaconf import OmegaConf
except ImportError as exc:  # pragma: no cover - omegaconf should be available but we fail loudly otherwise
    raise ImportError(
        "omegaconf is required to parse the EVA Probe configuration. Install it with `pip install omegaconf`."
    ) from exc


## 2. Configure paths

Update the variables in the next cell so they match your local checkout. The defaults assume the notebook is saved in the repository's `notebooks/` folder.

* `REPO_ROOT`: location of the `eva-probe` repository.
* `RAW_IMAGE_ROOT`: directory containing the images you want to inspect.
* `TRANSFORM_CONFIG_PATH`: YAML (or `.hydra`) file that defines the transform pipeline you are interested in.
* `TRANSFORM_CONFIG_KEY`: dotted path inside the YAML pointing to the transform specification (for example `datamodule.train_transforms`).
* `NUM_SAMPLES`: how many images to visualize.
* `IMAGE_EXTENSIONS`: file extensions that will be considered when scanning for sample images.


In [None]:
# --- User editable section ----------------------------------------------------
REPO_ROOT = Path.cwd().resolve().parents[0]  # adjust if the notebook is moved elsewhere
RAW_IMAGE_ROOT = REPO_ROOT / "data" / "samples"  # change to the folder that contains your raw images
TRANSFORM_CONFIG_PATH = REPO_ROOT / "configs" / "data" / "transforms" / "default.yaml"  # adjust to your config file
TRANSFORM_CONFIG_KEY = "transforms.train"  # dotted path inside the YAML pointing to the transform definition
NUM_SAMPLES = 4
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")
# -----------------------------------------------------------------------------

if not REPO_ROOT.exists():
    raise FileNotFoundError(f"REPO_ROOT does not exist: {REPO_ROOT}")

if not RAW_IMAGE_ROOT.exists():
    raise FileNotFoundError(
        f"RAW_IMAGE_ROOT does not exist: {RAW_IMAGE_ROOT}.
"
        "Please update the path so it points to a directory with sample images."
    )

if not TRANSFORM_CONFIG_PATH.exists():
    raise FileNotFoundError(
        f"TRANSFORM_CONFIG_PATH does not exist: {TRANSFORM_CONFIG_PATH}.
"
        "Update it so it references a valid EVA Probe transform configuration file."
    )

if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))
    print(f"Added {REPO_ROOT} to PYTHONPATH")


## 3. Helpers to instantiate and display transforms

The EVA Probe project typically describes transforms with Hydra/OmegaConf objects. The utility functions below convert those dictionaries into actual callables and provide a few quality-of-life helpers for visualisation.


In [None]:
def to_plain_python(value: Any) -> Any:
    """Recursively convert OmegaConf nodes, dataclasses, and other containers to plain Python objects."""
    if value is None:
        return None
    if isinstance(value, (str, int, float, bool)):
        return value
    if isinstance(value, Path):
        return str(value)
    if isinstance(value, dict):
        return {k: to_plain_python(v) for k, v in value.items()}
    if isinstance(value, (list, tuple, set)):
        iterable_type = type(value)
        return iterable_type(to_plain_python(v) for v in value)
    if is_dataclass(value):
        return to_plain_python(asdict(value))
    return value


def instantiate_from_config(config: Any) -> Any:
    """Instantiate Python objects from a Hydra-style configuration.

    The function understands dictionaries that contain a `_target_` or `target` key. Nested dictionaries and lists are
    recursively processed, which allows Compose-style pipelines to be created on the fly.
    """
    if isinstance(config, dict):
        keys = {"_target_", "target"} & config.keys()
        if keys:
            target_key = keys.pop()
            target = config[target_key]
            module_path, _, attr_name = target.rpartition('.')
            if not module_path:
                raise ValueError(f"Invalid target specification: {target}")
            module = importlib.import_module(module_path)
            factory = getattr(module, attr_name)
            kwargs = {
                k: instantiate_from_config(v)
                for k, v in config.items()
                if k not in {target_key, "_target_"}
            }
            return factory(**kwargs)
        return {k: instantiate_from_config(v) for k, v in config.items()}
    if isinstance(config, list):
        return [instantiate_from_config(v) for v in config]
    return config


def flatten_transform(transform: Any) -> list[Any]:
    """Return a flat list of sub-transforms for inspection purposes."""
    if transform is None:
        return []
    if hasattr(transform, 'transforms') and isinstance(transform.transforms, Iterable):
        flattened: list[Any] = []
        for sub in transform.transforms:
            flattened.extend(flatten_transform(sub))
        return flattened
    return [transform]


def apply_transform_to_pil(image: Image.Image, transform: Any) -> Image.Image:
    """Apply a transform to a PIL image and convert the result back to PIL for display."""
    if transform is None:
        return image

    # Try PIL / tensor based transforms first
    try:
        transformed = transform(image)
    except TypeError:
        # Albumentations-style API
        transformed = transform(image=np.array(image))
        if isinstance(transformed, dict) and 'image' in transformed:
            transformed = transformed['image']

    if isinstance(transformed, Image.Image):
        return transformed

    if isinstance(transformed, np.ndarray):
        array = transformed
        if array.ndim == 2:
            array = np.stack([array] * 3, axis=-1)
        if array.dtype != np.uint8:
            if array.max() <= 1.0:
                array = np.clip(array, 0.0, 1.0)
                array = (array * 255.0).astype(np.uint8)
            else:
                array = np.clip(array, 0, 255).astype(np.uint8)
        return Image.fromarray(array)

    if torch.is_tensor(transformed):
        tensor = transformed.detach().cpu()
        if tensor.ndim == 2:
            tensor = tensor.unsqueeze(0)
        if tensor.ndim == 3 and tensor.shape[0] in {1, 3}:
            tensor = tensor.permute(1, 2, 0)
        array = tensor.numpy()
        array = np.clip(array, 0.0, 1.0)
        array = (array * 255.0).astype(np.uint8)
        return Image.fromarray(array)

    raise TypeError(
        'Unsupported transform output type. Got '
        f"{type(transformed)}; expected PIL.Image.Image, numpy.ndarray, or torch.Tensor."
    )


def describe_transform(transform: Any) -> None:
    """Print a readable description of the transform pipeline."""
    flattened = flatten_transform(transform)
    print('Transform pipeline:')
    for idx, sub in enumerate(flattened, start=1):
        qualname = sub.__class__.__name__ if not inspect.isfunction(sub) else sub.__name__
        print(f"  {idx:02d}. {qualname}")
        if hasattr(sub, '__dict__'):
            params = {
                k: v
                for k, v in vars(sub).items()
                if not k.startswith('_') and not inspect.ismethod(v) and not inspect.isfunction(v)
            }
            if params:
                for key, value in params.items():
                    print(f"        - {key}: {value}")


## 4. Load the transform configuration

This step reads the YAML file, resolves the section that contains the transform specification, and instantiates the actual pipeline. The resulting object is ready to be applied to raw images.


In [None]:
config = OmegaConf.load(TRANSFORM_CONFIG_PATH)
plain_config = to_plain_python(config)

transform_cfg = plain_config
for key in TRANSFORM_CONFIG_KEY.split('.'):
    if key not in transform_cfg:
        raise KeyError(
            f"Could not find the key '{key}' inside the transform configuration.
"
            "Double-check TRANSFORM_CONFIG_KEY or inspect the YAML structure."
        )
    transform_cfg = transform_cfg[key]

transform_pipeline = instantiate_from_config(transform_cfg)

describe_transform(transform_pipeline)


## 5. Gather sample images

We now collect a few sample image paths. If your dataset is nested in class-specific subfolders (e.g. ImageNet-style layout), the glob pattern below will still find the images recursively.


In [None]:
image_paths = [
    path
    for extension in IMAGE_EXTENSIONS
    for path in sorted(RAW_IMAGE_ROOT.rglob(f"*{extension}"))
]

if not image_paths:
    raise RuntimeError(
        f"No images with extensions {IMAGE_EXTENSIONS} were found under {RAW_IMAGE_ROOT}.
"
        "Verify that RAW_IMAGE_ROOT points to the correct location or adjust IMAGE_EXTENSIONS."
    )

sample_paths = image_paths[:NUM_SAMPLES]
print(f"Found {len(image_paths)} image(s); displaying the first {len(sample_paths)} sample(s).")
for idx, path in enumerate(sample_paths, 1):
    print(f"  {idx:02d}. {path}")


## 6. Visualise original vs. transformed images

The final cell applies the transform pipeline to each sample and displays them side-by-side. If your transforms produce tensors (e.g. normalized PyTorch tensors), they are converted back to displayable images using simple heuristics. Adjust the logic if your pipeline outputs additional data (masks, depth maps, etc.).


In [None]:
fig, axes = plt.subplots(len(sample_paths), 2, figsize=(10, 4 * len(sample_paths)))
if len(sample_paths) == 1:
    axes = np.expand_dims(axes, axis=0)

for row, image_path in zip(axes, sample_paths):
    original = Image.open(image_path).convert('RGB')
    transformed = apply_transform_to_pil(original, transform_pipeline)

    row[0].imshow(original)
    row[0].set_title(f"Original
{image_path.name}")
    row[0].axis('off')

    row[1].imshow(transformed)
    row[1].set_title('Transformed')
    row[1].axis('off')

plt.tight_layout()
plt.show()
