In [1]:
import pandas as pd
import numpy as np

In [19]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import MiniBatchKMeans

In [20]:
def reduce_img_colours(img_path, compression = 1, num_colours = 5):
    
    img = Image.open(img_path)
    
    new_size = (np.array(img.size) * compression).astype(int)
    
    img_array = np.array(img.resize(new_size))
    
    pixels = img_array.reshape((-1, 3))
    
    kmeans = MiniBatchKMeans(num_colours).fit(pixels)
    
    new_pixels = kmeans.cluster_centers_[kmeans.labels_]
    
    new_img = new_pixels.reshape((new_size[1], new_size[0], 3))
    
    print(f'# Old colours: {len(np.unique(pixels, axis = 0))}')
    print(f'# New colours: {num_colours}')
    
    return new_img

In [26]:
# https://github.com/google-research-datasets/conceptual-captions/blob/master/LICENSE
# https://ai.google.com/research/ConceptualCaptions/download
data = pd.read_csv('../raw_data/Validation_GCC-1.1.0-Validation.tsv', sep = '\t', header = None)

In [27]:
data.columns = ['caption', 'url']

In [174]:
import requests
from PIL import Image
import PIL
import io
import torch
import clip

def get_pil_image(url):
    
    try:
        response = requests.get(url)
        if response.status_code == 200:
                return Image.open(io.BytesIO(response.content))
    
    except requests.exceptions.RequestException:
        return None
    except PIL.UnidentifiedImageError:
        return None
#     except RuntimeError:
#         return None


