In [None]:
from sdhelper import SD
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoProcessor, CLIPModel, pipeline
from tqdm.notebook import tqdm, trange
from piqa.ssim import ssim, SSIM
from skimage.metrics import structural_similarity as sk_ssim
from collections import defaultdict

In [None]:
sd = SD('SDXL-Turbo', disable_progress_bar=True)

In [None]:
class ClipEmbed:
    def __init__(self, device):
        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.model.to(device)

    @torch.no_grad()
    def __call__(self, images):
        inputs = self.processor(images=images, return_tensors="pt").to(self.model.device)
        image_features = self.model.get_image_features(**inputs)
        return image_features

clip = ClipEmbed(sd.device)

In [None]:
sam_pipe = pipeline(task="mask-generation", model="facebook/sam-vit-huge", device=sd.device)  # MaskGenerationPipeline

In [None]:
num_images = 1
extract_positions = sd.available_extract_positions
# extract_positions = ['down_blocks[1]','down_blocks[2]','mid_block','up_blocks[0]','up_blocks[1]']

In [None]:
def make_patches(image, num_patches: int, patch_size: int):
    '''create a tensor of patches from an image'''
    tile_size = image.shape[0] // num_patches
    assert len(image.shape) == 3, f'Expected 3D image, got {len(image.shape)}D'
    assert image.shape[0] % num_patches == 0, f'Image size {image.shape[0]} not divisible by num_patches {num_patches}'
    assert patch_size % 2 == 0, f'Patch size {patch_size} must be even'
    assert patch_size > tile_size, f'Patch size {patch_size} must be larger than tile size {tile_size}'

    patches = []
    for i in range(num_patches):
        patch_row = []
        for j in range(num_patches):
            x_start = i * tile_size - (patch_size - tile_size) // 2
            x_end = x_start + patch_size
            y_start = j * tile_size - (patch_size - tile_size) // 2
            y_end = y_start + patch_size
            if x_start < 0 or x_end > image.shape[0] or y_start < 0 or y_end > image.shape[1]:
                # skip patches going out of bounds
                patches.append(np.full((patch_size, patch_size, image.shape[2]), 0))
                pass
            else:
                patch_row.append(image[x_start:x_end, y_start:y_end, :])
        patches.append(patch_row)
        
    return torch.tensor(patches)

In [None]:
# create (n,n,n,n) tensors for each position
def create_map(): return {pos: torch.zeros([(shape if isinstance(shape[0], int) else shape[0])[-1]**2]*2) for pos, shape in sd.representation_shapes.items()}
maps = defaultdict(create_map)

