In [1]:
import keras
from keras.initializers import Constant
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Input, Activation, Dense, Flatten, Add, Concatenate
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling2D
from keras.layers.normalization import BatchNormalization
from keras.layers import Lambda, UpSampling2D
from keras.layers import ReLU, LeakyReLU, PReLU
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
import keras.backend as K
print(K.image_data_format())
K.set_learning_phase(1) #for batchnorm

import tensorflow as tf
#tf.enable_eager_execution()

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
#torch.backends.cudnn.enabled = True

def tt(x):
    return torch.tensor(x).float()

def npa(x):
    return x.detach().numpy()

# Generate pictures with different lights to verify that pmodel.Train(True) is actually relighting
# and always looks better. Check how to reproduce pmodel.Train(False) in keras.

# Save keras model as single file and test loading it.
# Compare speeds of keras and pytorch.
# Generate animation?

# pytorch is channels_first. now keras too (keras config).
# In keras batch normalization, use axis=1 if channels_first.
# Keras batchnorm: K.set_learning_phase(1)
# Pytorch batchnorm: don't use model.train(False)

# The pytorch network specifies values for mean and variance in some batchnorms, but I didn't translate them to Keras because they don't seem to produce any change.

Using TensorFlow backend.


channels_first


In [2]:

#keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
#                                center=True, scale=True, beta_initializer='zeros',
#                                gamma_initializer='ones', moving_mean_initializer='zeros',
#                                moving_variance_initializer='ones', beta_regularizer=None,
#                                gamma_regularizer=None, beta_constraint=None, gamma_constraint=None)

#nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#bnopt = {'momentum':0.1, 'scale':False, 'moving_mean_initializer':Constant(value=0.5),'moving_variance_initializer':Constant(value=0.25)}

bnopt = {
    'axis':1,
    'epsilon':1e-5,
    'center':True
}

def pconv3X3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def kconv3x3(out_planes, stride=1, name=None):
    return Conv2D(out_planes, kernel_size=3, strides=stride, padding='same', use_bias=False, name=name)

class pBasicBlock(nn.Module):
    def __init__(self, inplanes, outplanes, batchNorm_type=0, stride=1, downsample=None):
        super(pBasicBlock, self).__init__()
        self.inplanes = inplanes
        self.outplanes = outplanes
        self.conv1 = pconv3X3(inplanes, outplanes, 1)
        self.conv2 = pconv3X3(outplanes, outplanes, 1)
        if batchNorm_type == 0:
            self.bn1 = nn.BatchNorm2d(outplanes)
            self.bn2 = nn.BatchNorm2d(outplanes)
        else:
            self.bn1 = nn.InstanceNorm2d(outplanes)
            self.bn2 = nn.InstanceNorm2d(outplanes)
        self.shortcuts = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.inplanes != self.outplanes:
            out += self.shortcuts(x)
        else:
            out += x
        out = F.relu(out)
        return out

def kBasicBlock(inplanes, outplanes, batchNorm_type=0, name=None):
    x = Input(shape=(inplanes, None, None))
    out = kconv3x3(outplanes, 1, name='conv1')(x)
    if batchNorm_type == 0:
        out = BatchNormalization(**bnopt, name='bn1')(out)
    elif batchNorm_type == 1:
        out = InstanceNormalization(**bnopt, name='in1')(out)
    out = Activation('relu')(out)
    out = kconv3x3(outplanes, 1, name='conv2')(out)
    if batchNorm_type == 0:
        out = BatchNormalization(**bnopt, name='bn2')(out)
    elif batchNorm_type == 1:
        out = InstanceNormalization(**bnopt, name='in2')(out)
    shortcuts = Conv2D(outplanes, kernel_size=1, strides=1, use_bias=False, name='shortcuts')(x)
    if (inplanes != outplanes):
        out = Add()([out, shortcuts])
    else:
        out = Add()([out, x])
    out = Activation('relu')(out)
    return Model(inputs=x, outputs=out, name=name)

# x = np.random.rand(1,64,6,6)
# kmodel_ = kBasicBlock(64, 155, batchNorm_type=0)
# pmodel_ = pmodel.HG0.low1 #pBasicBlock(2, 4)

# kmodel_ = pbblock_to_kbblock(pmodel_, kmodel_)

