In [8]:
import tensorflow as tf
import models.transformer.fashion_encoder as fashion_enc
import official.transformer.v2.misc as misc
import importlib

def parse_example(raw):
    example = tf.io.parse_single_sequence_example(
          raw, sequence_features= {
              "categories": tf.io.FixedLenSequenceFeature([], tf.int64),
              "features": tf.io.FixedLenSequenceFeature(2048,tf.float32)
          })
    return example[1]["features"]

def get_dataset():
    filenames = ["output-000-5.tfrecord"]
    raw_dataset = tf.data.TFRecordDataset(filenames)
    return raw_dataset.map(parse_example)

def loss(y_pred, y_true):
    # Dot product for whole batch
    res = tf.math.multiply(y_pred, y_true)
    res = tf.math.reduce_sum(res)
    return res

def xentropy_loss(y_pred, y_true):
    feature_dim = y_pred.shape[2]
    # Reshape to batch (size * seq length, feature dim)
    pred_batch = tf.reshape(y_pred, [-1, feature_dim])
    true_batch = tf.reshape(y_true, [-1, feature_dim])
    item_count = true_batch.shape[0]
    # Dot product of every prediction with all labels
    logits = tf.matmul(pred_batch, true_batch, transpose_b=True)
    # One-hot labels (the indentity matrix)
    labels = tf.eye(item_count, item_count)
    
    return tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels, logits))

In [14]:
a = tf.constant([[[0.25106466, 0.00863039, 0.77634084, 0.23125434, 0.39757788],
  [0.82858,    0.44247448, 0.9647038,  0.31152928, 0.7297981 ]],
 [[0.8109468,  0.4998417,  0.31938195 ,0.03625381 ,0.92246723],
  [0.03745341, 0.7016275,  0.58033264, 0.10916102, 0.7598084 ]],
 [[0.57786167, 0.48389828, 0.7559997 , 0.51311195, 0.3556137 ],
  [0.06377017, 0.37903416, 0.69512844, 0.8320352 , 0.3528793 ]]])
b = tf.constant([[[0.25106466, 0.00863039, 0.77634084, 0.23125434, 0.39757788],
  [0.82858,    0.44247448, 0.9647038,  0.31152928, 0.7297981 ]],
 [[0.8109468,  0.4998417,  0.31938195 ,0.03625381 ,0.92246723],
  [0.03745341, 0.7016275,  0.58033264, 0.10916102, 0.7598084 ]],
 [[0.57786167, 0.48389828, 0.7559997 , 0.51311195, 0.3556137 ],
  [0.06377017, 0.37903416, 0.69512844, 0.8320352 , 0.3528793 ]]])

# print(a)
# print(b)
l = xentropy_loss(a,b)
print(l)

tf.Tensor(
[[0.8773598  1.322979   0.8309996  0.7933215  0.9962138  0.891647  ]
 [1.322979   2.4426377  1.8857194  1.4898481  1.8416088  1.4078786 ]
 [0.8309996  1.8857194  1.8617413  1.2712791  1.2985845  0.81886685]
 [0.7933215  1.4898481  1.2712791  1.4196948  1.1261005  1.0306814 ]
 [0.9962138  1.8416088  1.2985845  1.1261005  1.5293622  1.2981972 ]
 [0.891647   1.4078786  0.81886685 1.0306814  1.2981972  1.4477434 ]], shape=(6, 6), dtype=float32)
tf.Tensor(
[[1. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1.]], shape=(6, 6), dtype=float32)
tf.Tensor(9.154351, shape=(), dtype=float32)


In [None]:
def grad(model: tf.keras.Model, inputs, targets):
  with tf.GradientTape() as tape:
    loss_value = xentropy_loss(model(inputs), targets, training=True)
  return loss_value, tape.gradient(loss_value, model.trainable_variables)   

def train(epoch, train_dataset: tf.data.Dataset, model: tf.keras.Model, optimizer):
  epoch_loss_avg = tf.keras.metrics.Mean()
  epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

  # Training loop - using batches of 32
  for x, y in train_dataset:
    # Optimize the model
    loss_value, grads = grad(model, x, y)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Track progress
    epoch_loss_avg(loss_value)  # Add current batch loss
    # Compare predicted label to actual label
    # training=True is needed only if there are layers with different
    # behavior during training versus inference (e.g. Dropout).
    epoch_accuracy(y, model(x, training=True))

  # # End epoch
  # train_loss_results.append(epoch_loss_avg.result())
  # train_accuracy_results.append(epoch_accuracy.result())

  if epoch % 50 == 0:
    print("Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}".format(epoch,
                                                                epoch_loss_avg.result(),
                                                                epoch_accuracy.result()))   


In [None]:
importlib.reload(fashion_enc)

params = misc.get_model_params("base", 1)

params.update({
    "feature_dim": 512,
    "dtype": "float32",
    "hidden_size": 512,
    "extra_decode_length": 0,
    "num_hidden_layers": 1,
    "num_heads": 2,
    "max_length": 10,
    "default_batch_size": 128,
    "filter_size": 1024
})

model = fashion_enc.create_model(params, False)
model.summary()