### @ Mallory Wittwer
This notebook is used to pre-compute features that would otherwise take time to produce at training time. Pre-computed features are saved as NPY files.

In [20]:
import os
import numpy as np
from PIL import Image
import tensorflow as tf
import piexif
import piexif.helper
import json
import glob
from sklearn.pipeline import Pipeline
from sklearn.cluster import KMeans

from pymks import (
    PrimitiveTransformer,
    TwoPointCorrelation,
)

import warnings
warnings.filterwarnings('ignore')

def read_metaImage(root):
    '''Reads metadata embedded in a JPG file.'''
    exif_dict = piexif.load(root)
    user_comment = piexif.helper.UserComment.load(exif_dict["Exif"][piexif.ExifIFD.UserComment])
    d = json.loads(user_comment)
    return d

class MultiDataGenerator(tf.keras.utils.Sequence):
    '''
    Has an argument "subset" that enables to yield only a single texture class (for ex. "Brick").
    '''
    def __init__(self, root, batch_size=50, subset=None, res=(150,150)):
        self.d = {'Brick':[0,1], 'Checker':[1,1], 'Magic':[2,2], 'Noise':[3,2], 'Wave':[4,2]}
        self.root = root
        self.bs = batch_size
        self.rx, self.ry = res
        self.files = glob.glob(root)
        self.n = len(self.files)
        if subset:
            filt = np.array(self._get_class_names()) == subset
            self.files = np.array(self.files)[filt]
            self.n = len(self.files)
        self.subset = subset
        print(f'Initialized data generator. Found {self.n} files.')
    
    def __len__(self):
        return int(np.floor(self.n/self.bs))
    
    def __getitem__(self, idx):
        batch = self.files[idx*self.bs:(idx+1)*self.bs]
        X, y = self._data_generation(batch)
        return X, y
    
    def _data_generation(self, batch):
        data = np.empty((len(batch), self.rx, self.ry, 3), dtype=np.uint8)
        
        if self.subset:
            labels = np.empty((len(batch), self.d[self.subset][1]), dtype=np.float32)
        else:
            labels = np.empty(len(batch), dtype=np.uint8)
        
        for k, file in enumerate(batch):
            data[k] = np.array(Image.open(file), dtype=np.uint8)
            
            if self.subset:       
                labels[k] = list(read_metaImage(file).get('texture_params').values())[0]
            else:
                vals = list(read_metaImage(file).get('texture_params').keys())[0]
                labels[k] = self.d[vals][0]
                
        return data, labels   
    
    def _get_class_names(self):
        class_names = []
        for file in self.files:
            class_names.append(list(read_metaImage(file).get('texture_params').keys())[0])
        return class_names
    
class TwoPtsCorrelator():
    '''
    Computes Two-points correlation feature maps.
    '''
    def __init__(self, ctf=150):
        self.preprocessor = Pipeline(steps=[
            ("discritize", PrimitiveTransformer(n_state=2, min_=0.0, max_=1.0)),
            ("correlations", TwoPointCorrelation(periodic_boundary=True, cutoff=ctf, correlations=[(0, 1), (1, 1)])),
        ])
        
    def transform(self, X):
        gray_X = np.mean(X, axis=2) / 255.0
        container = np.zeros((gray_X.shape[0], gray_X.shape[1], 2), dtype=np.float64)
        twoPts = self.preprocessor.transform(gray_X).compute()[...,0] # Has two channels (channel 2 useless)
        container[...,0] = gray_X
        container[...,1][:twoPts.shape[0], :twoPts.shape[1]] = twoPts
        return container
    
class ColorGenerator(MultiDataGenerator):
    '''
    Yields centroid colors of images via K-Means clustering.
    '''
    def __init__(self, root, batch_size=50, subset=None, res=(150,150)):
        MultiDataGenerator.__init__(self, root, batch_size, subset=None, res=(150,150))
        self.data_is_baked = False
        self.rgbs_data = np.empty((self.n, 6), dtype=np.float32)
    
    def _data_generation(self, batch):
        data = np.empty((len(batch), 6), dtype=np.float32)
        labels = np.empty((len(batch), 6), dtype=np.float32)
        for k, file in enumerate(batch):
            if self.data_is_baked:
                rgbs_data = self.rgbs_data[k*len(batch):(k+1)*len(batch)]
            else:
                rgbs_data = self._extract_rgbs(np.array(Image.open(file), dtype=np.uint8))
                self.rgbs_data[k*len(batch):(k+1)*len(batch)] = rgbs_data
            rgbs_labels = list(read_metaImage(file).get('material_params').values())[:-2]
            if self.mse(rgbs_data[0:3], rgbs_labels[3:6]) < self.mse(rgbs_data[0:3], rgbs_labels[0:3]):
                # Flip colors 1 and 2 (put k-means centroids in correct order based on mse distance)
                temp = rgbs_data[0:3].copy()
                rgbs_data[0:3] = rgbs_data[3:6]
                rgbs_data[3:6] = temp # classic example... I forgot how to do this more efficiently
            data[k] = rgbs_data
            labels[k] = rgbs_labels
        
        if (k+1)*len(batch) == self.n:
            self.data_is_baked = True
        
        return data[:,:3], labels[:,:3]
    
    def mse(self, a, b):
        return np.mean(np.square(a-b))
    
    def _extract_rgbs(self, im):
        '''Returns array of (6,) RGB colors based on K-Means clustering of an image'''
        rx, ry, _ = im.shape
        # Extract training data for K-Means
        im_flat = im.copy().reshape((rx*ry, 3))
        np.random.shuffle(im_flat)
        data_extract = im_flat[:500]
        # Fit KMeans
        prime_colors = (KMeans(n_clusters=2).fit(data_extract).cluster_centers_ / 255.0).ravel()
        return prime_colors
    