# pout = npa(pmodel_(tt(x)))
# kout = kmodel_.predict(x)
# diff = np.abs(pout - kout)
# print(np.abs(pout).mean(), np.abs(kout).mean(), diff.mean())


In [3]:
def pconv2d_to_kconv2d(pconv, kconv): #assuming keras channels_first
    pweights = npa(pconv.weight)
    kweights = np.transpose(pweights, (2,3,1,0))
    if (pconv.bias is None):
        kconv.set_weights([kweights])
    else:
        pbias = npa(pconv.bias)
        kconv.set_weights([kweights, pbias])
    return kconv

def pbatchnorm_to_kbatchnorm(player, klayer):
    pweights = npa(player.weight)
    kweights = klayer.get_weights()
    kweights[0] = pweights
    if (player.bias is not None):
        kweights[1] = npa(player.bias)
    klayer.set_weights(kweights)
    return klayer

def pbblock_to_kbblock(pmodel, kmodel): #fix bn1 and bn2 
    klayer_names = [layer.name for layer in kmodel.layers]
    pconv2d_to_kconv2d(pmodel.conv1, kmodel.get_layer('conv1'))
    pconv2d_to_kconv2d(pmodel.conv2, kmodel.get_layer('conv2'))
    #if (pmodel.inplanes != pmodel.outplanes):
    if ('shortcuts' in klayer_names):
        pconv2d_to_kconv2d(pmodel.shortcuts, kmodel.get_layer('shortcuts'))
    if ('bn1' in klayer_names) and ('bn2' in klayer_names):
        pbatchnorm_to_kbatchnorm(pmodel.bn1, kmodel.get_layer('bn1'))
        pbatchnorm_to_kbatchnorm(pmodel.bn2, kmodel.get_layer('bn2'))
    return kmodel

def pprelu_to_kprelu(pprelu, kprelu):
    alpha = npa(pprelu.weight)[0]
    kprelu.set_weights([np.ones_like(kprelu.get_weights()[0])*alpha])
    return kprelu

def plightnet_to_klightnet(pmodel, kmodel):
    pconv2d_to_kconv2d(pmodel.predict_FC1, kmodel.get_layer('predict_FC1'))
    pconv2d_to_kconv2d(pmodel.predict_FC2, kmodel.get_layer('predict_FC2'))
    pprelu_to_kprelu(pmodel.predict_relu1, kmodel.get_layer('predict_relu1'))

    pconv2d_to_kconv2d(pmodel.post_FC1, kmodel.get_layer('post_FC1'))
    pconv2d_to_kconv2d(pmodel.post_FC2, kmodel.get_layer('post_FC2'))
    pprelu_to_kprelu(pmodel.post_relu1, kmodel.get_layer('post_relu1'))
    return kmodel


In [None]:
# # Weight conversion of conv2D

# x = np.random.rand(1,2,6,6)
# W = np.random.rand(2,2,3,3) #np.ones((2,2,3,3))

# #pconv.weight = torch.nn.Parameter(tt(W))
# #pconv2d_to_kconv2d(pconv, kconv.layers[1])
# #kconv2d_to_pconv2d(kconv.layers[1], pconv)

# kconv = kconv3x3(2, 2)
# #print(kconv.weights[0].numpy().shape)
# #print(kconv.weights[0].numpy())
# #print(kconv.predict(x))

# pconv = pconv3X3(2, 2)
# #print(npa(pconv.weight).shape)
# #print(npa(pconv.weight))
# #print(npa(pconv(tt(x))))

# print(pconv.bias)

In [None]:
# # # # # Matching of BatchNorm

# def kbatchnorm(dim):
#     x = Input(shape=(dim,None,None))
#     out = BatchNormalization(**bnopt_bias)(x)
#     #out = InstanceNormalization(**bnopt_bias)(x)
#     return Model(inputs=x, outputs=out)

# # def pbatchnorm(dim):
# #     return nn.BatchNorm2d(dim)
# #     #return nn.InstanceNorm2d(dim)

# x = np.random.rand(1,155,100,100)

# kbn = kbatchnorm(155)
# pbn = pmodel.HG0.low1.bn2
# # pbn = pbatchnorm(2)
# # pbn.bias = nn.Parameter(tt([2.0, 2.0]))
# pbatchnorm_to_kbatchnorm(pbn, kbn.layers[1])

# pout = npa(pbn(tt(x))).flatten()
# kout = kbn.predict(x).flatten()
# diff = np.abs(pout - kout)
# print(np.abs(pout).mean(), np.abs(kout).mean(), diff.mean())


