# Script for making coco_humans sample data set

First, load up a pretrained CLIP model.

In [None]:
import torch
import clip
from PIL import Image
import requests
from baukit import save_image_set
from baukit import ImageFolderSet, show, move_to
import torchvision

torch.set_grad_enabled(False)

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)

Next, load the COCO 2017 data set.  We will use the training data split.

In [None]:
input_image_dir = '/share/data/datasets/coco/images/train2017'
ds = ImageFolderSet(input_image_dir, transform=preprocess, identification=True)
ds_cropped = ImageFolderSet(input_image_dir,
                            transform=torchvision.transforms.Compose(
                                preprocess.transforms[:2]))

Precompute CLIP features for a couple positive prompts and a couple negative prompts.

In [None]:
positive_text = clip.tokenize(["a picture of a woman or a man", "a professional at work"]).to(device)
positive_features = model.encode_text(positive_text)
negative_text = clip.tokenize(["a picture of some junk", "a cluttered mess"]).to(device)
negative_features = model.encode_text(negative_text)

Demo of tokenizer

In [None]:
tokenizer = clip.simple_tokenizer.SimpleTokenizer()
tokenizer.decode(positive_text[0].tolist())

Collect top 2000 scoring images, adding positive similarities and subtracting negative ones.

In [None]:
from torch.utils.data import DataLoader
from torch.nn.functional import cosine_similarity
from baukit import TopK, pbar

stat = TopK(k=2000)
for batch in pbar(DataLoader(ds, batch_size=100, pin_memory=True, num_workers=30)):
    [[images, indexes]] = move_to(device, batch)
    image_features = model.encode_image(images)
    sim = sum(cosine_similarity(image_features, f[None], dim=1) for f in positive_features)
    neg = sum(cosine_similarity(image_features, f[None], dim=1) for f in negative_features)
    stat.add(sim - neg, indexes)

    

    
    

In [None]:
show(show.WRAP, [[show.style(maxWidth=120), ds_cropped[i]] for i in stat.topk()[1][0:12]])

In [None]:
import os, random
from baukit import WorkerPool, save_image_set

os.makedirs('coco_humans', exist_ok=True)
save_image_set(random.sample([ds_cropped[i][0] for i in stat.topk()[1]], len(stat.topk()[1])),
               'coco_humans/image_{0}.jpg')

In [None]:
from importlib import reload
import baukit.workerpool, baukit.imgsave
reload(baukit.workerpool)
reload(baukit.imgsave)

In [None]:
baukit.workerpool

In [None]:
%load_ext autoreload
%autoreload 2
