In [36]:
from matplotlib import image 
from tqdm import tqdm 
import numpy as np

import pathlib 


label_to_int_mapping = {
    'cardboard': 1, 
    'glass': 2, 
    'metal': 3, 
    'paper': 4, 
    'plastic': 5, 
    'trash': 6
}

int_to_label_mapping = {
    '1': 'cardboard', 
    '2': 'glass', 
    '3': 'metal', 
    '4': 'paper', 
    '5': 'plastic',
    '6': 'trash'
}

def load_image_files(train_data_path, test_data_path, assertion=False):
    train_data_files = list(train_data_path.glob('*.jpg'))
    test_data_files = list(test_data_path.glob('*.jpg'))
    if assertion:
        assert 1766 == len(train_data_files)
        assert 761 == len(test_data_files)
    return train_data_files, test_data_files

def flatten_and_normalize(img):
    return img.flatten() / 255.0

def grey_scale_image(img):
    return np.dot(img[...,:3], [0.2989, 0.5870, 0.1140])#img.mean(axis=2)

def preprocess_data(image_files, col_dim, use_grey_scale):
    
   
    X = np.zeros((len(image_files),col_dim))
    y = np.zeros((len(image_files),))
    for example_idx, image_file in tqdm(enumerate(image_files)):
        if use_grey_scale: 
            img_data = flatten_and_normalize(grey_scale_image(image.imread(image_file)))
        else:
            img_data = flatten_and_normalize(image.imread(image_file))
        X[example_idx] = img_data 
        if 'cardboard' in str(image_file):
            y[example_idx] = label_to_int_mapping['cardboard']
        elif 'glass' in str(image_file):
            y[example_idx] = label_to_int_mapping['glass']
        elif 'metal' in str(image_file):
            y[example_idx] = label_to_int_mapping['metal']
        elif 'paper' in str(image_file):
            y[example_idx] = label_to_int_mapping['paper']
        elif 'plastic' in str(image_file):
            y[example_idx] = label_to_int_mapping['plastic']
        else: 
            y[example_idx] = label_to_int_mapping['trash']
    return X, y
    
def get_preprocessed_train_test_images(use_grey_scale=True, assertion=False):
    base = pathlib.Path('../data')
    train_data_path = base / 'train'
    test_data_path = base / 'test'
    train_data_files, test_data_files = load_image_files(train_data_path, test_data_path)
    img0 = image.imread(train_data_files[0])
    orig_size = image.imread(train_data_files[0]).shape

    if assertion: 
        for train_data_file in tqdm(train_data_files):
            assert orig_size == image.imread(train_data_file).shape
            
        for train_data_file in tqdm(train_data_files):
            assert orig_size == image.imread(train_data_file).shape

    if use_grey_scale: 
        col_dim = grey_scale_image(img0).flatten().shape[0]
    else: 
        col_dim = img0.flatten().shape[0]
    print(grey_scale_image(img0).shape)
    X_train, y_train = preprocess_data(train_data_files, col_dim, use_grey_scale)
    X_test, y_test = preprocess_data(test_data_files, col_dim, use_grey_scale)
    return X_train, X_test, y_train, y_test

In [37]:
X_train, X_test, y_train, y_test = get_preprocessed_train_test_images()

(384, 512)


1766it [00:26, 65.59it/s]
761it [00:13, 56.82it/s]


In [38]:
from sklearn.neighbors import KNeighborsClassifier

(0.4060446780551905, 0.7429218573046432)

In [39]:
for i in range(1,20):
    knni = KNeighborsClassifier(n_neighbors=i).fit(X_train, y_train)
    print(i, knni.score(X_test, y_test), knni.score(X_train, y_train))

1 0.4126149802890933 1.0
2 0.4060446780551905 0.7429218573046432
3 0.4021024967148489 0.6183465458663646
4 0.39816031537450725 0.5928652321630804
5 0.39421813403416556 0.5543601359003397
6 0.41130091984231276 0.5328425821064553
7 0.39421813403416556 0.5101925254813137
8 0.3797634691195795 0.4898074745186863
9 0.38896189224704336 0.47565118912797283
10 0.37056504599211565 0.46772366930917325
11 0.37582128777923784 0.4524348810872027
12 0.38370565045992117 0.45016987542468856
13 0.3771353482260184 0.43941109852774635
14 0.37056504599211565 0.43714609286523215
15 0.3797634691195795 0.4331823329558324
16 0.3731931668856767 0.42638731596828994
17 0.37582128777923784 0.42638731596828994
18 0.37844940867279897 0.42129105322763305
19 0.36662286465177396 0.41336353340883353