In [None]:
# # # Matching BasicBlock

# x = np.random.rand(1,2,6,6)

# kmodel = kBasicBlock(2, 4)
# pmodel = pBasicBlock(2, 4)

# pbblock_to_kbblock(pmodel, kmodel)

# pout = npa(pmodel(tt(x)))
# kout = kmodel.predict(x)
# diff = np.abs(pout - kout)
# print(np.abs(pout).mean(), np.abs(kout).mean(), diff.mean())

# # kmodel.summary()

In [4]:
# Define lightingNets and compare them

class plightingNet(nn.Module):
    def __init__(self, ncInput, ncOutput, ncMiddle):
        super(plightingNet, self).__init__()
        self.ncInput = ncInput
        self.ncOutput = ncOutput
        self.ncMiddle = ncMiddle
        self.predict_FC1 = nn.Conv2d(self.ncInput,  self.ncMiddle, kernel_size=1, stride=1, bias=False)
        self.predict_relu1 = nn.PReLU()
        self.predict_FC2 = nn.Conv2d(self.ncMiddle, self.ncOutput, kernel_size=1, stride=1, bias=False)

        self.post_FC1 = nn.Conv2d(self.ncOutput,  self.ncMiddle, kernel_size=1, stride=1, bias=False)
        self.post_relu1 = nn.PReLU()
        self.post_FC2 = nn.Conv2d(self.ncMiddle, self.ncInput, kernel_size=1, stride=1, bias=False)
        self.post_relu2 = nn.ReLU()

    def forward(self, innerFeat, target_light):
        x = innerFeat[:,0:self.ncInput,:,:] # lighting feature
        _, _, row, col = x.shape
        # predict lighting
        feat = x.mean(dim=(2,3), keepdim=True)
        light = self.predict_relu1(self.predict_FC1(feat))
        light = self.predict_FC2(light)
        # get back the feature space
        upFeat = self.post_relu1(self.post_FC1(target_light))
        upFeat = self.post_relu2(self.post_FC2(upFeat))
        upFeat = upFeat.repeat((1,1,row, col))
        innerFeat[:,0:self.ncInput,:,:] = upFeat
        return innerFeat, light

def klightingNet(ncInput, ncOutput, ncMiddle, name=None): #27, 9, 128 #len(zs), #SH coeff, ...
    innerFeat = Input(shape=(None, None, None)) #155,32,32
    x = Lambda(lambda x_: K.slice(x_, (0,0,0,0), (-1,ncInput,-1,-1)))(innerFeat) #27,32,32
    #row, col = K.shape(x)[2], K.shape(x)[3] #32,32

    feat = Lambda(lambda x_: K.mean(x_, axis=(2,3), keepdims=True))(x) #27,32,32
    pred_light = Conv2D(ncMiddle, kernel_size=1, strides=1, use_bias=False, name='predict_FC1')(feat) #128,32,32
    pred_light = PReLU(shared_axes=[2,3], name='predict_relu1')(pred_light) #shared_axes solves bug in tf 1.13. not sure of side-effect.
    pred_light = Conv2D(ncOutput, kernel_size=1, strides=1, use_bias=False, name='predict_FC2')(pred_light) #9,32,32

    target_light = Input(shape=(ncOutput, None, None)) #9,1,1
    upFeat = Conv2D(ncMiddle, kernel_size=1, strides=1, use_bias=False, name='post_FC1')(target_light) #128,1,1
    upFeat = PReLU(shared_axes=[2,3], name='post_relu1')(upFeat) #shared_axes solves bug in tf 1.13. not sure of side-effect.
    upFeat = Conv2D(ncInput, kernel_size=1, strides=1, use_bias=False, name='post_FC2')(upFeat) #27,1,1
    upFeat = ReLU()(upFeat)
    upFeat = Lambda(lambda x_: K.tile(x_, (1, 1, 32, 32)))(upFeat) #27,32,32 #had to hard-code this
    #upFeat = Lambda(lambda x_: K.tile(x_, (1, 1, row, col)))(upFeat) #27,32,32?

    x2 = Lambda(lambda x_: K.slice(x_, (0,ncInput,0,0), (-1,-1,-1,-1)))(innerFeat) #27,32,32
    xout = Concatenate(axis=1)([upFeat, x2]) #155,32,32
    return Model(inputs=[innerFeat, target_light], outputs=[xout, pred_light], name=name) #(155, 32, 32), (9, 1, 1)

