In [1]:
import numpy as np
import torch
from sklearn.decomposition import PCA

In [2]:
from PIL import Image
from pathlib import Path
from tqdm import tqdm

In [3]:
import clip

## Functional Version

In [8]:
# functional version of code above
def create_mosaic(all_img_paths, dim, pre_crop, img_size, out_name):
    model, preprocess = clip.load('ViT-B/32')
    embeddings = torch.zeros(len(all_img_paths), 512)
    print("Generating clip embeddings")
    with torch.no_grad():
        for idx, p in tqdm(list(enumerate(all_img_paths))):
            img = preprocess(Image.open(p)).unsqueeze(0)
            embeddings[idx] = model.encode_image(img).squeeze()
    X = embeddings.numpy()
    pca = PCA(n_components=2)
    pca_vals = pca.fit_transform(X)
    print(f"Explained variance by first 2 principal components: {pca.explained_variance_ratio_.sum()}")
    comps_with_idx = [{"img_idx": idx, "comps": comps } for idx, comps in enumerate(pca_vals)]
    print("Sorting grid")
    # sort x 
    x_sorted = sorted(comps_with_idx, key=lambda comps: comps["comps"][0])
    grid_sorted = []
    for i in range(dim):
        row = x_sorted[i*dim:(i+1)*dim]
        row_sorted = sorted(row, key=lambda comps: comps["comps"][1])
        grid_sorted.append(row_sorted)
    crop_amt = (pre_crop - img_size) / 2
    total_dim = dim*img_size
    main_img = np.zeros((total_dim,total_dim,4), dtype=np.uint8)
    print("Creating Tiles")
    for y_idx, row in tqdm(list(enumerate(grid_sorted))):
        for x_idx, comps in enumerate(row):
            img_idx = comps["img_idx"]
            img = Image.open(all_img_paths[img_idx])
            resized_img = img.resize((pre_crop, pre_crop))
            if (resized_img.mode != "RGBA"):
                # replace pixels matching alpha_val with transparency
                new_img = np.ones((pre_crop,pre_crop,4), dtype=np.uint8)*255
                new_img[:,:,:3] = resized_img
                # from https://github.com/PWhiddy/PokemonRedExperiments/blob/master/MapWalkingVis.ipynb
                alpha_val = np.array([255, 255,  255, 255], dtype=np.uint8)
                alpha_mask = (new_img == alpha_val).all(axis=2).reshape(pre_crop,pre_crop,1)
                resized_img = Image.fromarray( np.where(alpha_mask, np.array([[[0,0,0,0]]]), new_img).astype(np.uint8) )
            cropped_img = resized_img.crop((crop_amt,crop_amt,pre_crop-crop_amt, pre_crop-crop_amt))
            main_img[
                x_idx*img_size:(x_idx+1)*img_size, 
                y_idx*img_size:(y_idx+1)*img_size] = np.asarray(cropped_img)
    im = Image.fromarray(main_img)
    im.save(out_name)

In [None]:
test_path = Path("shader-park-core.appspot.com/sculptureThumbnails")
all_test_paths = list(test_path.glob("*.png"))
create_mosaic(all_test_paths, 25, 512, 256, "test_out.png")

## Below is original implementation

In [5]:
model, preprocess = clip.load('ViT-B/32')

100%|███████████████████████████████████████| 338M/338M [00:14<00:00, 24.8MiB/s]


In [48]:
!ls shader-park-core.appspot.com/

