In [6]:
import os
import glob
import numpy as np 
import torch.nn as nn 
import torch.nn.functional as F
from tqdm.notebook import tqdm
import torchvision.models as models
from torchvision import datasets, transforms as T
from PIL import Image

In [7]:
feature_generator = models.resnext50_32x4d(pretrained=True)
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
transform = T.Compose([T.Resize(256),
                       T.CenterCrop(224),
                       T.ToTensor(),
                       normalize])
# to generate features disable last layer 
feature_generator.fc = nn.Sequential()
_ = feature_generator.eval()

In [3]:
import difflib
import ujson as json
from typing import Iterable, Dict, Any, Tuple, List

src = '/Users/jm/data/wikiart/images'
meta = '/Users/jm/data/wikiart/meta/'
CODED_TAGS = ['style', 'genre', 'material']


def locate_painting(meta_folder: str, full_path: str) -> Tuple[List[str], str]:
    p = full_path.split(os.sep)
    author = p[-2]
    name = p[-1].replace(".jpg", "")
    if 'untitled' in name.lower():
        return None, None
    author = json.load(open(os.path.join(meta_folder, author + ".json"), 'r'))
    titles = [p['title'] for p in author if 'title' in p]
    match = difflib.get_close_matches(name, titles, n=1, cutoff=0.7)
    if not match:
        return None, None
    target_painting = author[titles.index(match[0])]
    # get tags here
    ensemble_tags = []
    tags = target_painting.get("tags", None)
    if tags:
        tags = tags.lower().split(', ')
        ensemble_tags.extend(tags)
    for subtag in CODED_TAGS:
        st = target_painting.get(subtag, None)
        if st:
            st = st.lower().strip()
            if ',' in st:
                st = st.split(', ')
                ensemble_tags.extend(st)
            else:
                ensemble_tags.append(st)
    return ensemble_tags, target_painting['image']


def stream_features(index: str, src_imgs: os.PathLike, src_meta: os.PathLike) -> Iterable[Dict[str, Any]]:
    max_count = 5000
    for img_fn in tqdm(glob.glob(os.path.join(src_imgs, "*/*.jpg")),
                       desc='Generating features'):
        try:
            metadata, img_url = locate_painting(
                meta_folder=src_meta, full_path=img_fn)
            if not metadata:
                continue
            feature = generate_feature(img_fn)

            yield {
                "_index": index,
                "_source": {
                    "img_vector": feature.tolist()[0],
                    "tags": metadata,
                    "url": img_url
                }
            }
            max_count -= 1
            if not max_count:
                break
        except:
            print(f"Failed to process: {os.path.basename(img_fn)}")


def generate_feature(img_path: os.PathLike):
    with Image.open(img_path) as img:
        img = transform(img)
        feature = feature_generator(
            img.unsqueeze(0))
        # normalize for easier ES search (less compute on script side)
        feature = F.normalize(feature).detach().cpu().numpy()
        return feature


In [8]:
from elasticsearch import Elasticsearch, helpers


def create_es_indx(indx: str, es: Elasticsearch, dims: int = 2048):
    dense_paintings = {
        "mappings": {
            "properties": {
                "img_vector": {
                    "type": "dense_vector",
                    "dims": dims
                },
                "tags": {
                    "type": "keyword"
                },
                "url": {
                    "type": "text"
                }
            }
        }
    }
    es.indices.delete(indx, ignore=[400, 404])
    es.indices.create(index=indx, body=dense_paintings)


es = Elasticsearch()
index_name = "paintings"
es.info(pretty=True)
# create_es_indx(indx=index_name, es=es, dims=2048)
# for ok, response in helpers.streaming_bulk(es, stream_features(index_name, src_imgs=src, src_meta=meta)):
#     if not ok:
#         print(response)


{'name': 'd6be162e9295',
 'cluster_name': 'docker-cluster',
 'cluster_uuid': 'lcLs5bEPSq6zQ6hfAas94A',
 'version': {'number': '7.13.4',
  'build_flavor': 'default',
  'build_type': 'docker',
  'build_hash': 'c5f60e894ca0c61cdbae4f5a686d9f08bcefc942',
  'build_date': '2021-07-14T18:33:36.673943207Z',
  'build_snapshot': False,
  'lucene_version': '8.8.2',
  'minimum_wire_compatibility_version': '6.8.0',
  'minimum_index_compatibility_version': '6.0.0-beta1'},
 'tagline': 'You Know, for Search'}

In [20]:
import ipywidgets as ipw


def vector_query(query_vector, field_name: str = "img_vector", n: int = 10):
    #  we add +1 since Elasticsearch CANNOT rank negative scores
    return {
        "size": n,
        "query": {
            "script_score": {
                "query": {
                    "match_all": {}
                },
                "script": {
                    "source": f"cosineSimilarity(params.query_vector, '{field_name}') + 1.0",
                    "params": {
                        "query_vector": query_vector

                    }
                }
            }
        }
    }


def search_elasticsearch(index: str, query: List[float], n: int = 10) -> List[Dict[str, Any]]:
    q_format = vector_query(query_vector=query, n=n)
    results = es.search(index=index, body=q_format)
    hits = results['hits']['hits']
    return hits[1:]


def display_query_result(query_result: List[Any]):
    boxes = []
    for res in query_result:
        tags = res["_source"]["tags"]
        url = res["_source"]["url"]

        img_ = ipw.Image.from_url(url)
        img_.height = 320
        img_.width = 160
        cap = ipw.Label(", ".join(tags))
        boxes.append(ipw.VBox([img_, cap]))

    return ipw.VBox(boxes)


In [25]:
query_img = glob.glob(os.path.join(src, "*/*.jpg"))[-1]
feat = generate_feature(query_img).tolist()[0]
qr = search_elasticsearch(index=index_name, query=feat)
qimg = ipw.Image.from_file(query_img)
qimg.height = 320
qimg.width = 160
vbox = display_query_result(qr)

ipw.VBox([qimg, ipw.Label("QueryImage^"), vbox])




VBox(children=(Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00`\x00`\x00\x00\xff\xdb\x00C\x00\x…