# ninp, nout, nmid = 27, 9, 128
# row, col = 32, 32
# x = np.random.rand(1, 155, row, col)
# tlight = np.ones((1, nout, 1, 1))

# plnet = plightingNet(ninp, nout, nmid)
# pout = plnet(tt(x), tt(tlight))

# klnet = klightingNet(ninp, nout, nmid)
# klnet = plightnet_to_klightnet(plnet, klnet)
# kout = klnet.predict([x, tlight])

# for i in range(2):
#     x1 = npa(pout[i])
#     x2 = kout[i]
#     diff = np.abs(x1-x2)
#     print(x1.mean(), x2.mean(), diff.mean())

# print(pout[0].shape, pout[1].shape)
# print(kout[0].shape, kout[1].shape)


In [5]:
# Define HourglassBlocks and compare them

class pHourglassBlock(nn.Module):
    def __init__(self, inplane, mid_plane, middleNet):
        super(pHourglassBlock, self).__init__()
        # upper branch
        self.upper = pBasicBlock(inplane, inplane, batchNorm_type=1)
        # lower branch
        self.downSample = nn.MaxPool2d(kernel_size=2, stride=2)
        self.low1 = pBasicBlock(inplane, mid_plane)
        self.middle = middleNet
        self.low2 = pBasicBlock(mid_plane, inplane, batchNorm_type=1)
        self.upSample = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x, light):
        out_upper = self.upper(x)
        out_lower = self.downSample(x)
        out_lower = self.low1(out_lower)
        out_lower, out_middle = self.middle(out_lower, light)
        out_lower = self.low2(out_lower)
        out_lower = self.upSample(out_lower)
        out = out_lower + out_upper
        return out, out_middle

# We assume skip_count=0, and don't implement count
def kHourglassBlock(inplane, mid_plane, middleNet, name=None): #64, 155, klnet(27,9,128)
    x = Input(shape=(inplane, None, None)) #64,64,64
    target_light = Input(shape=(9, None, None)) #9,1,1 #hardcoded -> nout
    out_upper = kBasicBlock(inplane, inplane, batchNorm_type=1, name='upper')(x) #64,64,64
    out_lower = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='downsample')(x) #64,32,32
    out_lower = kBasicBlock(inplane, mid_plane, name='low1')(out_lower) #155,32,32 (OK)
    out_lower, pred_light = middleNet([out_lower, target_light]) #(155,32,32), (9,1,1) -> (155,32,32), (9,1,1)
    out_lower = kBasicBlock(mid_plane, inplane, batchNorm_type=1, name='low2')(out_lower) #64,32,32
    out_lower = UpSampling2D(size=(2, 2), name='upsample')(out_lower) #64,64,64
    out = Add()([out_lower, out_upper]) #64,64,64
    return Model(inputs=[x, target_light], outputs=[out, pred_light], name=name) #(64,64,64), (9,1,1)

def phgblock_to_khgblock(pmodel, kmodel):
    klayer_names = [layer.name for layer in kmodel.layers]
    pbblock_to_kbblock(pmodel.upper, kmodel.get_layer('upper'))
    pbblock_to_kbblock(pmodel.low1, kmodel.get_layer('low1'))
    pbblock_to_kbblock(pmodel.low2, kmodel.get_layer('low2'))
    if ('klnet' in klayer_names):
        plightnet_to_klightnet(pmodel.middle, kmodel.get_layer('klnet'))
    else:
        hg_layers = [s for s in klayer_names if 'HG' in s]
        phgblock_to_khgblock(pmodel.middle, kmodel.get_layer(hg_layers[0]))
    return kmodel

# ninp, nout, nmid = 27, 9, 128
# row, col = 64, 64
# inplane, mid_plane = 64, 155
# klnet_input = [np.random.rand(1, mid_plane, row//2, col//2), np.ones((1, nout, 1, 1))] #(155,32,32), (9,1,1)
# hgnet_input = [np.random.rand(1, inplane, row, col), np.ones((1, nout, 1, 1))] #(64,64,64), (9,1,1)

# plnet_ = pmodel.light #plightingNet(ninp, nout, nmid)
# pmodel_ = pmodel.HG0 #pHourglassBlock(inplane, mid_plane, plnet)
# pout = pmodel_(tt(hgnet_input[0]), tt(hgnet_input[1]))

