In [2]:
import os

In [6]:
import cv2
import numpy as np

def get_image(fname):
    img = cv2.cvtColor(cv2.imread(fname), cv2.COLOR_BGR2RGB)
    if img is None:
         return None
    img = cv2.resize(img, (224, 224))
    img = img[np.newaxis, :]
    
    return img

In [7]:
import glob
import os

def process_files(folder_name, start_index=0):
    
    products = []
    index = start_index
    dim = 2048
    xb = np.empty(shape=[0, dim], dtype=np.float32)
    
    for imgfolderpath in glob.glob(folder_name + '/*'):
        productTitle = os.path.basename(imgfolderpath)
    
        for file in glob.glob(imgfolderpath + '/*'):
            try:
                img = get_image(file)
                product = {}
                product['productTitle'] = productTitle
                product['imageFileName'] = file

                # extract features
                features_batch = model.predict([img])
                features = features_batch[0]

                # the Knn algorithm we'll use requires float32 rather than the default float64
                xb = np.append(xb, [features.astype(np.float32)], axis=0)

                products.append(product)
                index += 1
            except:
                print("Something went wrong with the file: " + file)            
    return (products, xb)

In [4]:
if not os.path.exists('./dataset-resized'):
    !curl https://raw.githubusercontent.com/garythung/trashnet/master/data/dataset-resized.zip | jar xv

In [5]:
from keras.applications.resnet50 import ResNet50
model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3), pooling="avg")

Using TensorFlow backend.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.



In [8]:
products, train_features = process_files('./dataset-resized')

import sklearn
import sklearn.preprocessing

train_features = sklearn.preprocessing.normalize(train_features)

from sklearn.manifold import TSNE
tsne = TSNE(n_components=3).fit_transform(train_features)

import pandas as pd

df = pd.DataFrame(products)
texts = df["imageFileName"]

li, uniques = pd.factorize(df['productTitle'])

li = (li * 99) // (len(uniques) - 1)

import plotly.graph_objs as go
from plotly.offline import plot

scatter = go.Scatter3d(
    x=tsne[:,0],
    y=tsne[:,1],
    z=tsne[:,2],
    mode='markers',
    text=texts,
    marker=dict(
        size=8,
        color=li,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ))

fig = go.Figure(data=[scatter])

plot(fig, filename='3d-scatter.html')





'3d-scatter.html'