In [1]:
import numpy as np
import cv2
import os
from sklearn import mixture as mix
from sklearn.externals import joblib
import fnmatch
import re

In [2]:
def crop_center(img, crop_size):
    s = img.shape
    v = s[0]//2 - crop_size//2
    h = s[1]//2 - crop_size//2
    cropped_img = img[v:v + crop_size, h: h + crop_size]
    return cropped_img

In [3]:
def normalize_patches(patches, shuffle=True):
    if shuffle:
        np.random.shuffle(patches)
    patches = np.float32(patches)
    # need to do this patch by patch, otherwise memory consuming
    for row in xrange(patches.shape[0]):
        patch = patches[row,:]
        norm_patch = patch - np.mean(patch)
        norm_patch = norm_patch - np.min(norm_patch)
        norm_patch /= np.max(norm_patch)
        patches[row,:]=norm_patch
    patches[np.isnan(patches)] = 0
    return patches

In [4]:
def generate_patches(processed_set, Check_dir, image_size, patches_per_image, b, seed=None):
    imagenames = os.listdir(processed_set)
    total_num_of_patches = patches_per_image * len(imagenames)
    patches = np.empty((total_num_of_patches, b * b), np.uint8)
    np.random.seed(seed)
    j = 0
    for name in imagenames:
        img = cv2.imread(processed_set + name,  0)
        if any(dim != image_size for dim in img.shape):
            img = crop_center(img, image_size)
        for i in xrange(patches_per_image):
            m = np.random.randint(0, image_size - b + 1)
            n = np.random.randint(0, image_size - b + 1)      
            patches[j] = np.reshape(img[m:m + b, n:n + b], (1, b*b))
            j += 1
    patches = normalize_patches(patches)
    return patches

In [5]:
def initialize_GMM(n_components, processed_set, Check_dir, image_size, patches_per_image, b):
    if not os.path.exists(Check_dir):
        os.makedirs(Check_dir)
    processing = processed_set[11:-5]
    checkpoints = fnmatch.filter(os.listdir(Check_dir), '*GMM_' + processing + '_*.pkl')
    
    if len(checkpoints) != 0:
        last_checkpoint = checkpoints[-1]
        re_ckpt_number = re.search(processing + '_(.*).pkl', last_checkpoint)
        if re_ckpt_number.group(1) == 'final':
            return None, None, None
        GMM = joblib.load(Check_dir + last_checkpoint)
        initial_i = int(re_ckpt_number.group(1))
        patches = joblib.load(Check_dir + 'Patches_' + processing + '.pkl')
    else:
        initial_i = 0
        GMM = mix.GaussianMixture(n_components=n_components, warm_start=True)
        
    patches_path = Check_dir + 'Patches_' + processing + '.pkl'
    if os.path.isfile(patches_path):
        patches = joblib.load(patches_path)
    else:
        patches = generate_patches(processed_set, Check_dir, image_size, patches_per_image, b)
        joblib.dump(patches, patches_path)
        
    return GMM, initial_i, patches

In [6]:
def train(processed_set, n_components = 200, b = 8, \
          patches_per_image = 500, image_size = 512, \
          batch_size = 10000, save_interval = 20):
    Check_dir = '../Models/' + str(image_size) + '/'
    GMM, initial_i, patches = initialize_GMM(n_components, processed_set, \
                                             Check_dir, image_size, patches_per_image, b)  
    if GMM is None:
        # already trained
        return 0
    total_num_of_patches = patches.shape[0]
    processing = processed_set[11:-5]
    for i in range(initial_i, total_num_of_patches / batch_size):
        GMM.fit(patches[i * batch_size: (i+1) * batch_size])
        if i % save_interval == 0:
            joblib.dump(GMM, Check_dir + 'GMM_' + processing + '_' + str((i+1) * batch_size) + '.pkl')
    GMM.fit(patches[-(total_num_of_patches % batch_size):])
    joblib.dump(GMM, Check_dir + 'GMM_' + processing + '_final.pkl')

In [None]:
Processings_TRN = ['ORI', 'GF', 'JPG', 'MF', 'RS', 'USM', 'WGN']
Processings_TRN = ['../DataSet/' + proc + '/TRN/' for proc in Processings_TRN]
image_sizes = [16, 32, 512]
for image_size in image_sizes:
    for processed_set in Processings_TRN:
        train(processed_set, image_size=image_size)