# klnet_ = klightingNet(ninp, nout, nmid, name='klnet')
# klnet_ = plightnet_to_klightnet(plnet_, klnet_)
# kmodel_ = kHourglassBlock(inplane, mid_plane, klnet_)
# kmodel_ = phgblock_to_khgblock(pmodel_, kmodel_)
# kout = kmodel_.predict(hgnet_input)

# for i in range(2):
#     x1 = npa(pout[i])
#     x2 = kout[i]
#     diff = np.abs(x1-x2)
#     print(x1.mean(), x2.mean(), diff.mean())

# print(pout[0].shape, pout[1].shape)
# print(kout[0].shape, kout[1].shape)


In [6]:
# Define HourglassNets and compare them

class pHourglassNet(nn.Module):
    def __init__(self, baseFilter = 16, gray=True):
        super(pHourglassNet, self).__init__()
        self.ncLight = 27   # number of channels for input to lighting network
        self.baseFilter = baseFilter

        # number of channles for output of lighting network
        if gray:
            self.ncOutLight = 9  # gray: channel is 1
        else:
            self.ncOutLight = 27  # color: channel is 3

        self.ncPre = self.baseFilter  # number of channels for pre-convolution
        # number of channels 
        self.ncHG3 = self.baseFilter
        self.ncHG2 = 2*self.baseFilter
        self.ncHG1 = 4*self.baseFilter
        self.ncHG0 = 8*self.baseFilter + self.ncLight

        self.pre_conv = nn.Conv2d(1, self.ncPre, kernel_size=5, stride=1, padding=2)
        self.pre_bn = nn.BatchNorm2d(self.ncPre)

        self.light = plightingNet(self.ncLight, self.ncOutLight, 128)
        self.HG0 = pHourglassBlock(self.ncHG1, self.ncHG0, self.light)
        self.HG1 = pHourglassBlock(self.ncHG2, self.ncHG1, self.HG0)
        self.HG2 = pHourglassBlock(self.ncHG3, self.ncHG2, self.HG1)
        self.HG3 = pHourglassBlock(self.ncPre, self.ncHG3, self.HG2)

        self.conv_1 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=3, stride=1, padding=1)
        self.bn_1 = nn.BatchNorm2d(self.ncPre) 
        self.conv_2 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0)
        self.bn_2 = nn.BatchNorm2d(self.ncPre) 
        self.conv_3 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0)
        self.bn_3 = nn.BatchNorm2d(self.ncPre)

        self.output = nn.Conv2d(self.ncPre, 1, kernel_size=1, stride=1, padding=0)

    def forward(self, x, target_light): #, skip_count):
        feat = self.pre_conv(x)
        feat = F.relu(self.pre_bn(feat))
        feat, out_light = self.HG3(feat, target_light) #, 0, skip_count)
        feat = F.relu(self.bn_1(self.conv_1(feat)))
        feat = F.relu(self.bn_2(self.conv_2(feat)))
        feat = F.relu(self.bn_3(self.conv_3(feat)))
        out_img = self.output(feat)
        out_img = torch.sigmoid(out_img)
        return out_img, out_light

def kHourglassNet(baseFilter=16, name=None):
    x = Input(shape=(1, None, None)) #1,512,512
    target_light = Input(shape=(9, None, None)) #9,1,1 #hardcoded -> nout
    feat = keras.layers.ZeroPadding2D(padding=(2, 2))(x)
    feat = Conv2D(baseFilter, kernel_size=5, strides=1, padding='valid', name='pre_conv')(feat)
    feat = BatchNormalization(**bnopt_bias, name='pre_bn')(feat)
    feat = Activation('relu')(feat)

    ncPre = baseFilter
    ncLight = 27
    ncOutLight = 9
    ncHG3 = baseFilter
    ncHG2 = 2*baseFilter
    ncHG1 = 4*baseFilter
    ncHG0 = 8*baseFilter + ncLight

    klnet = klightingNet(ncLight, ncOutLight, 128, name='klnet')
    HG0 = kHourglassBlock(ncHG1, ncHG0, klnet, name='HG0')
    HG1 = kHourglassBlock(ncHG2, ncHG1, HG0, name='HG1')
    HG2 = kHourglassBlock(ncHG3, ncHG2, HG1, name='HG2')
    HG3 = kHourglassBlock(ncPre, ncHG3, HG2, name='HG3')
    feat, out_light = HG3([feat, target_light])

    feat = Conv2D(ncPre, kernel_size=3, strides=1, padding='same', name='conv_1')(feat)
    feat = BatchNormalization(**bnopt_bias, name='bn_1')(feat)
    feat = Activation('relu')(feat)

    feat = Conv2D(ncPre, kernel_size=1, strides=1, padding='valid', name='conv_2')(feat)
    feat = BatchNormalization(**bnopt_bias, name='bn_2')(feat)
    feat = Activation('relu')(feat)

    feat = Conv2D(ncPre, kernel_size=1, strides=1, padding='valid', name='conv_3')(feat)
    feat = BatchNormalization(**bnopt_bias, name='bn_3')(feat)
    feat = Activation('relu')(feat)

    out_img = Conv2D(1, kernel_size=1, strides=1, padding='valid', name='output')(feat)
    out_img = Activation('sigmoid')(out_img)
    return Model(inputs=[x, target_light], outputs=[out_img, out_light], name=name)

