In [12]:
import numpy as np
from neo4j.v1 import GraphDatabase
from sklearn.cluster import KMeans
import os
import sys
import caffe
from sklearn.neighbors import NearestNeighbors
from skimage import io
import cv2
import matplotlib.pyplot as plt
#Widgets
import ipywidgets as widgets
from IPython.display import display
%matplotlib inline

In [13]:
category = [
('jeans','jeans'),
('панталони', 'hose'),
('каиш', 'gürtel'),
('јакна', 'jacke'),
('чино панталони', 'chino'),
('кожа', 'leder'),
('џемпер', 'sweatshirt'),
('сникерс патики', 'sneakers'),
('дуксер', 'kapuze'),
('пуловер', 'pullover'),
('ремен', 'armband'),
('деним', 'denim'),
('кошула', 'hemd'),
('маица', 'shirt'),
('наочари за сонце', 'sonnenbrille'),
('зимски дуксер', 'sweatjacke'),
('зимска јакна', 'winterjacke'),
('каиш', 'belt'),
('капа', 'mütze'),
('патики за трчање', 'laufschuhe'),
('бајкер', 'biker'),
('кардиган', 'strickjake'),
('бермуди, гаќи', 'shorts'),
('beanie', 'beanie'),
('капа', 'cap'),
('parka', 'parka'),
('плетени џемпери', 'strickpullover'),
('каиш', 'belt'),
('капут', 'mantel'),
('дуксер', 'hoody'),
('кратка маица', 'tee'),
('кратка маица', 't-shirt'),
('долги ракави', 'longsleeve'),
('quilted јакна', 'steppjacke'),
('карго панталони', 'cargo'),
('skinny панталони', 'skinny'),
('сникерс', 'turnschuhe'),
('волна', 'fleece'),
('маица со долги ракави', 'langarmshirt'),
('елек - vest', 'weste'),
('накит', 'schmuck'),
('За џогирање', 'jogginghose'),
('крзно', 'für'),
('панталони', 'hoses'),
('кондури', 'schuhe'),
('чизми', 'boots'),
('кардиган', 'cardigan'),
('поло маици', 'poloshirt'),
('часовник', 'uhren'),
('спортска јакна', 'sportjacke'),
('долго палто', 'trenchcoat'),
('долга маица', 'longshirt') ]

In [14]:
def find_descriptor(net, href):
    #read image
    try: img = io.imread(href) #Download url
    except: print('Link unavailable: ', href); return None
    #Preprocess
    #img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) #Convert to BGR, it appears that the links are already BGR
    img = cv2.resize(img, (224, 224))
    #transformer
    transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
    transformer.set_transpose('data', (2, 0, 1)) #input has shape 1,3,224,224
    transformer.set_mean('data', np.array([104, 117, 123]))
    img = transformer.preprocess('data', img)
    #Feed to network
    net.blobs['data'].data[...] = img
    out = net.forward()
    out = out['embedding'][0]
    out = [item for item in out] #Use this if storing descriptor as an array
    # out = ''.join('{},'.format(item) for item in out.astype(str))#[:-1]

    return out

def find_nearest_centroid(session, queried_descriptor, category, gender, nearest=3):
    query = 'MATCH (a:Centroid)' \
            'WHERE a.category = {category} AND a.gender = {gender}' \
            'RETURN a.descriptor as descriptor'
    result = session.run(query, category=category, gender=gender)
    
    if result.peek() is None:
        return None
    
    descriptors = []
    for item in result:
        descriptor = item['descriptor']

        descriptors.append(descriptor)
    
    NN = NearestNeighbors(n_neighbors=5, radius=10000)
    NN.fit(descriptors)
    index = NN.kneighbors([queried_descriptor], return_distance=False)[0, nearest]
    
    return descriptors[index]
    
def get_links(links, indexes):
    neighbours = []
    for i in range(0, indexes.size):
        index = indexes[0, i]
        neighbours.append(links[index])
    return neighbours

def find_descriptor_from_id(session, id):
    query = 'MATCH (a:Item)' \
            'WHERE a.id = {id}' \
            'RETURN a.descriptor as descriptor, a.link as link'
    result = session.run(query, id=id)
    result = result.single()

    try: 
        descriptor = toArray(result['descriptor'])
        link = result['link']
    except:
        return None, None
    
    return descriptor, link

#Turns a string descriptor to a numpy array
def toArray(descriptor):
    descriptor = descriptor[1:]
    descriptor = descriptor[:-1]
    descriptor = np.fromstring(descriptor, dtype=np.float32, sep=',')
    return descriptor
    

def plotItems(item, matches, UI=True):
    if UI:
        f, ax = plt.subplots()
        img = io.imread(link)
        img = cv2.resize(img, (224, 224))
        ax.imshow(img)
        
        plt.text(0, 300, 'Recommended:', fontsize=25)
        for i in range(0, len(matches)):
            img = io.imread(matches[i])
            img = cv2.resize(img, (224, 224))
            f, ax = plt.subplots()
            ax.imshow(img)
    else:
        print('Query: {}'.format(item))
        print('Matches: {}'.format(matches))

def plotItem(item):
    plt.close()
    f, ax = plt.subplots()
    img = io.imread(link)
    img = cv2.resize(img, (224, 224))
    ax.imshow(img)

def get_descriptor(session, net, item):
    if item.isdigit():
        item = int(item)
        descriptor, link = find_descriptor_from_id(session, item)
    else:
        descriptor = find_descriptor(net, item)
        descriptor = np.array(descriptor, dtype=np.float32)
        link = item
    
    if descriptor is None:
        raise Exception('Could not find descriptor for item: {}'.format(item))
    return descriptor, link

