## Setup

In [None]:
from sd3helper import SD3
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from datasets import load_dataset
import json

In [None]:
# load model
sd = SD3()

In [None]:
# load data
data = load_dataset('0jl/SPair-71k', 'data', split='train')

In [None]:
# Select Images
# category = 'cat'
category = 'dog'
image_index = -1


In [None]:
# cell to find good candidates
# execute this cell multiple times to find immages with good segmentations

image_index += 1
x = data[image_index]
cat = json.loads(x['annotation'])['category']
while cat != category:
    image_index += 1
    x = data[image_index]
    cat = json.loads(x['annotation'])['category']

# show image and segmentation
plt.subplot(1, 2, 1)
plt.imshow(x['img'])
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(np.array(x['segmentation']) > 0, cmap='gray')
plt.axis('off')
plt.show()
print('category:', cat)
print('image index:', image_index)

In [None]:
image_indices_cat = [170, 177, 243, 257, 267, 283]
image_indices_dog = [390, 51, 120, 131, 142, 353, 386]
image_indices = image_indices_cat  # chage this


In [None]:
plt.imshow(data[image_indices[0]]['img'])
data[image_indices[0]]['img'].size


In [None]:
center_cat = (280,240)
center_dog = (280,200)
center = center_cat  # chage this


In [None]:
data[image_indices[0]]['img'].size

In [None]:
# Color Transfer Plot
def get_repr(img: "Image.Image"):
    factor = 1024 / max(img.size)
    img = img.resize((int(img.size[0]*factor), int(img.size[1]*factor)))
    reprs = sd.get_repr(img, step=950, prompt='A photo of a cat')
    return reprs.mean(dim=0)