def phgnet_to_khgnet(pmodel, kmodel):
    pconv2d_to_kconv2d(pmodel.pre_conv, kmodel.get_layer('pre_conv'))
    pbatchnorm_to_kbatchnorm(pmodel.pre_bn, kmodel.get_layer('pre_bn'))
    phgblock_to_khgblock(pmodel.HG3, kmodel.get_layer('HG3'))
    pconv2d_to_kconv2d(pmodel.conv_1, kmodel.get_layer('conv_1'))
    pbatchnorm_to_kbatchnorm(pmodel.bn_1, kmodel.get_layer('bn_1'))
    pconv2d_to_kconv2d(pmodel.conv_2, kmodel.get_layer('conv_2'))
    pbatchnorm_to_kbatchnorm(pmodel.bn_2, kmodel.get_layer('bn_2'))
    pconv2d_to_kconv2d(pmodel.conv_3, kmodel.get_layer('conv_3'))
    pbatchnorm_to_kbatchnorm(pmodel.bn_3, kmodel.get_layer('bn_3'))
    pconv2d_to_kconv2d(pmodel.output, kmodel.get_layer('output'))
    return kmodel


In [None]:

# inputL = np.random.rand(1, 1, 512, 512)
# sh = np.random.rand(1, 9, 1, 1)

# pmodel = pHourglassNet()
# pmodel.load_state_dict(torch.load(os.path.join(modelFolder, 'trained_model_03.t7')))
# pout = pmodel(tt(inputL), tt(sh))

# kmodel = kHourglassNet(name='hgnet')
# kmodel = phgnet_to_khgnet(pmodel, kmodel)
# kout = kmodel.predict([inputL, sh])

# for i in range(2):
#     x1 = npa(pout[i])
#     x2 = kout[i]
#     diff = np.abs(x1-x2)
#     print(x1.mean(), x2.mean(), diff.mean())


In [41]:
import cv2, os

modelFolder = 'trained_model/'
lightFolder = 'data/example_light/'
saveFolder = 'result'

pmodel = pHourglassNet()
pmodel.load_state_dict(torch.load(os.path.join(modelFolder, 'trained_model_03.t7')))
pmodel.train(True) #best results with True
# pmodel.cuda()

img = cv2.imread('data/online_guy.jfif')
row, col, _ = img.shape
img = cv2.resize(img, (512, 512))
Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)

inputL = Lab[:, :, 0]
inputL = inputL.astype(np.float32)/255.0
inputL = inputL.transpose((0,1))
inputL = inputL[None, None, ...]
### inputL = torch.from_numpy(inputL)

plt.figure(figsize=(4,10))

for i in range(7):
    sh = np.loadtxt(os.path.join(lightFolder, 'rotate_light_{:02d}.txt'.format(i)))[0:9]
    sh = 0.7 * sh
    sh = np.reshape(sh, (1, 9, 1, 1)).astype(np.float32)

    pimg, psh = pmodel(tt(inputL), tt(sh))
    pimg = pimg[0].cpu().data.numpy()
    pimg = pimg.transpose((1,2,0))
    pimg = np.squeeze(pimg)
    pimg = (pimg*255.0).astype(np.uint8)
    Lab[:,:,0] = pimg
    resultLab = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR)
    resultLab = cv2.resize(resultLab, (col, row))

    #resultLab = cv2.cvtColor(resultLab, cv2.COLOR_BGR2RGB)
    #plt.imshow(resultLab)
    #cv2.imwrite('ponline_guy_traintrue_new.jpg', resultLab)
    cv2.imwrite(os.path.join(saveFolder, 'pytorch_online_guy_{:02d}.jpg'.format(i)), resultLab)