def precompute_correlations(gen, n_outputs=1, lim=None):
    '''
    lim - maximum number of batches to pre-compute
    '''
    correlator = TwoPtsCorrelator(ctf=gen.rx)
    if lim is None:
        lim = np.floor(gen.n / gen.bs).astype('int')
    tot_X = np.empty((lim*gen.bs, gen.rx, gen.ry, 2), dtype=np.float32)
    tot_y = np.empty((lim*gen.bs, n_outputs), dtype=np.float32)
    bs = gen.bs
    for k, (X, y) in enumerate(gen):
        tot_X[(k*bs):(k+1)*bs] = [correlator.transform(x) for x in X]
        tot_y[(k*bs):(k+1)*bs] = y
        if (k+1) >= lim:
            break
    return tot_X, tot_y

def precompute_colors(gen, lim=None):
    '''
    lim - maximum number of batches to pre-compute
    '''
    if lim is None:
        lim = np.floor(gen.n / gen.bs).astype('int')
    tot_X = np.empty((lim*gen.bs, 3), dtype=np.float32)
    tot_y = np.empty((lim*gen.bs, 3), dtype=np.float32)
    bs = gen.bs
    for k, (X, y) in enumerate(gen):
        tot_X[(k*bs):(k+1)*bs] = X
        tot_y[(k*bs):(k+1)*bs] = y
        if (k+1) >= lim:
            break
    return tot_X, tot_y

In [None]:
# Define where to save pre-computed data
training_root = os.path.abspath(os.path.join(os.getcwd(), f'renders/training/*.jpg'))
validation_root = os.path.abspath(os.path.join(os.getcwd(), f'renders/validation/*.jpg'))
test_root = os.path.abspath(os.path.join(os.getcwd(), f'renders/test/*.jpg'))

### Pre-computing two-point correlations (for structural texture parameter estimation)

In [27]:
# Texture names and number of parameters to predict
subsets = {'Brick':1, 'Checker':1, 'Magic':2, 'Noise':2, 'Wave':2}

for subset, out in subsets.items():
    # Training set
    xtr, ytr = precompute_correlations(
        gen=MultiDataGenerator(training_root, subset=subset), n_outputs=out, lim=40)
    np.save(os.path.abspath(os.path.join(os.getcwd(), f'point_correlations/{subset}/xtr.npy')), xtr)
    np.save(os.path.abspath(os.path.join(os.getcwd(), f'point_correlations/{subset}/ytr.npy')), ytr)
    
    # Validation set
    xval, yval = precompute_correlations(
        gen=MultiDataGenerator(validation_root, batch_size=1, subset=subset), n_outputs=out)
    np.save(os.path.abspath(os.path.join(os.getcwd(), f'point_correlations/{subset}/xval.npy')), xval)
    np.save(os.path.abspath(os.path.join(os.getcwd(), f'point_correlations/{subset}/yval.npy')), yval)
    
    # Test set
    xte, yte = precompute_correlations(
        gen=MultiDataGenerator(test_root, batch_size=1, subset=subset), n_outputs=out)
    np.save(os.path.abspath(os.path.join(os.getcwd(), f'point_correlations/{subset}/xte.npy')), xte)
    np.save(os.path.abspath(os.path.join(os.getcwd(), f'point_correlations/{subset}/yte.npy')), yte)

Initialized data generator. Found 2021 files.
(200, 150, 150, 2)


### Pre-computing centroid colors by K-Means (for texture base color estimation)

In [24]:
# Training set
xtr, ytr = precompute_colors(gen=ColorGenerator(training_root), lim=40)
np.save(os.path.abspath(os.path.join(os.getcwd(), f'color_translation/xtr.npy')), xtr)
np.save(os.path.abspath(os.path.join(os.getcwd(), f'color_translation/ytr.npy')), ytr)

# Validation set
xval, yval = precompute_colors(gen=ColorGenerator(validation_root))
np.save(os.path.abspath(os.path.join(os.getcwd(), f'color_translation/xval.npy')), xval)
np.save(os.path.abspath(os.path.join(os.getcwd(), f'color_translation/yval.npy')), yval)

# Test set
xte, yte = precompute_colors(gen=ColorGenerator(test_root))
np.save(os.path.abspath(os.path.join(os.getcwd(), f'color_translation/xte.npy')), xte)
np.save(os.path.abspath(os.path.join(os.getcwd(), f'color_translation/yte.npy')), yte)

Initialized data generator. Found 10200 files.
Done!
