In [1]:
import datetime
import tensorflow as tf
import models.transformer.fashion_encoder as fashion_enc
import official.transformer.v2.misc as misc
import models.transformer.metrics as metrics
import importlib

def parse_example(raw):
    example = tf.io.parse_single_sequence_example(
          raw, sequence_features= {
              "categories": tf.io.FixedLenSequenceFeature([], tf.int64),
              "features": tf.io.FixedLenSequenceFeature(2048,tf.float32)
          })
    return example[1]["features"]

def get_dataset():
    filenames = ["output-000-5.tfrecord"]
    raw_dataset = tf.data.TFRecordDataset(filenames)
    return raw_dataset.map(parse_example)

def duplicate(example):
    return example, example
    
def loss(y_pred, y_true):
    # Dot product for whole batch
    res = tf.math.multiply(y_pred, y_true)
    res = tf.math.reduce_sum(res)
    return res

def xentropy_loss(y_pred, y_true):
    feature_dim = y_pred.shape[2]
    # Reshape to batch (size * seq length, feature dim)
    pred_batch = tf.reshape(y_pred, [-1, feature_dim])
    true_batch = tf.reshape(y_true, [-1, feature_dim])
    item_count = true_batch.shape[0]
    # Dot product of every prediction with all labels
    logits = tf.matmul(pred_batch, true_batch, transpose_b=True)
    # One-hot labels (the indentity matrix)
    labels = tf.eye(item_count, item_count)
    
    return tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels, logits))

In [2]:
a = tf.constant([[[0.25106466, 0.00863039, 0.77634084, 0.23125434, 0.39757788],
  [0.82858,    0.44247448, 0.9647038,  0.31152928, 0.7297981 ]],
 [[0.8109468,  0.4998417,  0.31938195 ,0.03625381 ,0.92246723],
  [0.03745341, 0.7016275,  0.58033264, 0.10916102, 0.7598084 ]],
 [[0.57786167, 0.48389828, 0.7559997 , 0.51311195, 0.3556137 ],
  [0.06377017, 0.37903416, 0.69512844, 0.8320352 , 0.3528793 ]]])
b = tf.constant([[[0.25106466, 0.00863039, 0.77634084, 0.23125434, 0.39757788],
  [0.82858,    0.44247448, 0.9647038,  0.31152928, 0.7297981 ]],
 [[0.8109468,  0.4998417,  0.31938195 ,0.03625381 ,0.92246723],
  [0.03745341, 0.7016275,  0.58033264, 0.10916102, 0.7598084 ]],
 [[0.57786167, 0.48389828, 0.7559997 , 0.51311195, 0.3556137 ],
  [0.06377017, 0.37903416, 0.69512844, 0.8320352 , 0.3528793 ]]])

# print(a)
# print(b)
l = xentropy_loss(a,b)
print(l)

tf.Tensor(9.154351, shape=(), dtype=float32)


In [13]:
def grad(model: tf.keras.Model, inputs, targets):
  with tf.GradientTape() as tape:
    loss_value = xentropy_loss(model(inputs, training=True), targets)
  return loss_value, tape.gradient(loss_value, model.trainable_variables)   

def train(num_epochs, train_dataset: tf.data.Dataset, model: tf.keras.Model, optimizer):
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    batch_number = 1
    
    ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer,model=model)
    manager = tf.train.CheckpointManager(ckpt, './logs/tf_ckpts', max_to_keep=3)
    
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")
    
    for epoch in range(num_epochs):
      epoch_loss_avg = tf.keras.metrics.Mean('epoch_loss')
      train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
      # Training loop
      for x, y in train_dataset:
        # Optimize the model
        loss_value, grads = grad(model, x, y)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
        ckpt.step.assign_add(1)
        
        # Track progress
        epoch_loss_avg(loss_value)  # Add current batch loss
        train_loss(loss_value)   
        
        with train_summary_writer.as_default():
            tf.summary.scalar('loss', train_loss.result(), step=batch_number)
        batch_number = batch_number + 1
      with train_summary_writer.as_default():    
        tf.summary.scalar('epoch_loss', epoch_loss_avg.result(), step=epoch)
          
      if epoch % 50 == 0:
        print("Epoch {:03d}: Loss: {:.3f}".format(epoch, epoch_loss_avg.result()))
        save_path = manager.save()
        print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
            
    print(batch_number)


In [10]:
importlib.reload(fashion_enc)
importlib.reload(metrics)
batch_size = 2
num_epoch = 2000
params = misc.get_model_params("base", 1)

params.update({
    "feature_dim": 2048,
    "dtype": "float32",
    "hidden_size": 2048,
    "extra_decode_length": 0,
    "num_hidden_layers": 1,
    "num_heads": 2,
    "max_length": 10,
    "default_batch_size": 128,
    "filter_size": 1024
})

model = fashion_enc.create_model(params, True)
model.summary()

# model.compile('adam')
# 
# outfits = get_dataset()
# outfits = outfits.map(duplicate)
# outfits = outfits.padded_batch(batch_size, (params["max_length"], 2048))
# model.fit(outfits, callbacks=[tf.keras.callbacks.TensorBoard(), tf.keras.callbacks.ModelCheckpoint("./checkpoints/")])

Model: "transformer_v2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder_stack_2 (EncoderStac multiple                  20986880  
Total params: 20,986,880
Trainable params: 20,986,880
Non-trainable params: 0
_________________________________________________________________
Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
inputs (InputLayer)          [(None, None, 2048)]      0         
_________________________________________________________________
transformer_v2 (FashionEncod (None, None, 2048)        20986880  
Total params: 20,986,880
Trainable params: 20,986,880
Non-trainable params: 0
_________________________________________________________________


In [14]:
batch_size = 2
num_epoch = 200
optimizer = tf.optimizers.Adam()
outfits = get_dataset()
outfits = outfits.padded_batch(batch_size, (None, 2048))
outfits = outfits.map(duplicate)

train(num_epoch, outfits, model, optimizer)


Initializing from scratch.
Epoch 000: Loss: 15.888
Saved checkpoint for step 2: ./logs/tf_ckpts\ckpt-1
Epoch 050: Loss: 69.599
Saved checkpoint for step 52: ./logs/tf_ckpts\ckpt-2
Epoch 100: Loss: 33.691
Saved checkpoint for step 102: ./logs/tf_ckpts\ckpt-3
Epoch 150: Loss: 115.017
Saved checkpoint for step 152: ./logs/tf_ckpts\ckpt-4
201
