In [1]:
import numpy as np
import cv2
import os
from sklearn import mixture as mix
from sklearn.externals import joblib
import fnmatch
import re

In [2]:
def crop_center(img, crop_size):
    s = img.shape
    v = s[0]//2 - crop_size//2
    h = s[1]//2 - crop_size//2
    cropped_img = img[v:v + crop_size, h: h + crop_size]
    return cropped_img

In [3]:
def normalize_patches(patches, shuffle=True):
    if shuffle:
        np.random.shuffle(patches)
    patches = np.float32(patches)
    # need to do this patch by patch, otherwise memory consuming
    for row in xrange(patches.shape[0]):
        patch = patches[row,:]
        norm_patch = patch - np.mean(patch)
        norm_patch = norm_patch - np.min(norm_patch)
        norm_patch /= np.max(norm_patch)
        patches[row,:]=norm_patch
    patches[np.isnan(patches)] = 0
    return patches

In [4]:
def generate_patches(processed_set, patches_per_image, b, seed=None):
    imagenames = os.listdir(processed_set)
    total_num_of_patches = patches_per_image * len(imagenames)
    patches = np.empty((total_num_of_patches, b * b), np.uint8)
    np.random.seed(seed)
    j = 0
    for name in imagenames:
        img = cv2.imread(processed_set + name,  0)
        image_size = img.shape[0]
        for i in xrange(patches_per_image):
            m = np.random.randint(0, image_size - b + 1)
            n = np.random.randint(0, image_size - b + 1)      
            patches[j] = np.reshape(img[m:m + b, n:n + b], (1, b*b))
            j += 1
    patches = normalize_patches(patches)
    return patches

In [5]:
def initialize_GMM(n_components, processed_set, processing, Check_dir, patches_per_image, b):
    if not os.path.exists(Check_dir):
        os.makedirs(Check_dir)
    checkpoints = fnmatch.filter(os.listdir(Check_dir), '*GMM_' + processing + '_*.pkl')
    
    if len(checkpoints) != 0:
        if os.path.isfile(Check_dir + 'GMM_' + processing + '_final.pkl'):
            return None, None, None
        checkpoints.sort(key=len)
        last_checkpoint = checkpoints[-1]
        print 'Loading from ' + Check_dir + last_checkpoint + '...'
        GMM = joblib.load(Check_dir + last_checkpoint)
        re_ckpt_number = re.search(processing + '_(.*).pkl', last_checkpoint)
        initial_i = int(re_ckpt_number.group(1))
        patches = joblib.load(Check_dir + 'Patches_' + processing + '.pkl')
    else:
        initial_i = 0
        GMM = mix.GaussianMixture(n_components=n_components, warm_start=True)
        
    patches_path = Check_dir + 'Patches_' + processing + '.pkl'
    if os.path.isfile(patches_path):
        patches = joblib.load(patches_path)
    else:
        print 'Generating patches...'
        patches = generate_patches(processed_set, patches_per_image, b)
        joblib.dump(patches, patches_path)
        
    return GMM, initial_i, patches

In [6]:
def train(processed_set, image_size, n_components = 200, \
          b = 8, patches_per_image = 500, batch_size = 10000, \
          save_interval = 20):
    Check_dir = '../Models/' + str(image_size) + '/'
    processing = processed_set[12 + len(str(image_size)):-5]
    GMM, initial_i, patches = initialize_GMM(n_components, processed_set, processing, \
                                             Check_dir, patches_per_image, b)  
    if GMM is None:
        print processing + ' with image size: ' + str(image_size) + ' already trained.'
        return 0
    total_num_of_patches = patches.shape[0]
    initial_i /= batch_size
    print 'Training ' + processing + ' with image size: ' + str(image_size) +'...'
    for i in range(initial_i, total_num_of_patches / batch_size):
        GMM.fit(patches[i * batch_size: (i+1) * batch_size])
        if i % save_interval == 0:
            print 'Saving GMM_' + processing + '_' + str((i+1) * batch_size) + '.pkl...'
            joblib.dump(GMM, Check_dir + 'GMM_' + processing + '_' + str((i+1) * batch_size) + '.pkl')
    GMM.fit(patches[-(total_num_of_patches % batch_size):])
    print 'Saving GMM_' + processing + '_final.pkl...'
    joblib.dump(GMM, Check_dir + 'GMM_' + processing + '_final.pkl')

In [7]:
Processings = ['ORI', 'GF', 'JPG', 'MF', 'RS', 'USM', 'WGN']
im_sizes = [16, 32, 512]
for im_size in im_sizes:
    Processings_TRN = ['../DataSet/' +str(im_size) + '/' + proc + '/TRN/' for proc in Processings]
    for processed_set in Processings_TRN:
        train(processed_set, image_size=im_size)
    print 'Training for image size: ' + str(im_size) + ' completed.'
print 'Training completed.'

ORI with image size: 16 already trained.
GF with image size: 16 already trained.
JPG with image size: 16 already trained.
MF with image size: 16 already trained.
RS with image size: 16 already trained.
USM with image size: 16 already trained.
WGN with image size: 16 already trained.
Training for image size: 16 completed.
ORI with image size: 32 already trained.
GF with image size: 32 already trained.
JPG with image size: 32 already trained.
MF with image size: 32 already trained.
RS with image size: 32 already trained.
USM with image size: 32 already trained.
WGN with image size: 32 already trained.
Training for image size: 32 completed.
ORI with image size: 512 already trained.
GF with image size: 512 already trained.
JPG with image size: 512 already trained.
MF with image size: 512 already trained.
Loading from ../Models/512/GMM_RS_10000.pkl...
Training RS with image size: 512...
Saving GMM_RS_210000.pkl...
Saving GMM_RS_410000.pkl...
Saving GMM_RS_610000.pkl...
Saving GMM_RS_810000.

  # Remove the CWD from sys.path while we load stuff.


Training USM with image size: 512...
Saving GMM_USM_10000.pkl...
Saving GMM_USM_210000.pkl...
Saving GMM_USM_410000.pkl...
Saving GMM_USM_610000.pkl...
Saving GMM_USM_810000.pkl...
Saving GMM_USM_1010000.pkl...
Saving GMM_USM_1210000.pkl...
Saving GMM_USM_final.pkl...
Generating patches...
Training WGN with image size: 512...
Saving GMM_WGN_10000.pkl...
Saving GMM_WGN_210000.pkl...
Saving GMM_WGN_410000.pkl...
Saving GMM_WGN_610000.pkl...
Saving GMM_WGN_810000.pkl...
Saving GMM_WGN_1010000.pkl...
Saving GMM_WGN_1210000.pkl...
Saving GMM_WGN_final.pkl...
Training for image size: 512 completed.
Training completed.
