In [5]:
import os
import sys
from pathlib import Path

# Set the destination path
dst_path = '/p/openvocabdustr/probing_midlevel_vision/code/probing-mid-level-vision/evals/models'

# Change to the destination directory
os.chdir(dst_path)

# Add the destination path to sys.path so that it can import the necessary modules
sys.path.append(dst_path)

# Now you can import your modules as absolute imports
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision.transforms import ToTensor, Normalize, Compose
from PIL import Image

# Replace relative imports with absolute imports
from util import load_checkpoint, prepare_state_dict
from utils import center_padding
import torchvision

# Import CroCoNet from the croco_model
from croco_models.croco import CroCoNet

In [60]:
# Define the checkpoints and paths
checkpoints = {
    "croco_v2": {
        "url": "https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth",  # Replace with actual URL
        "filename": "CroCo.pth",
    }
}


class CroCoV2(nn.Module):
    def __init__(
        self,
        model_name="croco_v2",
        layer=-1,
        output="dense",
        return_multilayer=False,
        add_norm=False,  # Add flag to control batch normalization
    ):
        super().__init__()

        # Load the model within __init__
        self.model = self.load_model(model_name)

        self.output = output
        self.checkpoint_name = f"${model_name}$_{output}"
        self.patch_size = 16  # CroCoNet typically uses a 16x16 patch size
        self.add_norm = add_norm

        # Setup batch normalization layers for each layer
        num_layers = len(self.model.enc_blocks)
        feat_dim = 512  # Adjust this based on actual feature size in CroCoNet

        # Define which layers to extract
        multilayers = [
            0, 2, 4, 6,
        ]

        if return_multilayer:
            self.feat_dim = [feat_dim] * 4
            self.multilayers = multilayers
        else:
            self.feat_dim = feat_dim
            self.multilayers = [multilayers[-1]]
        self.layer = "-".join(str(_x) for _x in self.multilayers)

        # Define BatchNorm2d layers for each multilayer
        self.batchnorms = nn.ModuleList(
            [nn.BatchNorm2d(feat_dim) for _ in self.multilayers]
        )

    def load_model(self, model_name: str):
        """Load the CroCo model from checkpoint."""
        assert model_name in checkpoints.keys(), f"Invalid model: {model_name}"
        ckpt = load_checkpoint(**checkpoints[model_name])
        # ckpt = prepare_state_dict(ckpt["model"])

        model = CroCoNet(
            **ckpt.get("croco_kwargs", {})
        )  # Initialize CroCoNet with arguments
        model.load_state_dict(ckpt["model"], strict=True)
        return model.eval()

    def forward(self, image1, image2):
        """Forward pass through the CroCo model."""
        with torch.inference_mode():
            # encoder of the masked first image
            feat1, pos1, mask1 = self.model._encode_image(image1, do_mask=True)
            # encoder of the self.model image
            feat2, pos2, _ = self.model._encode_image(image2, do_mask=False)
            # decoder
            visf1 = self.model.decoder_embed(feat1)
            f2 = self.model.decoder_embed(feat2)
            # append masked tokens to the sequence
            outputs = []
            
            B, Nenc, C = visf1.size()
            if mask1 is None: # downstreams
                f1_ = visf1
            else: # pretraining 
                Ntotal = mask1.size(1)
                f1_ = self.model.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype)
                f1_[~mask1] = visf1.view(B * Nenc, C)
            # add positional embedding
            if self.model.dec_pos_embed is not None:
                f1_ = f1_ + self.model.dec_pos_embed
                f2 = f2 + self.model.dec_pos_embed
            # apply Transformer blocks
            out = f1_
            out2 = f2
            
            for idx, blk in enumerate(self.model.dec_blocks):
                out, out2 = blk(out, out2, pos1, pos2)
                if idx in self.multilayers:
                    if self.add_norm:
                        out_norm = self.batchnorms[self.multilayers.index(idx)](out.reshape(1, 14, 14, 512).permute(0, 3, 1, 2))
                        out_norm = out_norm.permute(0, 2, 3, 1)
                        outputs.append(out_norm)
                    else:
                        outputs.append(out)

        return outputs

    def process_images(self, image_paths, device):
        """Process input images and apply necessary transformations."""
        imagenet_mean = [0.485, 0.456, 0.406]
        imagenet_std = [0.229, 0.224, 0.225]
        trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)])

        # Load and transform images
        images = [
            trfs(Image.open(image).convert("RGB")).to(device).unsqueeze(0)
            for image in image_paths
        ]
        return images

    def decode_output(self, output, image1, mask, device):
        """Undo normalization and prepare masked image for visualization."""
        imagenet_mean_tensor = (
            torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
        )
        imagenet_std_tensor = (
            torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
        )

        decoded_image = output * imagenet_std_tensor + imagenet_mean_tensor
        input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor
        image_masks = self.model.unpatchify(
            self.model.patchify(torch.ones_like(image1)) * mask[:, :, None]
        )
        masked_input_image = (1 - image_masks) * input_image

        return decoded_image, masked_input_image

    def visualize(self, image1, image2, decoded_image, masked_input_image):
        """Create visualization of the input, reference, masked, and decoded images."""
        visualization = torch.cat(
            (image2, masked_input_image, decoded_image, image1), dim=3
        )
        B, C, H, W = visualization.shape
        visualization = visualization.permute(1, 0, 2, 3).reshape(C, B * H, W)
        return torchvision.transforms.functional.to_pil_image(
            torch.clamp(visualization, 0, 1)
        )

In [61]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CroCoV2(output="dense", add_norm=True).to(device)

In [62]:

# Process images
image_paths = ["assets/Chateau1.png", "assets/Chateau2.png"]
images = model.process_images(image_paths, device)
image1, image2 = images

# Forward pass
features = model.forward(image1, image2)

In [65]:
features[0].shape

torch.Size([1, 14, 14, 512])