## Modules and repo

In [1]:
!pip install -q diffusers transformers accelerate

In [2]:
import sys
import torch
import torch.nn.functional as F
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline
from transformers import CLIPTextModel, CLIPTokenizer
from typing import Tuple
from torch.utils.data import DataLoader
import os
import json
from datetime import datetime

!rm -rf semantic-correspondence

!git clone https://github.com/MarcotteS/semantic-correspondence.git
import sys
sys.path.append('/content/semantic-correspondence/src')

Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.


Cloning into 'semantic-correspondence'...
remote: Enumerating objects: 208, done.[K
remote: Counting objects: 100% (208/208), done.[K
remote: Compressing objects: 100% (147/147), done.[K
remote: Total 208 (delta 117), reused 112 (delta 46), pack-reused 0 (from 0)[K
Receiving objects: 100% (208/208), 1.33 MiB | 6.39 MiB/s, done.
Resolving deltas: 100% (117/117), done.


In [3]:
from analyzer import ResultsAnalyzer
from evaluation import CorrespondenceEvaluator,evaluate_model
from correspondence import CorrespondenceMatcher
from dataset import SPairDataset,denorm,draw_image_with_keypoints,visualize_sample,collate_fn_correspondence

## Loading weights

In [4]:
!hf auth login

WEIGHTS_PATH = "/content/sd-weights"

print("Downloading Stable Diffusion weights...")
pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        use_safetensors=True,
        low_cpu_mem_usage=True
    ).to("cuda")

pipe.save_pretrained(WEIGHTS_PATH)

print(f"Weights saved to {WEIGHTS_PATH}")


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_http.py", line 402, in hf_raise_for_status
    response.raise_for_status()
  File "/usr/local/lib/python3.12/dist-packages/requests

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model_index.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

Fetching 19 files:   0%|          | 0/19 [00:00<?, ?it/s]

scheduler_config.json:   0%|          | 0.00/479 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/575 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/737 [00:00<?, ?B/s]

text_encoder/model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

text_encoder_2/model.safetensors:   0%|          | 0.00/2.78G [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer_config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/642 [00:00<?, ?B/s]

vae/diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

unet/diffusion_pytorch_model.safetensors:   0%|          | 0.00/10.3G [00:00<?, ?B/s]

vae_1_0/diffusion_pytorch_model.safetens(…):   0%|          | 0.00/335M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


Weights saved to /content/sd-weights


## Stable Diffusion Extractor

In [5]:
class SDXLExtractor:
    """
    Feature extractor using Stable Diffusion XL U-Net features.
    SDXL is more recent than SD 2.1 (released July 2023).
    """

    def __init__(
        self,
        weights: str,
        model_name: str = "sdxl",
        timestep: int = 261,
        layer_name: str = "up_blocks.0",
        patch_size: int = 16,
    ):
        """
        Initialize SDXL feature extractor.

        Args:
            weights: Path to local directory containing SDXL weights
            model_name: Label for this model
            timestep: Denoising timestep (DIFT uses 261)
            layer_name: Which U-Net layer to extract from
            patch_size: Virtual patch size
        """
        print(f"Loading Stable Diffusion XL from: {weights}")

        if not os.path.exists(weights):
            raise ValueError(f"Weights not found: {weights}")

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32

        # Load SDXL pipeline components
        from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
        from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer

        print("Loading UNet...")
        self.unet = UNet2DConditionModel.from_pretrained(
            weights, subfolder="unet", torch_dtype=self.dtype, local_files_only=True
        ).to(self.device)

        print("Loading VAE...")
        self.vae = AutoencoderKL.from_pretrained(
            weights, subfolder="vae", torch_dtype=self.dtype, local_files_only=True
        ).to(self.device)

        print("Loading text encoders...")
        # SDXL uses TWO text encoders
        self.text_encoder = CLIPTextModel.from_pretrained(
            weights, subfolder="text_encoder", torch_dtype=self.dtype, local_files_only=True
        ).to(self.device)

        self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
            weights, subfolder="text_encoder_2", torch_dtype=self.dtype, local_files_only=True
        ).to(self.device)

        print("Loading tokenizers...")
        self.tokenizer = CLIPTokenizer.from_pretrained(
            weights, subfolder="tokenizer", local_files_only=True
        )
        self.tokenizer_2 = CLIPTokenizer.from_pretrained(
            weights, subfolder="tokenizer_2", local_files_only=True
        )

        print("Loading scheduler...")
        self.scheduler = DDPMScheduler.from_pretrained(
            weights, subfolder="scheduler", local_files_only=True
        )

        # Set to eval
        self.unet.eval()
        self.vae.eval()
        self.text_encoder.eval()
        self.text_encoder_2.eval()

        self.timestep = timestep
        self.layer_name = layer_name
        self.patch_size = patch_size
        self.model_name = model_name
        self.features = None

        print(f"✓ SDXL loaded! (timestep={timestep}, layer={layer_name})")

    def _register_hook(self):
        """Register forward hook to capture features."""
        def hook_fn(module, input, output):
            self.features = output

        layer = self.unet
        for name in self.layer_name.split('.'):
            layer = getattr(layer, name)

        return layer.register_forward_hook(hook_fn)

    @torch.no_grad()
    def extract(self, img: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
        """
        Extract features from image(s).

        Args:
            img: torch.Tensor [B, 3, H, W] or [3, H, W]
                Recommended size: 1024x1024 for SDXL (can use 512 too)

        Returns:
            features: [B, H_p*W_p, D]
            spatial_dims: (H_p, W_p)
        """
        if img.dim() == 3:
            img = img.unsqueeze(0)

        B, C, H, W = img.shape
        img = img.to(self.device, dtype=self.dtype)

        # Normalize to [-1, 1]
        if img.min() >= 0 and img.max() <= 1:
            img = img * 2.0 - 1.0

        # Encode to latent
        latents = self.vae.encode(img).latent_dist.sample()
        latents = latents * self.vae.config.scaling_factor

        # Prepare timestep
        t = torch.tensor([self.timestep], device=self.device).long().expand(B)

        # Add noise
        noise = torch.randn_like(latents)
        noisy_latents = self.scheduler.add_noise(latents, noise, t)

        # Get text embeddings (SDXL uses two encoders)
        text_embeddings = self._get_empty_text_embeddings(B)

        # Forward with hook
        handle = self._register_hook()

        # SDXL UNet expects additional arguments
        added_cond_kwargs = {
            "text_embeds": text_embeddings[1],  # pooled embeddings
            "time_ids": self._get_time_ids(B, H, W),
        }

        _ = self.unet(
            noisy_latents,
            t,
            encoder_hidden_states=text_embeddings[0],  # prompt embeddings
            added_cond_kwargs=added_cond_kwargs,
        ).sample

        handle.remove()

        # Process features
        features = self.features
        B, C_feat, H_feat, W_feat = features.shape
        features = features.permute(0, 2, 3, 1).reshape(B, H_feat * W_feat, C_feat)
        features = F.normalize(features, dim=-1)

        return features, (H_feat, W_feat)

    def _get_empty_text_embeddings(self, batch_size: int):
        """Get text embeddings from both SDXL encoders."""
        # Tokenize empty prompt for both tokenizers
        text_inputs = self.tokenizer(
            [""] * batch_size,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )

        text_inputs_2 = self.tokenizer_2(
            [""] * batch_size,
            padding="max_length",
            max_length=self.tokenizer_2.model_max_length,
            truncation=True,
            return_tensors="pt",
        )

        # Get embeddings from both encoders
        prompt_embeds = self.text_encoder(
            text_inputs.input_ids.to(self.device),
            output_hidden_states=True,
        )
        pooled_prompt_embeds = self.text_encoder_2(
            text_inputs_2.input_ids.to(self.device),
            output_hidden_states=True,
        )

        # SDXL concatenates hidden states from both encoders
        prompt_embeds = prompt_embeds.hidden_states[-2]
        pooled_prompt_embeds = pooled_prompt_embeds[0]

        return (prompt_embeds, pooled_prompt_embeds)

    def _get_time_ids(self, batch_size: int, height: int, width: int):
        """Get time IDs for SDXL conditioning."""
        # Original size, crops coords, target size
        time_ids = torch.tensor([[height, width, 0, 0, height, width]], device=self.device)
        time_ids = time_ids.expand(batch_size, -1)
        return time_ids


## Load dataset

In [6]:
!wget https://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz
!tar -xzvf SPair-71k.tar.gz

[1;30;43mLe flux de sortie a été tronqué et ne contient que les 5000 dernières lignes.[0m
SPair-71k/ImageAnnotation/horse/2007_001724.json
SPair-71k/ImageAnnotation/horse/2008_004705.json
SPair-71k/ImageAnnotation/horse/2007_008596.json
SPair-71k/ImageAnnotation/horse/2008_008296.json
SPair-71k/ImageAnnotation/horse/2008_000765.json
SPair-71k/ImageAnnotation/horse/2008_001729.json
SPair-71k/ImageAnnotation/horse/2007_001420.json
SPair-71k/ImageAnnotation/horse/2009_003768.json
SPair-71k/ImageAnnotation/horse/2008_002665.json
SPair-71k/ImageAnnotation/horse/2009_003734.json
SPair-71k/ImageAnnotation/horse/2009_002957.json
SPair-71k/ImageAnnotation/horse/2008_008279.json
SPair-71k/ImageAnnotation/horse/2008_000141.json
SPair-71k/ImageAnnotation/horse/2007_000783.json
SPair-71k/ImageAnnotation/horse/2007_000799.json
SPair-71k/ImageAnnotation/horse/2008_002806.json
SPair-71k/ImageAnnotation/horse/2008_002686.json
SPair-71k/ImageAnnotation/horse/2010_001077.json
SPair-71k/ImageAnnotation/

In [7]:
image_size = 512  # SD native size

dataset = SPairDataset(
    datapath='.',
    split='test',
    img_size=image_size,
    category='all'
)

dataloader = DataLoader(
    dataset,
    batch_size=4,  # Colab GPU can handle 4-8
    shuffle=False,
    num_workers=2,  # Lower for Colab
    collate_fn=collate_fn_correspondence
)

Loading SPair-71k test annotations...


100%|██████████| 12234/12234 [00:02<00:00, 4785.56it/s]


## Evaluation

In [None]:
# Initialize extractor
extractor = SDXLExtractor(
    weights=WEIGHTS_PATH,
    model_name="sd-xl",
    timestep=261,
    layer_name="up_blocks.0"
)

# Create matcher
matcher = CorrespondenceMatcher(extractor)

# Evaluate
metrics = evaluate_model(matcher, dataloader)

Loading Stable Diffusion XL from: /content/sd-weights
Loading UNet...


In [None]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

model_name = "sd-xl"

results = {
    'metrics': metrics,
    'model': model_name,
    'config': {
        'timestep': 261,
        'layer': 'up_blocks.0',
        'image_size': image_size
    },
    'timestamp': timestamp
}

with open(f'results_{model_name}_{image_size}_{timestamp}.json', 'w') as f:
    json.dump(results, f, indent=2)

print(f"Results saved!")

In [None]:
analyzer = ResultsAnalyzer(metrics)
summary = analyzer.create_summary_table(threshold=0.10)
print("\nPCK@0.10 Summary:")
print(summary)

analyzer.generate_report(save_dir=f'./results/sd-xl')