In [None]:
from sdhelper import SD
from datasets import load_dataset
from transformers import pipeline
from transformers import SamModel, SamProcessor
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import torch
from semantic_correspondence import expand_and_resize, expand, concat_reprs
from PIL import Image
import json
import random

In [None]:
# load models
sd = SD('SDXL')
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [None]:
# load data
pairs = load_dataset('0jl/SPair-71k')
img_data = load_dataset('0jl/SPair-71k', 'data')

In [None]:
# test SAM model
i = 0
raw_image = img_data['train'][i]['img']
input_points = [[json.loads(img_data['train'][i]['annotation'])['kps']['0']]] # 2D localization of a window

with torch.no_grad():
    inputs = sam_processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
    outputs = sam_model(**inputs)
    masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
    scores = outputs.iou_scores
print(f'IOU score: {scores[0][0].cpu()}')
print(f'Found {len(masks[0])} masks')
print(f'Predicted mask shape: {masks[0][0].shape}')

mask = Image.fromarray((masks[0][0].permute(1,2,0).detach().cpu().numpy().sum(axis=2)>0).astype(np.uint8)*255)
mask = mask.resize((mask.size[0]//8, mask.size[1]//8))
mask = np.array(mask) > 0
plt.subplot(1,2,1)
plt.imshow(mask, cmap='gray')
plt.scatter(input_points[0][0][0]/8, input_points[0][0][1]/8, c='r')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(raw_image)
plt.axis('off')
plt.show()

In [None]:
def plot_pca(imgs: list, pos = ['up_blocks[1]'], step = 50, size = 512, segmentations = None):
    imgs_expanded = [expand(img, size) for img in imgs]
    reprs_ = [concat_reprs(sd.img2repr(img, pos, step), pos) for img in imgs_expanded]
    reprs = [repr.reshape(repr.shape[0], -1).T for repr in reprs_]
    pca = PCA(n_components=3)
    pca.fit(torch.cat(reprs, 0).numpy())
    pcas = [pca.transform(repr) for repr in reprs]
    pca_min = min([pca.min() for pca in pcas])
    pcas = [pca - pca_min for pca in pcas]
    pca_max = max([pca.max() for pca in pcas])
    pcas = [pca / pca_max for pca in pcas]
    masked_reprs = []
    if segmentations is not None:
        masks = [np.array(segmentation.resize((repr.shape[2], repr.shape[1]))) for segmentation, repr in zip(segmentations, reprs_)]
        for repr, mask in zip(reprs_, masks):
            masked_repr = np.array([repr[:,i,j] for i, j in np.argwhere(mask)])
            masked_reprs.append(masked_repr)
        pca2 = PCA(n_components=3)
        pca2.fit(np.concatenate(masked_reprs, 0))
        pcas2 = [pca2.transform(masked_repr) for masked_repr in masked_reprs]
        pca_min = min([pca.min() for pca in pcas2])
        pcas2 = [pca - pca_min for pca in pcas2]
        pca_max = max([pca.max() for pca in pcas2])
        pcas2 = [pca / pca_max for pca in pcas2]

    fig, ax = plt.subplots(2 + (segmentations is not None), len(imgs), squeeze=False, figsize=(len(imgs)*5, 10))
    for i, img in enumerate(imgs):
        ax[0, i].imshow(img)
        ax[1, i].imshow(pcas[i].reshape(*reprs_[i].shape[1:],3)[:,:,:])
        ax[0, i].axis('off')
        ax[1, i].axis('off')
        if segmentations is not None:
            mask = masks[i]
            tmp = np.zeros((*mask.shape,3))
            for (u,v), x in zip(np.argwhere(mask), pcas2[i]):
                tmp[u,v,:] = x
            ax[2, i].imshow(tmp)
            ax[2, i].axis('off')
    
    plt.tight_layout()
    plt.show()


for category in pairs['train'].features['category'].names:
    imgs = []
    segmentations = []
    for x in img_data['train']:
        an = json.loads(x['annotation'])
        if an['category'] == category:
            imgs.append(x['img'])

            # get segmentation using SAM
            # with torch.no_grad():
            #     kp = [kp for kp in an['kps'].values() if kp is not None][0]
            #     inputs = sam_processor(x['img'], input_points=[[kp]], return_tensors="pt").to("cuda")
            #     outputs = sam_model(**inputs)
            #     masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
            # mask = Image.fromarray((masks[0][0].permute(1,2,0).detach().cpu().numpy().sum(axis=2)>0).astype(np.uint8)*255)
            # segmentations.append(mask)

            # get given segmentation
            segmentations.append(x['segmentation'])

        if len(imgs) == 6:
            break

    plot_pca(imgs, ['up_blocks[0]', 'up_blocks[1]'], 50, 1024, segmentations)

In [None]:
# check segmentations
for _ in range(4):
    i = random.randint(0, len(img_data['train']))
    x = img_data['train'][i]
    seg = x['segmentation']
    print(np.array(seg).shape)
    plt.subplot(1,2,1)
    plt.imshow(np.array(seg)>0)
    plt.axis('off')
    plt.subplot(1,2,2)
    plt.imshow(x['img'])
    plt.axis('off')
    plt.show()