In [1]:
import numpy as np
import cv2
import os
from sklearn import mixture as mix
from sklearn.externals import joblib

In [2]:
def crop_center(img, crop_size):
    s = img.shape
    v = s[0]//2 - crop_size//2
    h = s[1]//2 - crop_size//2
    cropped_img = img[v:v + crop_size, h: h + crop_size]
    return cropped_img

In [3]:
def normalize_patches(patches, shuffle=True):
    if shuffle:
        np.random.shuffle(patches)
    patches = np.float32(patches)
    # need to do this patch by patch, otherwise memory consuming
    for row in xrange(patches.shape[0]):
        patch = patches[row,:]
        norm_patch = patch - np.mean(patch)
        norm_patch = norm_patch - np.min(norm_patch)
        norm_patch /= np.max(norm_patch)
        patches[row,:]=norm_patch
    patches[np.isnan(patches)] = 0
    return patches

In [4]:
def train(processed_set, n_components = 200, b = 8, \
          patches_per_image = 500, image_size = 512, \
          batch_size = 10000, save_interval = 20, seed=None):
    imagenames = os.listdir(processed_set)
    total_num_of_patches = patches_per_image * len(imagenames)
    patches = np.empty((total_num_of_patches, b * b), np.uint8)
    np.random.seed(seed)
    j = 0
    for name in imagenames:
        img = cv2.imread(processed_set + name,  0)
        if any(dim != image_size for dim in img.shape):
            img = crop_center(img, image_size)
        for i in xrange(patches_per_image):
            m = np.random.randint(0, image_size - b + 1)
            n = np.random.randint(0, image_size - b + 1)      
            patches[j] = np.reshape(img[m:m + b, n:n + b], (1, b*b))
            j += 1
    GMM = mix.GaussianMixture(n_components=n_components, warm_start=True)
    Check_dir = '../Models/' + str(image_size) + '/'
    if not os.path.exists(Check_dir):
        os.makedirs(Check_dir)
    processing = processed_set[11:-5]
    patches = normalize_patches(patches)
    for i in range(total_num_of_patches / batch_size):
        GMM.fit(patches[i * batch_size: (i+1) * batch_size])
        if i % save_interval == 0:
            joblib.dump(GMM, Check_dir + 'GMM_' + processing + '_' + str((i+1) * batch_size) + '.pkl')
    GMM.fit(patches[-(total_num_of_patches % batch_size):])
    joblib.dump(GMM, Check_dir + 'GMM_' + processing + '_final.pkl')

In [None]:
Processings_TRN = ['ORI', 'GF', 'JPG', 'MF', 'RS', 'USM', 'WGN']
Processings_TRN = ['../DataSet/' + proc + '/TRN/' for proc in Processings_TRN]
image_sizes = [16, 32, 512]
for image_size in image_sizes:
    for processed_set in Processings_TRN:
        train(processed_set, image_size=image_size)

  # Remove the CWD from sys.path while we load stuff.