with torch.no_grad():
    for i in trange(num_images):
        result = sd('a city scene', extract_positions=extract_positions)

        # sam_segmentation = sam(result.result_image)
        sam_masks = torch.tensor(sam_pipe(result.result_image)['masks'])
        sam_embedding = next(sam_pipe.preprocess(result.result_image))['image_embeddings']  # has shape [1, 256, 64, 64]

        for pos, repr in tqdm(list(result.representations.items())):
            repr = repr[-1]  # use last step
            if isinstance(repr, tuple):
                # ignore skip connections
                repr = repr[0]
            repr = repr.reshape((repr.shape[1],-1)).permute(1,0)
            n = int(repr.shape[0]**0.5)
            img = np.array(result.result_image)
            patches = make_patches(img, n, 100)
            # TODO: edge patches are not valid and should be ignored

            # calculate cosine similarity
            cs = torch.einsum('ij,kj->ik', repr, repr) / (repr.norm(dim=1, keepdim=True) @ repr.norm(dim=1, keepdim=True).T)
            maps['CS'][pos] += cs.cpu()

            # calculate l1 distance
            maps['L1'][pos] += torch.cdist(repr.float(), repr.float(), p=1).cpu()

            # calculate l2 distance
            maps['L2'][pos] += torch.cdist(repr.float(), repr.float(), p=2).cpu()

            # calculate l1 distance between images
            # tiled_image = img.reshape(n,img.shape[0]//n,n,img.shape[0]//n,3).transpose(0,2,1,3,4).reshape(n*n,-1)
            torch_img_flat = torch.tensor(patches.reshape(n*n,-1), device=sd.device).float()
            maps['L1_img'][pos] += torch.cdist(torch_img_flat, torch_img_flat, p=1).cpu()

            # calculate l2 distance between images
            maps['L2_img'][pos] += torch.cdist(torch_img_flat, torch_img_flat, p=2).cpu()

            # calculate clip similarity
            # torch_img = torch_img_flat.reshape((n*n,*[img.shape[0]//n]*2,3)).permute(0,3,1,2)
            clip_embedding = clip(patches.reshape(n*n,*patches.shape[2:]).permute(0,3,1,2).to(sd.device))
            maps['Clip'][pos] += torch.einsum('ij,kj->ik', clip_embedding, clip_embedding).cpu()

            # calculate ssim
            # try:
            #     torch_img_ = torch_img.to(sd.device, dtype=torch.float32) / 255
            #     ssim = SSIM().to(sd.device)
            #     maps['SSIM'][pos] += torch.stack([ssim(torch_img_, torch_img_[i].expand_as(torch_img_)) for i in range(n*n)], 0).cpu()
            # except:
            #     pass

            # sklearn ssim (sklearn)
            # maps['SSIM2'][pos] += torch.tensor([[sk_ssim(torch_img_.cpu().numpy()[i], torch_img_.cpu().numpy()[j], gaussian_weights=True, channel_axis=0, data_range=1) for j in range(n*n)] for i in trange(n*n)])

            # calculate segment anything mask based similarity
            sam_classes = sam_masks.reshape(-1, n, img.shape[0]//n, n, img.shape[0]//n).float().mean((2,4)).reshape(-1,n*n)
            maps['SAM'][pos] += torch.einsum('ji,jk->ik', sam_classes, sam_classes)

            # calculate segment anything embedding based similarity
            sam_embedding_ = sam_embedding.reshape(256, n, sam_embedding.shape[-1]//n, n, sam_embedding.shape[-1]//n).mean((2,4)).reshape(256,n*n)
            maps['SAM_Emb'][pos] += torch.einsum('ji,jk->ik', sam_embedding_, sam_embedding_).cpu()


# replace nan with 0 and normalize
for name, m in maps.items():
    for pos, t in m.items():
        t[torch.isnan(t)] = 0
        t /= num_images

In [None]:
# print out the pearson correlation between the different metrics

def print_pearson():
    for pos in extract_positions:
        # go through all combinations of metrics
        for a_name in maps:
            for b_name in maps:
                a = maps[a_name][pos]
                b = maps[b_name][pos]
                pearson = torch.corrcoef(torch.stack((a.flatten(), b.flatten())))[0,1]
                print(f'{pos:<15}: {a_name:^6} vs {b_name:^6}: {pearson.item():+.2f}')

# print_pearson()

In [None]:
# plot the correlation matrices

def plot_corr_matrices():
    matrix = torch.zeros((len(extract_positions), len(maps), len(maps)))
    for pos in extract_positions:
            for i, a_name in enumerate(maps):
                for j, b_name in enumerate(maps):
                    a = maps[a_name][pos]
                    b = maps[b_name][pos]
                    pearson = torch.corrcoef(torch.stack((a.flatten(), b.flatten())))[0,1]
                    matrix[extract_positions.index(pos), i, j] = pearson.item()

    for i, pos in enumerate(extract_positions):
        fig, ax = plt.subplots(1, 1, figsize=(5, 5))
        cax = ax.matshow(matrix[i], cmap='coolwarm', vmin=-1, vmax=1)
        fig.colorbar(cax)
        ax.set_xticks(range(len(maps)))
        ax.set_xticklabels(maps)
        ax.set_yticks(range(len(maps)))
        ax.set_yticklabels(maps)
        ax.set_title('Correlation at position ' + pos)
        plt.show()

# plot_corr_matrices()

In [None]:
def gaussian_kernel(distance, bandwidth):
    """Computes the Gaussian kernel for a given distance and bandwidth using PyTorch."""
    return (1 / (bandwidth * torch.sqrt(torch.tensor(2 * torch.pi)))) * torch.exp(-0.5 * ((distance / bandwidth) ** 2))

def gaussian_binning(x, y, bin_width, bandwidth):
    """
    Applies Gaussian binning to data points using PyTorch.

    Parameters:
    - x: torch.Tensor, the x-values of the data points.
    - y: torch.Tensor, the y-values of the data points.
    - bin_width: float, the width of the bins.
    - bandwidth: float, the bandwidth for the Gaussian kernel.
    
    Returns:
    - bin_centers: torch.Tensor, the centers of the bins.
    - weighted_means: torch.Tensor, the weighted means of the y-values.
    - weighted_stds: torch.Tensor, the weighted standard deviations of the y-values.
    """
    # Ensure input is a float tensor for compatibility
    x, y = x.float(), y.float()
    
    # Define bins and bin centers
    bins = torch.arange(torch.min(x), torch.max(x) + bin_width, bin_width, device=x.device)
    bin_centers = 0.5 * (bins[:-1] + bins[1:])
    
    # Initialize arrays for the weighted means and standard deviations
    weighted_means = torch.zeros(len(bin_centers))
    weighted_stds = torch.zeros(len(bin_centers))

    # Compute weighted mean and std for each bin
    for i, bin_center in enumerate(bin_centers):
        distances = torch.abs(x - bin_center)
        weights = gaussian_kernel(distances, bandwidth)
        weighted_sum = torch.sum(weights * y)
        sum_of_weights = torch.sum(weights)
        weighted_mean = weighted_sum / sum_of_weights
        weighted_means[i] = weighted_mean
        
        # Weighted standard deviation calculation
        variance = torch.sum(weights * (y - weighted_mean) ** 2) / sum_of_weights
        weighted_stds[i] = torch.sqrt(variance)

    return bin_centers, weighted_means, weighted_stds

In [None]:
def plot_similarity_comparison(similarities: list | None = None):
    for pos in extract_positions:
        num_items = len(maps) - 1 if similarities is None else len(similarities)
        plt.figure(figsize=(6*num_items, 5))
        i = 0
        for name, map in maps.items():
            if name == 'CS': continue
            if similarities is not None and name not in similarities: continue
            i += 1

            # setup plot
            plt.subplot(1, num_items, i)
            plt.title(name)
            plt.xlabel('Cosine similarity')

            # setup data
            cs = maps['CS'][pos].cuda()
            idx = cs.flatten().argsort()
            cossim_sorted = cs.flatten()[idx]
            y_sorted = map[pos].cuda().flatten()[idx]

            # plot heatmap
            plt.hist2d(cossim_sorted.cpu().numpy(), y_sorted.cpu().numpy(), bins=100, cmap='Greys', norm='log')
            # plt.hist2d(cossim_sorted.cpu().numpy(), y_sorted.cpu().numpy(), bins=100, cmap='Greys')

            # plot lines
            bin_centers, weighted_means, weighted_stds = gaussian_binning(cossim_sorted, y_sorted, 0.01, 0.02)
            bin_centers = bin_centers.cpu().numpy()
            weighted_means = weighted_means.cpu().numpy()
            weighted_stds = weighted_stds.cpu().numpy()
            plt.fill_between(bin_centers, weighted_means - weighted_stds, weighted_means + weighted_stds, alpha=0.5)
            plt.plot(bin_centers, weighted_means, 'k-')

            # set limits
            plt.ylim(y_sorted.min().item(), y_sorted.max().item())
            plt.xlim(-1, 1)

        # show plot
        plt.suptitle('Position ' + pos)
        plt.show()

plot_similarity_comparison()
# plot_similarity_comparison(['Clip', 'SAM', 'SAM_Emb'])