def square_image(pil_img):
    
    short_side = min(pil_img.size)
    
    img_width, img_height = pil_img.size
    return pil_img.crop(((img_width - short_side) // 2,
                         (img_height - short_side) // 2,
                         (img_width + short_side) // 2,
                         (img_height + short_side) // 2))

def simplify_image(pil_img, num_colors = 8):

    img = pil_img.quantize(colors=num_colors, method=2)
    
    sq_img = square_image(img).resize((256, 256), Image.LANCZOS)
    
    return sq_img
    
def save_img(pil_img, caption, destination_folder):
    
    file_name = f"{destination_folder}/{caption.replace('.', '')}.png"
    
    #print('saving to ', file_name)
    
    pil_img.save(file_name)


model, preprocess = clip.load("ViT-B/32", device=device)


def clip_encode_image(pil_img):
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    image = preprocess(pil_img).unsqueeze(0).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image)

    return image_features



In [175]:
from tqdm import tqdm

#features_list = []
#image_path_list = []

for i, row in tqdm(data[1100:2000].iterrows()):
    
    url, caption = row['url'], row['caption']
    
    #print(url)

    img = get_pil_image(url)
    
    if img:

        features_list.append(clip_encode_image(img)[0])

        simple_img = simplify_image(img, num_colors = 256)

        save_img(simple_img, caption, destination_folder = '../raw_data/images')
        
        image_path_list.append(f"{caption.replace('.', '')}.png")
    if len(features_list) % 10 == 0:
        print(len(features_list))
        
        

#image = get_simplified_image(url, caption, destination_folder = '../raw_data/images', num_colors = 256)

7it [00:09,  1.36s/it]

160


19it [01:01,  6.98s/it]

170


30it [01:18,  1.35s/it]

180


41it [01:39,  2.02s/it]

190


53it [07:50,  6.38s/it] 

200


65it [08:12,  2.62s/it]

210


77it [08:37,  1.60s/it]

220


91it [10:09, 12.15s/it]

230


102it [10:52,  3.73s/it]

240


112it [11:02,  1.07s/it]

250


123it [12:30, 12.06s/it]

260


134it [12:52,  1.83s/it]

270


146it [13:09,  1.69s/it]

280


158it [13:38,  2.08s/it]

290


172it [15:15,  6.90s/it]

300


183it [15:35,  1.58s/it]

310


195it [15:57,  1.85s/it]

320


209it [16:20,  1.37s/it]

330
330


210it [16:24,  2.16s/it]

330


223it [16:43,  1.72s/it]

340


236it [17:02,  1.40s/it]

350


246it [17:18,  1.18s/it]

360


257it [17:34,  1.80s/it]

370


270it [17:55,  1.69s/it]

380


271it [17:58,  2.35s/it]

380


272it [18:00,  2.12s/it]

380


284it [19:27,  5.08s/it]

390


294it [19:46,  1.35s/it]

400


306it [20:09,  2.43s/it]

410


322it [20:38,  1.71s/it]

420


334it [21:00,  2.23s/it]

430


335it [21:01,  1.82s/it]

430


336it [21:02,  1.53s/it]

430


350it [21:18,  1.52s/it]

440


363it [21:30,  1.23it/s]

450


375it [22:29,  2.30s/it]

460


386it [22:45,  1.97s/it]

470


396it [23:06,  1.91s/it]

480


411it [23:22,  1.08s/it]

490


423it [24:24,  4.33s/it]

500


436it [24:43,  1.35s/it]

510


437it [24:44,  1.12s/it]

510


449it [28:57, 18.49s/it]

520


450it [29:01, 14.12s/it]

520


464it [29:24,  1.43s/it]

530


476it [29:43,  1.52s/it]

540


489it [31:27, 17.73s/it]

550


501it [31:46,  2.05s/it]

560


513it [32:10,  2.19s/it]

570


523it [32:32,  1.91s/it]

580


523it [52:15,  5.99s/it]


KeyboardInterrupt: 

In [176]:
image_embeddings = np.stack(features_list)

In [177]:
image_embeddings.shape

X_image_embeddings = torch.from_numpy(image_embeddings)
X_image_embeddings

tensor([[ 0.0410, -0.1044, -0.3145,  ...,  0.4066,  0.1564,  0.1414],
        [-0.0344,  0.1190,  0.2218,  ...,  0.6305, -0.3967, -0.0071],
        [-0.2324,  0.0697,  0.0915,  ...,  0.3114,  0.0565, -0.3759],
        ...,
        [-0.3113, -0.0241, -0.1355,  ...,  0.7593, -0.2546, -0.2554],
        [ 0.0336, -0.1806, -0.0556,  ...,  0.2804,  0.2809,  0.1955],
        [-0.6176, -0.1579, -0.5416,  ...,  0.7598,  0.0930,  0.0875]])

In [163]:
# mnist = pymde.datasets.MNIST()
# mnist.data

In [178]:
import pymde

mde = pymde.preserve_neighbors(X_image_embeddings, embedding_dim=2, verbose=True)
embedding = mde.embed(verbose=True)

Jul 02 01:25:39 AM: Computing 5-nearest neighbors, with max_distance=None
Jul 02 01:25:39 AM: Exact nearest neighbors by brute force 
Jul 02 01:25:39 AM: Computing quadratic initialization.
Jul 02 01:25:39 AM: Fitting a centered embedding into R^2, for a graph with 580 items and 4213 edges.
Jul 02 01:25:39 AM: `embed` method parameters: eps=1.0e-05, max_iter=300, memory_size=10
Jul 02 01:25:39 AM: iteration 000 | distortion 0.529098 | residual norm 0.0905655 | step length 0.51596 | percent change 0.137199
Jul 02 01:25:39 AM: iteration 030 | distortion 0.280032 | residual norm 0.00242741 | step length 1 | percent change 0.989663
Jul 02 01:25:39 AM: iteration 060 | distortion 0.277166 | residual norm 0.00110947 | step length 1 | percent change 0.894307
Jul 02 01:25:40 AM: iteration 090 | distortion 0.275297 | residual norm 0.000927837 | step length 1 | percent change 0.513152
Jul 02 01:25:40 AM: iteration 120 | distortion 0.274814 | residual norm 0.000516245 | step length 1 | percent cha

In [179]:
image_path_list

['author : a life in photography -- in pictures.png',
 'the player staring intently at a computer screen .png',
 'the - bedroom stone cottage can sleep people.png',
 'party in the park under cherry blossoms.png',
 'a man holds what is believed to be some of the debris that caused damage to vehicles monday afternoon after airliner returned to airport following problems after take off .png',
 "where 's the best place to show off your nails ? right in front of the castle , of course !.png",
 'that combines elements of a simple vegetable and dish.png',
 'transformers : till all are issue # 4b.png',
 'illustration of a little girl taking a bath.png',
 'tv police procedural is filming on the street this week .png',
 'the new terminal on island with quiet sea and setting sun.png',
 'what makeup to wear to a job interview.png',
 'the dentist drill the tooth with a turbine.png',
 'pair of new bright orange modern sneakers isolated on a white background.png',
 'ask industry to do in our family r

In [191]:

Y_coords = embedding.numpy()[:, 1]

In [219]:
def embedding_to_coords(embedding, axis):
    
    if axis == 'X':
        col = 0
    elif axis == 'Y':
        col = 1
    else:
        print("coord must be X or Y")
    
    coords = embedding.numpy()[:, col]

    from_one = coords - min(coords)

    return [int(coord) for coord in (from_one / max(from_one) * 1000)]

In [221]:
X_coords = embedding_to_coords(embedding, 'X')
type(X_coords[0])

int

In [222]:
import json

X_coords = embedding_to_coords(embedding, 'X')
Y_coords = embedding_to_coords(embedding, 'Y')

map_json = {
    "X_coords": X_coords,
    "Y_coords": Y_coords,
    "image_paths": image_path_list
}
map_json['image_paths'][0]
with open('../raw_data/latent_space_map.json', 'w') as f:
    json.dump(map_json, f)