# transfer color to other images using cosine similarity
def plot_color_transfer(images, segmentations):
    # color source image
    base_repr = get_repr(images[0]).to(dtype=torch.float32)
    base_repr = base_repr / torch.norm(base_repr, dim=2, keepdim=True)  # normalize now to avoid overflow in cosine similarity
    h_base, w_base = base_repr.shape[:2]  # get repr. height and width
    scaled_mask = np.array(segmentations[0].resize((w_base, h_base), Image.BILINEAR)) > 0
    # circular color map
    color_matrix_circular = np.zeros((h_base, w_base, 3))
    for i, j in np.argwhere(scaled_mask):
        # calculate color based on angle to center
        angle = np.arctan2(i - center[1]*h_base/images[0].size[1], j - center[0]*w_base/images[0].size[0])
        color_matrix_circular[i, j] = np.array([.5+.5*np.sin(angle+0), .5+.5*np.sin(angle+2*np.pi/3), .5+.5*np.sin(angle+4*np.pi/3)])
    # high frequency color map
    color_matrix_highfreq = np.zeros((h_base, w_base, 3))
    for i, j in np.argwhere(scaled_mask):
        color_matrix_highfreq[i, j] = np.array([.5+.5*np.sin(2*np.pi*i/h_base*10), .5+.5*np.sin(2*np.pi*j/w_base*10), .5+.5*np.sin(2*np.pi*(i+j)/(h_base+w_base)*5)])

    # plot base image
    plt.figure(figsize=(4*len(images), 12))
    plt.subplot(4, len(images), 1)
    plt.imshow(images[0])
    plt.axis('off')
    plt.subplot(4, len(images), len(images)+1)
    plt.imshow(color_matrix_circular)
    plt.axis('off')
    plt.subplot(4, len(images), 2*len(images)+1)
    plt.imshow(color_matrix_highfreq)
    plt.axis('off')
    plt.subplot(4, len(images), 3*len(images)+1)
    plt.imshow(images[0])
    plt.axis('off')

    # plot other images
    for plt_index, (img, seg) in enumerate(zip(images[1:], segmentations[1:]), 2):
        repr = get_repr(img).to(dtype=torch.float32)
        repr = repr / torch.norm(repr, dim=2, keepdim=True)  # normalize now to avoid overflow in cosine similarity
        h_curr, w_curr = repr.shape[:2]  # get repr. height and width
        mask = np.array(seg.resize((w_curr, h_curr), Image.BILINEAR)) > 0
        curr_color_matrix_circular = np.zeros((h_curr, w_curr, 3))
        curr_color_matrix_highfreq = np.zeros((h_curr, w_curr, 3))
        img_transferred = np.array(img)
        # img_transferred = np.zeros_like(np.array(img))  # comment out for hiding the original image
        for i, j in np.argwhere(mask):
            # find most similar point in base_repr using cosine similarity
            similarity_matrix = (repr[i, j, None, None, :] * base_repr).sum(dim=2)
            # similarity_matrix *= torch.tensor(scaled_mask).to(similarity_matrix.device)  # comment out to use all pixels (ignore mask)
            best = torch.argmax(similarity_matrix)
            curr_color_matrix_circular[i,j] = color_matrix_circular.reshape((-1,3))[best]
            curr_color_matrix_highfreq[i,j] = color_matrix_highfreq.reshape((-1,3))[best]
            repr_pixel_size_x = images[0].size[0]/w_base
            repr_pixel_size_y = images[0].size[1]/h_base
            img_patch = np.array(images[0])[int(best//w_base*repr_pixel_size_y):int((best//w_base+1)*repr_pixel_size_y), int(best%w_base*repr_pixel_size_x):int((best%w_base+1)*repr_pixel_size_x)]
            repr_pixel_size_x = img.size[0]/w_curr
            repr_pixel_size_y = img.size[1]/h_curr
            try:
                img_transferred[int(i*repr_pixel_size_x):int((i+1)*repr_pixel_size_x), int(j*repr_pixel_size_y):int((j+1)*repr_pixel_size_y)] = np.array(Image.fromarray(img_patch).resize((int((j+1)*repr_pixel_size_y)-int(j*repr_pixel_size_y), int((i+1)*repr_pixel_size_x)-int(i*repr_pixel_size_x)), Image.BILINEAR))
            except ValueError as e:
                # print(plt_index, i, j, e)
                pass
        plt.subplot(4, len(images), plt_index)
        plt.imshow(img)
        plt.axis('off')
        plt.subplot(4, len(images), len(images)+plt_index)
        plt.imshow(curr_color_matrix_circular)
        plt.axis('off')
        plt.subplot(4, len(images), 2*len(images)+plt_index)
        plt.imshow(curr_color_matrix_highfreq)
        plt.axis('off')
        plt.subplot(4, len(images), 3*len(images)+plt_index)
        plt.imshow(img_transferred)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# plot_color_transfer([data[i]['img'] for i in image_indices], [data[i]['segmentation'] for i in image_indices])

x = data[image_indices[0]]
plot_color_transfer([x['img'], x['img'], ImageOps.mirror(x['img'])], [x['segmentation'], x['segmentation'], ImageOps.mirror(x['segmentation'])])

In [None]:
def get_repr(img: "Image.Image", pos = [0]):
    factor = 1024 / max(img.size)
    img = img.resize((int(img.size[0]*factor), int(img.size[1]*factor)))
    return sd.get_repr(img, step=950, prompt='A photo of a cat')[pos].mean(dim=0).permute(2, 0, 1)

# transfer color to other images using cosine similarity
def plot_color_transfer_over_pos(images, segmentations, pos=[[x] for x in range(24)]):

    # create figure for plotting
    plt.figure(figsize=(3*len(image_indices), (1+len(pos))*2))

    # plot raw images
    for i in range(len(images)):
        plt.subplot(len(pos)+1, len(image_indices), i+1)
        plt.imshow(images[i])
        plt.axis('off')
        if i == 0:
            plt.text(-0.1, 0.5, 'Original Image', va='center', ha='right', fontsize=12, transform=plt.gca().transAxes)

    # plot other repr
    for pos_index, p in enumerate(pos, 1):
        # color source image
        base_repr = get_repr(images[0], p).to(dtype=torch.float32)
        base_repr = base_repr / torch.norm(base_repr, dim=0)
        scaled_mask = np.array(segmentations[0].resize((base_repr.shape[2], base_repr.shape[1]), Image.BILINEAR)) > 0
        color_matrix = np.zeros((base_repr.shape[1], base_repr.shape[2], 3))
        for i, j in np.argwhere(scaled_mask):
            # calculate color based on angle to center
            angle = np.arctan2(i - center[1]*base_repr.shape[1]/images[0].size[1], j - center[0]*base_repr.shape[2]/images[0].size[0])
            color_matrix[i, j] = np.array([.5+.5*np.sin(angle+0), .5+.5*np.sin(angle+2*np.pi/3), .5+.5*np.sin(angle+4*np.pi/3)])

        # plot base repr
        plt.subplot(len(pos)+1, len(image_indices), pos_index*len(image_indices)+1)
        plt.imshow(color_matrix)
        plt.axis('off')
        plt.text(-0.1, 0.5, str(p), va='center', ha='right', fontsize=12, transform=plt.gca().transAxes)

        for plt_index, (img, seg) in enumerate(zip(images[1:], segmentations[1:]), 2):
            repr = get_repr(img, pos=p).to(dtype=torch.float32)
            repr = repr / torch.norm(repr, dim=0)
            mask = np.array(seg.resize((repr.shape[2], repr.shape[1]), Image.BILINEAR)) > 0
            curr_color_matrix = np.zeros((repr.shape[1], repr.shape[2], 3))
            for i, j in np.argwhere(mask):
                # find most similar point in base_repr using cosine similarity
                similarity_matrix = (repr[:, i, j, None, None] * base_repr).sum(dim=0)
                similarity_matrix *= torch.tensor(scaled_mask).to(similarity_matrix.device)  # comment out to use all pixels (ignore mask)
                best = torch.argmax(similarity_matrix)
                curr_color_matrix[i,j] = color_matrix.reshape((-1,3))[best]
            plt.subplot(len(pos)+1, len(image_indices), pos_index*len(image_indices)+plt_index)
            plt.imshow(curr_color_matrix)
            plt.axis('off')
    plt.tight_layout()
    plt.show()

# plot_color_transfer_over_pos([data[i]['img'] for i in image_indices], [data[i]['segmentation'] for i in image_indices])

# x = data[image_indices[0]]
plot_color_transfer_over_pos([x['img'], ImageOps.mirror(x['img'])], [x['segmentation'], ImageOps.mirror(x['segmentation'])])