# Data Loading

This notebook reads images from the channels PAX5 (nuclear) and CD3 (membrane), and obtains the DINOv2 embeddings.

In [1]:
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from transformers import AutoImageProcessor, AutoModel
from tqdm.auto import tqdm

# --- Config ---
MODEL_NAME = "facebook/dinov2-base"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_DIR = Path("../data/images")  # Where the marker folders are
MARKERS = ["PAX5", "CD3"]

# --- DINOv2 Setup ---
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

# --- Optional: Compile for faster inference ---
try:
    model = torch.compile(model)
except Exception:
    pass  # torch.compile may not be supported on all systems

# --- Helper: Load and preprocess images ---
def load_images(marker: str) -> list[tuple[str, torch.Tensor]]:
    marker_path = IMAGE_DIR / marker.replace("/", "_")
    image_files = sorted(marker_path.glob("*.png"))
    images = []

    for path in image_files:
        image = Image.open(path).convert("RGB")
        inputs = processor(image, return_tensors="pt")
        images.append((str(path), inputs))

    return images

# --- Helper: Batch and run through model ---
def get_embeddings(image_batches: list[tuple[str, dict]]) -> list[dict]:
    results = []
    BATCH_SIZE = 32

    for i in tqdm(range(0, len(image_batches), BATCH_SIZE), desc="Embedding"):
        batch = image_batches[i:i+BATCH_SIZE]
        paths = [x[0] for x in batch]
        input_batch = {k: torch.cat([x[1][k] for x in batch]).to(DEVICE) for k in batch[0][1]}

        with torch.no_grad(), torch.autocast("cuda" if DEVICE.type == "cuda" else "cpu"):
            outputs = model(**input_batch)

        # Use CLS token representation
        cls_embeddings = outputs.last_hidden_state[:, 0].cpu().numpy()

        for path, emb in zip(paths, cls_embeddings):
            results.append({
                "image_path": path,
                "embedding": emb
            })

    return results

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [2]:
DEVICE

device(type='cuda')

In [3]:
# --- Main: Embed all images for selected markers ---
all_embeddings = []

for marker in MARKERS:
    print(f"Processing marker: {marker}")
    images = load_images(marker)
    if not images:
        print(f"No images found for {marker}")
        continue
    embeddings = get_embeddings(images)
    for entry in embeddings:
        entry["marker"] = marker
    all_embeddings.extend(embeddings)

Processing marker: PAX5


Embedding:   0%|          | 0/12 [00:00<?, ?it/s]

Processing marker: CD3


Embedding:   0%|          | 0/12 [00:00<?, ?it/s]

In [7]:
# --- Convert to DataFrame ---
import pandas as pd

df = pd.DataFrame(all_embeddings)
df['sample_id'] = df['image_path'].apply(lambda p: Path(p).stem)
df['embedding'] = df['embedding'].apply(np.array)

# Optional: Save embeddings (e.g., as Parquet with numpy array)
df.to_parquet("../data/marker_embeddings_subset.parquet", engine="pyarrow", compression="zstd")

print("Done. Embeddings saved to ../data/marker_embeddings_subset.parquet")

Done. Embeddings saved to ../data/marker_embeddings_subset.parquet


In [8]:
df

Unnamed: 0,image_path,embedding,marker,sample_id
0,../data/images/PAX5/sample_001.png,"[4.1601596, 0.19707349, -0.39906776, 0.4187315...",PAX5,sample_001
1,../data/images/PAX5/sample_002.png,"[2.1713753, -0.18894142, -1.1794543, 0.0642976...",PAX5,sample_002
2,../data/images/PAX5/sample_003.png,"[2.3522525, 1.3800017, -1.7806236, -1.49976, -...",PAX5,sample_003
3,../data/images/PAX5/sample_006.png,"[3.3333764, 0.9943014, -0.4116209, 1.0857946, ...",PAX5,sample_006
4,../data/images/PAX5/sample_007.png,"[3.0457215, 1.0350839, -0.34904912, 1.10828, -...",PAX5,sample_007
...,...,...,...,...
729,../data/images/CD3/sample_400.png,"[1.2264028, 0.055597667, -0.92955977, -2.60830...",CD3,sample_400
730,../data/images/CD3/sample_401.png,"[0.7608712, 1.8352201, 0.3570453, 0.7371291, -...",CD3,sample_401
731,../data/images/CD3/sample_402.png,"[1.0764847, -0.42349392, -0.00124115, 0.235760...",CD3,sample_402
732,../data/images/CD3/sample_403.png,"[0.8793458, 0.7496702, 0.10183122, -0.8387582,...",CD3,sample_403
