# CLIP for selection suggestions

CLIP is a model for comparing natural language with images. In this notebook, I want to check if it can be used as part of our selection UI. This is a qualitative study.

The CLIP model outputs embeddings (e_text, e_image) for a (text, image) pair. The dot product between the e_text and e_image is the similarity score. For a given text prompt and a graphic, I'll show the best node in the graphic's tree. I'll start out with ground truth annotations for graphics. 

In [None]:
from vectorrvnn.utils import *
from vectorrvnn.data import *
from vectorrvnn.utils import *
from vectorrvnn.baselines import *
from vectorrvnn.trainutils import *
from vectorrvnn.interfaces import *

opts = Options().parse(testing=[
    '--backbone', 'resnet18',
    '--checkpoints_dir', '../results',
    '--dataroot', '../data/All',
    '--embedding_size', '64',
    '--hidden_size', '128', '128',
    '--load_ckpt', 'expt1/training_end.pth',
    '--loss', 'cosineSimilarity',
    '--modelcls', 'ThreeBranch',
    '--name', 'test',
    '--sim_criteria', 'negativeCosineSimilarity',
    '--device', 'cuda:0',
    '--phase', 'test',
    '--temperature', '0.1',
    '--use_layer_norm', 'true',
    '--seed', '0',
])
setSeed(opts)
_, _, _, _, data = buildData(opts)
grouper = buildModel(opts)


In [None]:
DATA_DIR = '../data/PublicDomainVectors'

svgFiles = [f for f in allfiles(DATA_DIR) if f.endswith('svg')][:800]
publicDomain  = [SVGData(_) for _ in svgFiles]
# Filter out graphics with too many paths. 
publicDomain = [_ for _ in publicDomain if _.nPaths < 40] 

In [None]:
device = "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
model = model.float()

def matchingNode (text, tree) :
    pp_text = clip.tokenize([text]).to(device)
    pathSets = [tree.nodes[n]['pathSet'] for n in tree.nodes]
    subdocs = [subsetSvg(tree.doc, ps) for ps in pathSets]
    rasters = [rasterize(sd, 256, 256) for sd in subdocs]
    bit8 = [(n * 255).astype(np.uint8) for n in rasters]
    images = [Image.fromarray(b, 'RGBA') for b in bit8]
    pp_image = torch.stack([preprocess(im) for im in images]).to(device)
    with torch.no_grad() : 
        logits_per_image, _ = model(pp_image, pp_text)
        probs = logits_per_image.softmax(dim=0).cpu().numpy()
        probs = probs.reshape(-1)
    top3 = probs.argsort()[-3:][::-1]
    fig, axes = plt.subplots(1, 3)
    print("Showing top 3 matches for prompt -", text)
    for i, ax in enumerate(axes) : 
        ax.imshow(bit8[top3[i]])
    plt.show()

In [None]:
dataId = 203

print("Showing whole graphic")
plt.imshow(rasterize(publicDomain[dataId].doc, 256, 256))
plt.show()
matchingNode("The polar bear's head", grouper.containmentGuidedTree(publicDomain[dataId]))