<Figure size 288x720 with 0 Axes>

In [69]:
kmodel.layers[6].layers[2].trainable

False

In [67]:

def set_trainable(model, val=True):
    for layer in model.layers:
        layer.trainable = val
        if 'layers' in dir(layer):
            set_trainable(layer, val)

# saveFolder = 'keras/results/no_bn'
# K.set_learning_phase(1)
# kmodel = kHourglassNet(name='hgnet')
# kmodel = phgnet_to_khgnet(pmodel, kmodel)
#kmodel.trainable = True
set_trainable(kmodel, False)

for i in range(1):
    sh = np.loadtxt(os.path.join(lightFolder, 'rotate_light_{:02d}.txt'.format(i)))[0:9]
    sh = 0.7 * sh
    sh = np.reshape(sh, (1, 9, 1, 1)).astype(np.float32)

    kimg, ksh = kmodel.predict([inputL, sh])
    kimg = kimg[0].transpose((1,2,0))
    kimg = np.squeeze(kimg)
    kimg = (kimg*255.0).astype(np.uint8)
    Lab[:,:,0] = kimg
    resultLab = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR)
    resultLab = cv2.resize(resultLab, (col, row))
    print(resultLab.mean())
    #resultLab = cv2.cvtColor(resultLab, cv2.COLOR_BGR2RGB)
    #plt.imshow(resultLab)
    #cv2.imwrite('konline_guy_traintrue_new.jpg', resultLab)
    #cv2.imwrite(os.path.join(saveFolder, 'keras_online_guy_{:02d}.jpg'.format(i)), resultLab)

117.78985133333333


In [32]:
#kmodel.save('trained_model_03.h5')
# kmodel_json = kmodel.to_json()
# with open('trained_model_03.json', 'w') as json_file:
#     json_file.write(kmodel_json)
# kmodel.save_weights("trained_model_03.weights.h5")

from keras.models import load_model
from keras.models import model_from_json

# model = load_model('trained_model_03.h5')

