In [5]:
import os
import glob
import numpy as np 
import torch.nn as nn 
import torch.nn.functional as F
from tqdm.notebook import tqdm
import torchvision.models as models
from torchvision import datasets, transforms as T
from PIL import Image

In [6]:
feature_generator = models.resnext50_32x4d(pretrained=True)
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
transform = T.Compose([T.Resize(256),
                       T.CenterCrop(224),
                       T.ToTensor(),
                       normalize])
# to generate features disable last layer 
feature_generator.fc = nn.Sequential()
_ = feature_generator.eval()

In [7]:
import difflib
import ujson as json
from typing import Iterable, Dict, Any, Tuple, List

src = '/Users/jm/data/wikiart/images'
meta = '/Users/jm/data/wikiart/meta/'
CODED_TAGS = ['style', 'genre', 'material']


def locate_painting(meta_folder: str, full_path: str) -> Tuple[List[str], str]:
    p = full_path.split(os.sep)
    author = p[-2]
    name = p[-1].replace(".jpg", "")
    if 'untitled' in name.lower():
        return None, None
    author = json.load(open(os.path.join(meta_folder, author + ".json"), 'r'))
    titles = [p['title'] for p in author if 'title' in p]
    match = difflib.get_close_matches(name, titles, n=1, cutoff=0.7)
    if not match:
        return None, None
    target_painting = author[titles.index(match[0])]
    # get tags here
    ensemble_tags = []
    tags = target_painting.get("tags", None)
    if tags:
        tags = tags.lower().split(', ')
        ensemble_tags.extend(tags)
    for subtag in CODED_TAGS:
        st = target_painting.get(subtag, None)
        if st:
            st = st.lower().strip()
            if ',' in st:
                st = st.split(', ')
                ensemble_tags.extend(st)
            else:
                ensemble_tags.append(st)
    return ensemble_tags, target_painting['image']


def stream_features(index: str, src_imgs: os.PathLike, src_meta: os.PathLike) -> Iterable[Dict[str, Any]]:
    max_count = 3
    for img_fn in tqdm(glob.glob(os.path.join(src_imgs, "*/*.jpg")),
                       desc='Generating features'):
        try:
            metadata, img_url = locate_painting(
                meta_folder=src_meta, full_path=img_fn)
            if not metadata:
                continue
            feature = generate_feature(img_fn)

            yield {
                "_index": index,
                "_source": {
                    "img_vector": feature.tolist()[0][:100],
                    "tags": metadata,
                    "url": img_url
                }
            }
            max_count -= 1
            if not max_count:
                break
        except KeyError:
            print(f"Failed to process: {os.path.basename(img_fn)}")


def generate_feature(img_path: os.PathLike):
    with Image.open(img_path) as img:
        img = transform(img)
        feature = feature_generator(
            img.unsqueeze(0))
        feature = F.normalize(feature).detach().cpu().numpy()
        return feature


In [14]:
from elasticsearch import Elasticsearch, helpers


def create_es_indx(indx: str, es: Elasticsearch, dims: int = 2048):
    dense_paintings = {
        "mappings": {
            "properties": {
                "img_vector": {
                    "type": "dense_vector",
                    "dims": dims
                },
                "tags": {
                    "type": "keyword"
                },
                "url": {
                    "type": "text"
                }
            }
        }
    }
    es.indices.delete(indx, ignore=[400, 404])
    es.indices.create(index=indx, body=dense_paintings)


es = Elasticsearch()
index_name = "paintings"
es.info(pretty=True)
create_es_indx(indx=index_name, es=es, dims=2048)
for ok, response in helpers.streaming_bulk(es, stream_features(index_name, src_imgs=src, src_meta=meta)):
    if not ok:
        print(response)


Generating features:   0%|          | 0/37868 [00:00<?, ?it/s]

In [15]:

def vector_query(query_vector, field_name: str = "img_vector", method="cosine"):
    #  we add +1 since Elasticsearch CANNOT rank negative scores
    if method == "cosine":
        score_fn = f"cosineSimilarity(params.query_vector, '{field_name}') + 1.0"
    else:
        score_fn = f"cosineSimilarity(params.query_vector, '{field_name}') + 1.0"
    print(score_fn)
    return {
        "size": 1,
        "query": {
            "script_score": {
                "query": {
                    "match_all": {}
                },
                "script": {
                    "source": score_fn,
                    "params": {
                        "query_vector": query_vector

                    }
                }
            }
        }
    }


def search_elasticsearch(index: str, query: List[float], method: str = "cosine") -> List[Dict[str, Any]]:
    q_format = vector_query(query_vector=query, method=method)
    results = es.search(index=index, body=q_format)
    print(results)
    hits = results['hits']['hits']
    return hits[1:]


In [18]:
from elasticsearch import logger as es_logger

LOGLEVEL = 50
es_logger.setLevel(LOGLEVEL)

In [16]:
# es.search(index=index_name, body={
#     "size": 1,
#     "query": {
#         "match": {
#             "tags": {
#                 "query": "realism"
#             }
#         }
#     }
# })


In [19]:
feat = generate_feature(glob.glob(os.path.join(src, "*/*.jpg"))[-1]).tolist()
search_elasticsearch(index=index_name, query=feat)
# locate_painting(meta, glob.glob(os.path.join(src, "*/*.jpg"))[-1])

cosineSimilarity(params.query_vector, 'img_vector') + 1.0


RequestError: RequestError(400, 'search_phase_execution_exception', 'runtime error')

In [12]:
feat = generate_feature('/Users/jm/data/wikiart/images/le-corbusier/Carton_pour_tapisserie_Marie_Cuttoli.jpg').tolist()

In [14]:
feat[0]

[0.5082176327705383,
 0.3644159734249115,
 0.2834455072879791,
 0.22682209312915802,
 0.14105096459388733,
 0.33090540766716003,
 0.14845611155033112,
 1.4157172441482544,
 0.4562191069126129,
 0.0037068966776132584,
 0.03438771888613701,
 0.04211169481277466,
 0.1622363179922104,
 0.21804344654083252,
 0.679964005947113,
 0.5121763944625854,
 0.32109010219573975,
 0.15980727970600128,
 0.033732034265995026,
 0.02695423737168312,
 0.7968424558639526,
 0.1073613092303276,
 0.31156569719314575,
 0.3088437020778656,
 0.026965521275997162,
 1.1004310846328735,
 0.606105625629425,
 1.131072998046875,
 0.4298565983772278,
 0.438693642616272,
 0.6242384910583496,
 0.4721740484237671,
 0.6600103378295898,
 0.20566165447235107,
 0.36688998341560364,
 0.41983261704444885,
 1.509124755859375,
 0.8685483336448669,
 0.8341764211654663,
 0.4517126679420471,
 0.28308236598968506,
 1.6400933265686035,
 0.09592751413583755,
 0.3007603883743286,
 0.35418686270713806,
 0.6821137070655823,
 0.290884822607