[34msculptureThumbnails[m[m


In [16]:
path = Path("shader-park-core.appspot.com/sculptureThumbnails")
all_paths = list(path.glob("*.jpeg")) + list(path.glob("*.png"))

In [18]:
embeddings = torch.zeros(len(all_paths), 512)

In [21]:
with torch.no_grad():
    for idx, p in tqdm(enumerate(all_paths)):
        img = preprocess(Image.open(p)).unsqueeze(0)
        embeddings[idx] = model.encode_image(img).squeeze()

1125it [01:08, 16.50it/s]


In [33]:
X = embeddings.numpy()
pca = PCA(n_components=2)
#pca.fit(X)
pca_vals = pca.fit_transform(X)
print(pca.explained_variance_ratio_.sum())

0.16663381


In [42]:
print(pca_vals.shape)
pca_vals[:4]

(1125, 2)


array([[ 1.3595157 , -1.2727263 ],
       [-0.6688118 , -0.9418027 ],
       [ 0.60495144, -0.85944796],
       [ 0.29132822, -0.8286504 ]], dtype=float32)

In [45]:
comps_with_idx = [{"img_idx": idx, "comps": comps } for idx, comps in enumerate(pca_vals)]

In [47]:

# sort x 
x_sorted = sorted(comps_with_idx, key=lambda comps: comps["comps"][0])

In [48]:
# sort y
dim = 34
grid_sorted = []
for i in range(dim):
    row = x_sorted[i*dim:(i+1)*dim]
    row_sorted = sorted(row, key=lambda comps: comps["comps"][1])
    grid_sorted.append(row_sorted)


In [56]:
# clip sorted version 

dim = 34
pre_crop = 512
img_size = 256
crop_amt = (pre_crop - img_size) / 2
total_dim = dim*img_size
main_img = np.zeros((total_dim,total_dim,4), dtype=np.uint8)
for y_idx, row in tqdm(enumerate(grid_sorted)):
    for x_idx, comps in enumerate(row):
        img_idx = comps["img_idx"]
        img = Image.open(all_paths[img_idx])
        resized_img = img.resize((pre_crop, pre_crop))
        if (resized_img.mode != "RGBA"):
            new_img = np.ones((pre_crop,pre_crop,4), dtype=np.uint8)*255
            new_img[:,:,:3] = resized_img
            # from https://github.com/PWhiddy/PokemonRedExperiments/blob/master/MapWalkingVis.ipynb
            alpha_val = np.array([255, 255,  255, 255], dtype=np.uint8)
            alpha_mask = (new_img == alpha_val).all(axis=2).reshape(pre_crop,pre_crop,1)
            resized_img = Image.fromarray( np.where(alpha_mask, np.array([[[0,0,0,0]]]), new_img).astype(np.uint8) )
        #print((crop_amt,pre_crop-crop_amt,pre_crop-crop_amt, crop_amt))
        # left, top, right, bottom
        cropped_img = resized_img.crop((crop_amt,crop_amt,pre_crop-crop_amt, pre_crop-crop_amt))
        main_img[
            x_idx*img_size:(x_idx+1)*img_size, 
            y_idx*img_size:(y_idx+1)*img_size] = np.asarray(cropped_img)
im = Image.fromarray(main_img)
im.save("grid_full_sorted_crop.png")

34it [00:38,  1.14s/it]


In [64]:
# full version
dim = 34
img_size = 512
total_dim = dim*img_size
main_img = np.zeros((total_dim,total_dim,4), dtype=np.uint8)
x_idx = 0
y_idx = 0
all_paths = list(path.glob("*.jpeg")) + list(path.glob("*.png"))
for p in tqdm(all_paths):
    img = Image.open(p)
    resized_img = img.resize((img_size, img_size))
    if (resized_img.mode != "RGBA"):
        new_img = np.ones((img_size,img_size,4), dtype=np.uint8)*255
        new_img[:,:,:3] = resized_img
        # from https://github.com/PWhiddy/PokemonRedExperiments/blob/master/MapWalkingVis.ipynb
        alpha_val = np.array([255, 255,  255, 255], dtype=np.uint8)
        alpha_mask = (new_img == alpha_val).all(axis=2).reshape(img_size,img_size,1)
        resized_img = np.where(alpha_mask, np.array([[[0,0,0,0]]]), new_img)
    main_img[
        x_idx*img_size:(x_idx+1)*img_size, 
        y_idx*img_size:(y_idx+1)*img_size] = np.asarray(resized_img)
    x_idx += 1
    if x_idx >= dim:
        x_idx = 0
        y_idx += 1
im = Image.fromarray(main_img)
im.save("grid_full.png")

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1125/1125 [00:40<00:00, 28.10it/s]


In [50]:
total_dim

17408

In [63]:
# cropped version
dim = 34
pre_crop = 512
img_size = 256
crop_amt = (pre_crop - img_size) / 2
total_dim = dim*img_size
main_img = np.zeros((total_dim,total_dim,4), dtype=np.uint8)
x_idx = 0
y_idx = 0
all_paths = list(path.glob("*.jpeg")) + list(path.glob("*.png"))
#all_paths = all_paths[:20]
for p in tqdm(all_paths):
    img = Image.open(p)
    resized_img = img.resize((pre_crop, pre_crop))
    if (resized_img.mode != "RGBA"):
        new_img = np.ones((pre_crop,pre_crop,4), dtype=np.uint8)*255
        new_img[:,:,:3] = resized_img
        # from https://github.com/PWhiddy/PokemonRedExperiments/blob/master/MapWalkingVis.ipynb
        alpha_val = np.array([255, 255,  255, 255], dtype=np.uint8)
        alpha_mask = (new_img == alpha_val).all(axis=2).reshape(pre_crop,pre_crop,1)
        resized_img = Image.fromarray( np.where(alpha_mask, np.array([[[0,0,0,0]]]), new_img).astype(np.uint8) )
    #print((crop_amt,pre_crop-crop_amt,pre_crop-crop_amt, crop_amt))
    # left, top, right, bottom
    cropped_img = resized_img.crop((crop_amt,crop_amt,pre_crop-crop_amt, pre_crop-crop_amt))
    main_img[
        x_idx*img_size:(x_idx+1)*img_size, 
        y_idx*img_size:(y_idx+1)*img_size] = np.asarray(cropped_img)
    x_idx += 1
    if x_idx >= dim:
        x_idx = 0
        y_idx += 1
im = Image.fromarray(main_img)
im.save("grid_cropped.png")

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1125/1125 [00:38<00:00, 28.90it/s]
