Skip to content

Commit c4501c0

Browse files
simplify the network and add a transformation encoder
1 parent 3787ab3 commit c4501c0

File tree

3 files changed

+40
-72
lines changed

3 files changed

+40
-72
lines changed

NN/FaceMeshEncoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import tensorflow as tf
2-
from NN.Utils import sMLP, CRMLBlock
2+
from NN.Utils import sMLP, CFusingBlock
33
from NN.CCoordsEncodingLayer import CCoordsEncodingLayer
44
from Core.Utils import FACE_MESH_INVALID_VALUE
55

@@ -15,7 +15,7 @@ def __init__(self, latentSize, **kwargs):
1515
self._sMLP2 = sMLP(sizes=[latentSize], activation='relu', name='FaceMeshEncoder/sMLP-2')
1616

1717
self._RML = [
18-
CRMLBlock(
18+
CFusingBlock(
1919
mlp=sMLP(
2020
sizes=[latentSize * 2] * 3,
2121
activation='relu', name=f'FaceMeshEncoder/RML-{i}/mlp'

NN/Utils.py

Lines changed: 5 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -227,72 +227,17 @@ def call(self, x, training=None):
227227
quantized = 0.5 + tf.clip_by_value(quantized, self._minValue, self._maxValue)
228228
return x + tf.stop_gradient(quantized - x)
229229
####################################
230-
class CResidualMultiplicativeLayer(tf.keras.layers.Layer):
231-
def __init__(self, eps=1e-8, headsN=1, **kwargs):
232-
super().__init__(**kwargs)
233-
self._eps = eps
234-
self._scale = tf.Variable(
235-
initial_value=tf.random.normal((1, ), mean=0.0, stddev=0.1),
236-
trainable=True, dtype=tf.float32,
237-
name=self.name + '/_scale'
238-
)
239-
self._headsN = headsN
240-
self._normalization = None
241-
return
242-
243-
@property
244-
def scale(self): return tf.nn.sigmoid(self._scale) * (1.0 - 2.0 * self._eps) + self._eps # [eps, 1 - eps]
245-
246-
def _SMNormalization(self, xhat):
247-
xhat = tf.nn.softmax(xhat, axis=-1)
248-
xhat = xhat - tf.reduce_mean(xhat, axis=-1, keepdims=True)
249-
rng = tf.reduce_max(tf.abs(xhat), axis=-1, keepdims=True)
250-
return 1.0 + tf.math.divide_no_nan(xhat, rng * self.scale) # [1 - scale, 1 + scale]
251-
252-
def _HeadwiseNormalizationNoPadding(self, xhat):
253-
shape = tf.shape(xhat)
254-
# reshape [B, ..., N * headsN] -> [B, ..., headsN, N], apply normalization, reshape back
255-
xhat = tf.reshape(xhat, tf.concat([shape[:-1], [self._headsN, shape[-1] // self._headsN]], axis=-1))
256-
xhat = self._SMNormalization(xhat)
257-
xhat = tf.reshape(xhat, shape)
258-
return xhat
259-
260-
def _HeadwiseNormalizationPadded(self, lastChunk):
261-
def F(xhat):
262-
mainPart = self._HeadwiseNormalizationNoPadding(xhat[..., :-lastChunk])
263-
tailPart = self._SMNormalization(xhat[..., -lastChunk:])
264-
return tf.concat([mainPart, tailPart], axis=-1)
265-
return F
266-
267-
def build(self, input_shapes):
268-
_, xhatShape = input_shapes
269-
self._normalization = self._SMNormalization
270-
if 1 < self._headsN:
271-
assert 1 < (xhatShape[-1] // self._headsN), "too few channels for headsN"
272-
273-
lastChunk = xhatShape[-1] % self._headsN
274-
self._normalization = self._HeadwiseNormalizationPadded(lastChunk) if 0 < lastChunk else self._HeadwiseNormalizationNoPadding
275-
pass
276-
return super().build(input_shapes)
277-
278-
def call(self, x):
279-
x, xhat = x
280-
# return (tf.nn.relu(x) + self._eps) * (self._normalization(xhat) + self._eps) # more general/stable version
281-
# with SM normalization, relu and addition are redundant
282-
return x * self._normalization(xhat)
283-
####################################
284-
class CRMLBlock(tf.keras.Model):
285-
def __init__(self, mlp=None, RML=None, **kwargs):
230+
class CFusingBlock(tf.keras.Model):
231+
def __init__(self, mlp=None, **kwargs):
286232
super().__init__(**kwargs)
287233
if mlp is None: mlp = lambda x: x
288234
self._mlp = mlp
289-
if RML is None: RML = CResidualMultiplicativeLayer()
290-
self._RML = RML
291235
return
292236

293237
def build(self, input_shapes):
294238
xShape = input_shapes[0]
295239
self._lastDense = L.Dense(xShape[-1], activation='relu', name='%s/LastDense' % self.name)
240+
self._combiner = L.Dense(xShape[-1], activation='relu', name='%s/Combiner' % self.name)
296241
return super().build(input_shapes)
297242

298243
def call(self, x):
@@ -301,7 +246,8 @@ def call(self, x):
301246
xhat = self._mlp(xhat)
302247
xhat = self._lastDense(xhat)
303248
x0 = x[0]
304-
return self._RML([x0, xhat])
249+
x = tf.concat([x0, xhat], axis=-1)
250+
return self._combiner(x)
305251
####################################
306252
# Hacky way to provide same optimizer for all models
307253
def createOptimizer(config=None):

NN/networks.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def Face2StepModel(pointsN, eyeSize, latentSize, embeddingsSize):
5656
# we need to combine them together and with the encodedP
5757
combined = encodedP # start with the face features
5858
for i, EFeat in enumerate(encodedEFList):
59-
combined = CResidualMultiplicativeLayer(name='F2S/ResMul-%d' % i)([
59+
combined = CFusingBlock(name='F2S/ResMul-%d' % i)([
6060
combined,
6161
sMLP(sizes=[latentSize] * 1, activation='relu', name='F2S/MLP-%d' % i)(
6262
L.Concatenate(-1)([combined, encodedP, EFeat, embeddings])
@@ -94,7 +94,7 @@ def Step2LatentModel(latentSize, embeddingsSize):
9494
temporal = sMLP(sizes=[latentSize] * 1, activation='relu')(
9595
L.Concatenate(-1)([stepsData, encodedT, embeddings])
9696
)
97-
temporal = CResidualMultiplicativeLayer()([stepsData, temporal])
97+
temporal = CFusingBlock()([stepsData, temporal])
9898
intermediate['S2L/enc0'] = temporal
9999
# # # # # # # # # # # # # # # # # # # # # # # # # # # # #
100100
for blockId in range(3):
@@ -104,14 +104,14 @@ def Step2LatentModel(latentSize, embeddingsSize):
104104
temp = sMLP(sizes=[latentSize] * 1, activation='relu')(
105105
L.Concatenate(-1)([temporal, temp])
106106
)
107-
temporal = CResidualMultiplicativeLayer()([temporal, temp])
107+
temporal = CFusingBlock()([temporal, temp])
108108
intermediate['S2L/ResLSTM-%d' % blockId] = temporal
109109
continue
110110
# # # # # # # # # # # # # # # # # # # # # # # # # # # # #
111111
latent = sMLP(sizes=[latentSize] * 1, activation='relu')(
112112
L.Concatenate(-1)([stepsData, temporal, encodedT, encodedT])
113113
)
114-
latent = CResidualMultiplicativeLayer()([stepsData, latent])
114+
latent = CFusingBlock()([stepsData, latent])
115115
return tf.keras.Model(
116116
inputs={
117117
'latent': latents,
@@ -185,6 +185,35 @@ def Face2LatentModel(
185185
IP = lambda x: IntermediatePredictor()(x) # own IntermediatePredictor for each output
186186
res['intermediate'] = {k: IP(x) for k, x in intermediate.items()}
187187
res['result'] = IP(res['latent'])
188+
###################################
189+
# TODO: figure out is this helpful or not
190+
# branch for global coordinates transformation
191+
# predict shift, rotation, scale
192+
emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])
193+
emb = sMLP(sizes=[64, 64, 64, 64, 32], activation='relu')(emb[:, 0])
194+
shift = L.Dense(2, name='GlobalShift')(emb)[:, None]
195+
rotation = L.Dense(1, name='GlobalRotation', activation='sigmoid')(emb)[:, None] * np.pi
196+
scale = L.Dense(2, name='GlobalScale')(emb)[:, None]
197+
198+
shifted = res['result'] + shift - 0.5 # [0.5, 0.5] -> [0, 0]
199+
# Rotation matrix components
200+
cos_rotation = L.Lambda(lambda x: tf.cos(x))(rotation)
201+
sin_rotation = L.Lambda(lambda x: tf.sin(x))(rotation)
202+
rotation_matrix = L.Lambda(lambda x: tf.stack([x[0], x[1]], axis=-1))([cos_rotation, sin_rotation])
203+
204+
# Apply rotation
205+
rotated = L.Lambda(
206+
lambda x: tf.einsum('isj,iomj->isj', x[0], x[1])
207+
)([shifted, rotation_matrix]) + 0.5 # [0, 0] -> [0.5, 0.5] back
208+
209+
# Apply scale
210+
scaled = rotated * scale
211+
def clipWithGradient(x):
212+
res = tf.clip_by_value(x, 0.0, 1.0)
213+
return x + tf.stop_gradient(res - x)
214+
215+
res['result'] = L.Lambda(clipWithGradient)(scaled)
216+
###################################
188217

189218
main = tf.keras.Model(inputs=inputs, outputs=res)
190219
return {
@@ -195,13 +224,6 @@ def Face2LatentModel(
195224
}
196225

197226
if __name__ == '__main__':
198-
# autoencoder = FaceAutoencoderModel(latentSize=64, means={
199-
# 'points': np.zeros((478, 2), np.float32),
200-
# 'left eye': np.zeros((32, 32), np.float32),
201-
# 'right eye': np.zeros((32, 32), np.float32),
202-
# })['main']
203-
# autoencoder.summary(expand_nested=True)
204-
205227
X = Face2LatentModel(steps=5, latentSize=64,
206228
embeddings={
207229
'userId': 1, 'placeId': 1, 'screenId': 1, 'size': 64

0 commit comments

Comments
 (0)