#Because some categories are intermixed, using the first cluster gives us items from the same category as ours,
#we try to avoid that by using the nearest variable which tells us which nth nearest cluster to use
def find_matches(session, descriptor, category, gender, nearest=1):
    #Some categories don't have centroids
    if nearest > 5:
        print('nearest can\'t be bigger than 5, setting to 5!')
        nearest = 5
    centroid = find_nearest_centroid(session, descriptor, category, gender, nearest=nearest-1)
    items = knn(session, descriptor, centroid, category, gender)
    return items

def find_random(session, category, gender, num=5):
    query = 'MATCH (c:Category) <-[:HAS_CATEGORY]- (a:Item) -[:HAS_CATEGORY]-> (b:Category) ' \
                'WHERE c.name = {category} AND b.name = {gender}' \
                'RETURN a'
    result = session.run(query, category=category, gender=gender)
    
    links = []
    for item in result:
        node = item['a']
        link = node['link']
        
        links.append(link)
    
    indexes = np.random.randint(0, len(links), num)
    
    matches = []
    for i in range(0, len(indexes)):
        index = indexes[i]
        matches.append(links[index])
        
    return matches

def find_ids(session):
    query = 'MATCH (a:Item)' \
                'RETURN DISTINCT a.id'
    result = session.run(query)
    
    ids = result.values()
    return ids

In [15]:
def knn(session, queried_descriptor, centroid, category, gender, neighbours=5):
    if centroid is not None:
        query = 'MATCH (a:Item) -[:GRAVITATES]-> (b:Centroid)' \
                'WHERE b.category = {category} AND b.gender = {gender} AND b.descriptor = {centroid}' \
                'RETURN a'
        result = session.run(query, category=category, gender=gender, centroid=centroid)
    else:
        query = 'MATCH (c:Category) <-[:HAS_CATEGORY]- (a:Item) -[:HAS_CATEGORY]-> (b:Category) ' \
                'WHERE c.name = {category} AND b.name = {gender}' \
                'RETURN a'
        result = session.run(query, category=category, gender=gender)
    
    descriptors = []
    links = []
    for item in result:
        node = item['a']
        descriptor = toArray(node['descriptor'])
        link = node['link']
        
        descriptors.append(descriptor)
        links.append(link)
    
    NN = NearestNeighbors(n_neighbors=neighbours, radius=10000)
    NN.fit(descriptors)
    indexes = NN.kneighbors([queried_descriptor], return_distance=False)
    matches = get_links(links, indexes)

    return matches

In [16]:
uri = 'bolt://0.0.0.0:7472'
password = 'starlight'
nearest = 1

driver = GraphDatabase.driver(uri, auth=('neo4j', password))
session = driver.session()

ids = find_ids(session)
kn = []

#Initialize net
net = caffe.Net('../deploy_googlenet-siamese.prototxt',
                '../googlenet-siamese-final.caffemodel',
                caffe.TEST)


# Искористи го копчето за да генерираш облека

In [17]:
# button = widgets.Button(description="Генерирај облека")
# display(button)

descriptor = None
link = None

# widgets.interact_manual.opts['manual_name'] = 'Генерирај облека'

@widgets.interact_manual()
def generate_apparel():
    global descriptor, link
    pos = np.random.randint(0, len(ids))
    item = str(ids[pos][0])
    descriptor, link = get_descriptor(session, net, item)
    %clear
    plt.close()
    img = io.imread(link)
    f, ax = plt.subplots()
    ax.imshow(img)

# button.on_click(generate_apparel)


interactive(children=(Button(description='Run Interact', style=ButtonStyle()), Output()), _dom_classes=('widge…

# Одбери каква категорија на облека сакаш, и одреди која колона е посоодветна (ако добиете чудни резултати променете го cluster)

In [18]:
def plotApparel(label, matches):
    for i in range(0, len(matches)):
        img = io.imread(matches[i])
        img = cv2.resize(img, (224, 224))
        f, ax = plt.subplots()
        ax.imshow(img)

def plotApp(axs, matches, matches2):
    j = 0
    for i in range(0, len(matches)):
        img = io.imread(matches[i])
        img = cv2.resize(img, (224, 224))
        axs[i, j].imshow(img)
    
    j = 1
    for i in range(0, len(matches2)):
        img = io.imread(matches2[i])
        img = cv2.resize(img, (224, 224))
        axs[i, j].imshow(img)
    

@widgets.interact_manual(
    category=category,
    gender=[('унисекс', 'unisex'), ('женска', 'damen'), ('машка', 'herren')],
    cluster=[3,4,5]
)
def f(category, gender, cluster):
    nearest = cluster
    %clear
    plt.close()
    fig, axs = plt.subplots(5, 2, figsize=(25, 25))
    if np.random.randint(2) == 0:
        try:
            matches = find_matches(session, descriptor, category, gender, nearest)
            matches2 = find_random(session, category, gender)
            plotApp(axs, matches, matches2)
            kn.append(1)
        except:
            print('Обиди се повторно')
    else:
        try:
            matches = find_random(session, category, gender)
            matches2 = find_matches(session, descriptor, category, gender, nearest)
            plotApp(axs, matches, matches2)
            kn.append(2)
        except:
            print('Обиди се повторно')

interactive(children=(Dropdown(description='category', options=(('jeans', 'jeans'), ('панталони', 'hose'), ('к…

In [19]:
for i in range(0, len(kn)):
    print(i+1, kn[i])