In [None]:
import os
from pathlib import Path

import torch
from PIL import Image
import requests
import matplotlib.pyplot as plt
import numpy as np
from einops import rearrange
from transformers import AutoImageProcessor, AutoModel
from ncut_pytorch import ncut_fn, tsne_color

In [None]:
def seed_everything(seed: int = 42):
    """Set random seed for reproducibility."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(123)

In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

IMAGE_RESOLUTION = 224

processor = AutoImageProcessor.from_pretrained("facebook/dino-vitb16")
processor.crop_size = {'height': IMAGE_RESOLUTION, 'width': IMAGE_RESOLUTION}
model = AutoModel.from_pretrained("facebook/dino-vitb16").to(DEVICE).to(DTYPE)

In [None]:
import h5py
data_dir = Path("hw3 solo datasets")
data_paths = {fname.split("_")[2]: f"{data_dir}/{fname}" for fname in os.listdir(data_dir)}
with h5py.File(data_paths["img"], "r") as f:
    images_raw = torch.tensor(np.array(f["data"])).to(torch.float)

In [None]:
# Identify and crop the non-padded region for each image individually
non_padded_images = []
for img in images_raw:
    # Identify the non-padded region
    non_zero_mask = img.abs().sum(dim=0) > 0  # Sum across channels to find non-zero regions
    non_zero_rows = non_zero_mask.any(dim=1)  # Find rows with non-zero values
    non_zero_cols = non_zero_mask.any(dim=0)  # Find columns with non-zero values

    # Find the bounding box of the non-padded region
    top, bottom = torch.where(non_zero_rows)[0][[0, -1]]
    left, right = torch.where(non_zero_cols)[0][[0, -1]]

    # Crop the image to the non-padded region
    cropped_img = img[:, top:bottom + 1, left:right + 1]
    non_padded_images.append(cropped_img)

In [None]:
# Select specific indices along the first dimension of the 4D tensor
selected_indices = [0, 3, 4, 6, 24, 26, 27, 29, 37, 44, 51, 52, 53, 66, 71, 72, 73, 76, 85, 92, 93, 95, 97]
images_cat = [non_padded_images[i] for i in selected_indices[:20]]
fig, axs = plt.subplots(4, 5, figsize=(15, 15))
for i, img in enumerate(images_cat):
    ax = axs[i // 5, i % 5]
    ax.imshow(img.permute(1, 2, 0).numpy().astype(np.uint8))
    ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
images = [img.permute(1, 2, 0).numpy().astype(np.uint8) for img in images_cat]

# Process the batch of images
batch_inputs = processor(images=images, return_tensors="pt", do_rescale=True).to(DEVICE).to(DTYPE)
print(f"Batch shape: {batch_inputs['pixel_values'].shape}")
images = [img.astype(np.float32) / 255.0 for img in images]


In [None]:
with torch.no_grad():
    outputs = model(**batch_inputs, output_hidden_states=True, interpolate_pos_encoding=True)

# The last hidden state contains the patch features (and CLS token)
last_hidden_state = outputs.last_hidden_state
print(f"Shape of last hidden state: {last_hidden_state.shape}")

# The hidden states are a tuple, one for each layer
hidden_states = outputs.hidden_states
print(f"Number of hidden layers: {len(hidden_states)}")

In [None]:
for target_layer in [-5, -4, -3, -2, -1]:
# target_layer = -2 # Last layer
# We are interested in the patch features, so we skip the first token (CLS token)
    patch_features = hidden_states[target_layer][:, 1:, :]

    # NCut and t-SNE coloring
    b, n, d = patch_features.shape
    patch_features = rearrange(patch_features, 'b n d -> (b n) d')  # Combine batch and patches
    eigvecs, _ = ncut_fn(patch_features, n_eig=100, d_gamma=0.05)
    colors_rgb = tsne_color(eigvecs)

    # The processor resizes the image to a square, so we can find the patch grid size
    h = w = int(np.sqrt(n))
    print(f"Patch grid size: {h}x{w}; Layer: {target_layer}")

    # Reshape the colors to match the patch grid
    color_grid = rearrange(colors_rgb.cpu(), '(b h w) c -> b h w c', b=b, h=h, w=w)

    # Visualize the original images and colored image patches
    num_images_to_show = min(8, b)
    fig, axs = plt.subplots(2, num_images_to_show, figsize=(15, 10))
    for i in range(num_images_to_show):
        # Show original image
        axs[0, i].imshow(images[i])
        axs[0, i].set_title(f'Original Image {i+1}')
        axs[0, i].axis('off')
        
        # Show NCut visualization
        axs[1, i].imshow(color_grid[i].numpy())
        axs[1, i].set_title(f'NCut Visualization {i+1}')
        axs[1, i].axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
import sys
sys.path.append('./src')
from model import SOLO

################################################################################################################
# TODO: Load the model checkpoint you want to visualize here
################################################################################################################
solo_model: SOLO = SOLO.load_from_checkpoint("checkpoints/best_resnet_epoch=35.ckpt", strict=False).cuda()
################################################################################################################

backbone_outputs = []
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
normalized_images = [(torch.tensor(img).to(torch.float32) - mean) / std for img in images]
normalized_images = [img.permute(2, 0, 1) for img in normalized_images]  # Change to (C, H, W) format

with torch.no_grad():
    for non_padded_batched_img in normalized_images:
        backbone_output = solo_model.backbone(non_padded_batched_img.unsqueeze(0).to(DEVICE))
        backbone_outputs.append(backbone_output)


In [None]:
for target_layer in backbone_outputs[0].keys():
    print(f"Backbone layer: {target_layer}, shape: {backbone_outputs[0][target_layer].shape}")
    patch_features_all = []
    for backbone_output in backbone_outputs:
        resnet_feats = backbone_output[target_layer].permute(0,2,3,1)
        h, w = resnet_feats.shape[1], resnet_feats.shape[2]
        b, d = resnet_feats.shape[0], resnet_feats.shape[3]
        n = h * w
        patch_features = rearrange(resnet_feats, 'b h w d -> (b h w) d')  # Combine batch and patches
        patch_features_all.append(patch_features)
    patch_features_all = torch.cat(patch_features_all, dim=0)
    b = len(backbone_outputs)
    eigvecs, _ = ncut_fn(patch_features_all, n_eig=100, d_gamma=0.05)
    colors_rgb = tsne_color(eigvecs)
    num_images_to_show = min(8, b)

    fig, axs = plt.subplots(2, num_images_to_show, figsize=(15, 10))
    idx = 0
    for i in range(num_images_to_show):
        h, w = backbone_outputs[i][target_layer].shape[2], backbone_outputs[i][target_layer].shape[3]
        colors_rgb_i = colors_rgb[idx:idx + h * w]
        color_grid = rearrange(colors_rgb_i.cpu(), '(b h w) c -> b h w c', b=1, h=h, w=w)
        idx += h * w
        # Show original image
        axs[0, i].imshow(images[i])
        axs[0, i].set_title(f'Original Image {i+1}')
        axs[0, i].axis('off')
        
        # Show NCut visualization
        axs[1, i].imshow(color_grid[0].numpy())
        axs[1, i].set_title(f'NCut Visualization {i+1}')
        axs[1, i].axis('off')
    plt.tight_layout()
    plt.show()