json_file = open('trained_model_03.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json, custom_objects={'InstanceNormalization':InstanceNormalization})
model.load_weights("trained_model_03.weights.h5")


IndexError: list index out of range

In [None]:
# odict = torch.load(os.path.join(modelFolder, 'trained_model_03.t7'))
# keys = odict.keys()
# # set([key.split('.')[-1] for key in keys])
# for key in keys:
#     if 'HG0.low1' in key: #key.split('.')[-1]:
#         print(key, odict[key].shape) #odict[key])

# print(K.eval(kbn.layers[1].moving_mean))
# print(pbn.running_mean)

# print(K.eval(kbn.layers[1].moving_variance))
# print(pbn.running_var)

#dir(kbn.layers[1])
#  'moving_mean',
#  'moving_mean_initializer',
#  'moving_variance',
#  'moving_variance_initializer',

In [None]:
# import sys
# sys.path.append('model')
# #import defineHourglass_512_gray_skip as pyt
# my_network = pyt.HourglassNet()
# my_network.load_state_dict(torch.load('trained_model/trained_model_03.t7'))


In [None]:
# class KBasicBlock(keras.Model):
#     def __init__(self, inplanes, outplanes, batchNorm_type=0):
#         super(KBasicBlock, self).__init__()
#         self.inplanes = inplanes
#         self.outplanes = outplanes
#         self.conv1 = kconv3x3(inplanes, outplanes, 1)
#         self.conv2 = kconv3x3(outplanes, outplanes, 1)
#         self.shortcuts = Conv2D(outplanes, kernel_size=3, strides=1, use_bias=False)
#         if batchNorm_type == 0:
#             self.bn1 = BatchNormalization(**bnopt)
#             self.bn2 = BatchNormalization(**bnopt)
#         else: #FIX
#             self.bn1 = BatchNormalization(**bnopt)
#             self.bn2 = BatchNormalization(**bnopt)

#     def call(self, x):
#         out = self.conv1(x)
#         out = self.bn1(out)
#         out = Activation('relu')(out)
#         out = self.conv2(out)
#         out = self.bn2(out)
#         aux = self.shortcuts(x)
#         if self.inplanes != self.outplanes:
#             out = Add()([out, aux]) #self.shortcuts(x)
#         else:
#             out = Add()([out, x])
#         out = Activation('relu')(out)
#         return out

# class KHourglassBlock(keras.Model):
#     def __init__(self, inplane, mid_plane, middleNet):
#         super(KHourglassBlock, self).__init__()
#         # upper branch
#         self.upper = kBasicBlock(inplane, inplane, batchNorm_type=1) #FIX
#         # lower branch
#         self.downSample = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))
#         self.low1 = kBasicBlock(inplane, mid_plane)
#         self.middle = middleNet
#         self.low2 = kBasicBlock(mid_plane, inplane, batchNorm_type=1) #FIX
#         self.upsample = UpSampling2D(size=(2, 2))

#     def call(self, inputs):
#         x = inputs[0]
#         target_light = inputs[1]
#         out_upper = self.upper(x)
#         out_lower = self.downSample(x)
#         out_lower = self.low1(out_lower)
#         out_lower, pred_light = self.middle([out_lower, target_light])
#         out_lower = self.low2(out_lower)
#         out_lower = self.upsample(out_lower)
#         out = Add()([out_lower, out_upper])
#         return out, pred_light

#conc = Lambda(lambda x_: Concatenate(axis=1)([x_[0], x_[1]]))([out_lower, light])
#out_lower, out_middle = middleNet(Concatenate(mode='ave')([out_lower, light]))
#out_lower, out_middle = middleNet(conc)
#out_lower, out_middle = middleNet.call([out_lower, light])
#out_lower, out_middle = Lambda(middleNet)((out_lower, light)) 

# class KlightingNet(keras.Model):
#     def __init__(self, ncInput, ncOutput, ncMiddle):
#         super(KlightingNet, self).__init__()

#         self.ncInput = ncInput
#         self.ncOutput = ncOutput
#         self.ncMiddle = ncMiddle

#         self.slicer = Lambda(lambda x_: K.slice(x_, (0,0,0,0), (-1,self.ncInput,-1,-1)))
#         self.predict_mean = Lambda(lambda x_: K.mean(x_, axis=(2,3), keepdims=True))
#         self.predict_FC1 = Conv2D(ncMiddle, kernel_size=1, strides=1, use_bias=False, name='predict_FC1')
#         self.predict_relu1 = PReLU(name='predict_relu1')
#         self.predict_FC2 = Conv2D(ncOutput, kernel_size=1, strides=1, use_bias=False, name='predict_FC2')

#         self.post_FC1 = Conv2D(ncMiddle, kernel_size=1, strides=1, use_bias=False, name='post_FC1')
#         self.post_relu1 = PReLU(shared_axes=[2,3], name='post_relu1')
#         self.post_FC2 = Conv2D(ncInput, kernel_size=1, strides=1, use_bias=False, name='post_FC2')
#         self.post_relu2 = ReLU()
#         self.post_tiler = Lambda(lambda x_: K.tile(x_, (1, 1, row, col)))

#         self.slicer_x1 = Lambda(lambda x_: x_[:, 0:self.ncInput, :, :])
#         self.slicer_x2 = Lambda(lambda x_: x_[:, self.ncInput:, :, :])

#     def call(self, inputs):
#         innerFeat = inputs[0]
#         x = self.slicer(innerFeat)
#         row, col = K.shape(x)[2], K.shape(x)[3]
#         feat  = self.predict_mean(x)
#         light = self.predict_FC1(feat)
#         light = self.predict_relu1(light)
#         light = self.predict_FC2(light)

#         target_light = inputs[1]
#         upFeat = self.post_FC1(target_light)
#         upFeat = self.post_relu1(upFeat) #shared_axes solves bug in tf 1.13. not sure of side-effect.
#         upFeat = self.post_FC2(upFeat)
#         upFeat = self.post_relu2(upFeat)
#         upFeat = self.post_tiler(upFeat)

#         #x1 = self.slicer_x1(innerFeat)
#         x2 = self.slicer_x2(innerFeat)
#         #print(upFeat.shape, x2.shape)
#         xout = Concatenate(axis=1)([upFeat, x2])
#         return xout, light

