In [79]:
import cv2
import numpy as np

def cropCordinates(image):
    height, width = image.shape[:2]
    
    new_size = min(height, width)
    cropped = None
    
    if new_size == height:
        start_row, start_col = int(0), int(0 + (width - new_size) / 2)

        end_row, end_col = int(height), int(width - (width - new_size) / 2)

        return (start_row, start_col, end_row, end_col)
    else:
        start_row, start_col = int(0 + (height - new_size) / 2), int(0)

        end_row, end_col = int(height - (height - new_size) / 2), int(width)

        return (start_row, start_col, end_row, end_col)
    
def crop(image):
    start_row, start_col, end_row, end_col = cropCordinates(image)
    
    return image[start_row:end_row , start_col:end_col]
    
def convertTo32by32(image):
    cropped = crop(image)
    resized = cv2.resize(cropped, (32, 32), interpolation = cv2.INTER_AREA)
    
#     cv2.imshow("Resized", resized) 
#     cv2.waitKey(0)     
#     cv2.destroyAllWindows()
#     cv2.waitKey(1)
    
    return resized

def swapChannel(image):
    return np.swapaxes(np.swapaxes(image, 0, 2),1,2)

def prepareImage(image):
    return swapChannel(convertTo32by32(image))

In [None]:
print prepareImage(cv2.imread('images/frog-rect.jpg')).shape
print prepareImage(cv2.imread('images/frog-vert.png')).shape

In [80]:
from keras.models import load_model

model = load_model('larger-cnn.h5')

def predict(image):
    classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    
    x = prepareImage(image)
    
    x = x.astype('float32') 
    x = x / 255.0 
    x = np.expand_dims(x, axis=0)
    
    weighted_classes = list(model.predict(x, verbose=0)[0])
    pred = max(weighted_classes)
    
    precision = pred / sum(weighted_classes)
    
    return (classes[weighted_classes.index(pred)], precision)

In [81]:
print predict(cv2.imread('images/airplane-01.jpg'))
print predict(cv2.imread('images/airplane-02.jpg'))
print predict(cv2.imread('images/airplane-03.jpg'))

('airplane', 0.99978578820484643)
('airplane', 0.99999997787083617)
('airplane', 0.99999799802132305)


In [82]:
print predict(cv2.imread('images/automobile-01.jpg'))
print predict(cv2.imread('images/automobile-02.jpg'))
print predict(cv2.imread('images/automobile-03.jpg'))

('automobile', 0.99999504135386508)
('automobile', 0.96686182620006944)
('automobile', 0.99989117853252396)


In [83]:
print predict(cv2.imread('images/bird-01.jpg'))
print predict(cv2.imread('images/bird-02.png'))
print predict(cv2.imread('images/bird-03.jpg'))

('bird', 0.99998036860612438)
('airplane', 0.72707449110952282)
('bird', 0.99984114316131267)


In [84]:
print predict(cv2.imread('images/cat-01.jpg'))
print predict(cv2.imread('images/cat-02.jpg'))
print predict(cv2.imread('images/cat-03.jpg'))

('cat', 0.96092851117593769)
('cat', 0.98011898653787266)
('cat', 0.96463835278110266)


In [85]:
print predict(cv2.imread('images/deer-01.jpg'))
print predict(cv2.imread('images/deer-02.jpg'))
print predict(cv2.imread('images/deer-03.jpg'))

('deer', 0.98758072425640064)
('deer', 0.9981710374422833)
('bird', 0.60071899639226489)


In [86]:
print predict(cv2.imread('images/dog-01.jpg'))
print predict(cv2.imread('images/dog-02.jpg'))
print predict(cv2.imread('images/dog-03.jpg'))

('cat', 0.77759511110515767)
('dog', 0.57831404181058466)
('cat', 0.87355838424613552)


In [87]:
print predict(cv2.imread('images/frog-01.jpg'))
print predict(cv2.imread('images/frog-02.jpg'))
print predict(cv2.imread('images/frog-03.jpg'))

('truck', 0.47710126938327874)
('automobile', 0.8800995811845379)
('automobile', 0.99985671153420963)


In [88]:
print predict(cv2.imread('images/horse-01.jpg'))
print predict(cv2.imread('images/horse-02.jpg'))
print predict(cv2.imread('images/horse-03.jpg'))

('airplane', 0.88880101316282689)
('airplane', 0.99923741391610488)
('dog', 0.3588241987109177)


In [89]:
print predict(cv2.imread('images/ship-01.jpg'))
print predict(cv2.imread('images/ship-02.jpg'))
print predict(cv2.imread('images/ship-03.jpg'))

('cat', 0.72194112218114881)
('ship', 0.9768867681037352)
('bird', 0.40186865230858104)


In [90]:
print predict(cv2.imread('images/truck-01.jpg'))
print predict(cv2.imread('images/truck-02.jpg'))
print predict(cv2.imread('images/truck-03.jpg'))

('truck', 0.9999999626801489)
('truck', 0.99999999997507572)
('truck', 0.99999743590892998)
