In [None]:
# following the tutorial here https://www.tensorflow.org/tutorials/generative/style_transfer

In [None]:
!pip install --quiet tensorflow-text

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_text
import tensorflow_hub as hub
print(tf.__version__)

2.3.0


In [None]:
preprocessor = hub.KerasLayer(
      "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1")
encoder = hub.KerasLayer(
    "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3",
    trainable=False)





















In [None]:
sentence = "This is a nice sentence."





In [None]:
def gram_matrix(input_tensor):
  result = tf.linalg.einsum('bic,bid->bcd', input_tensor, input_tensor)
  input_shape = tf.shape(input_tensor)
  num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32)
  return result/(num_locations)

In [None]:
class StyleContentModel(tf.keras.models.Model):
  def __init__(self, style_layers, content_layers):
    super(StyleContentModel, self).__init__()
    self.preprocessor = preprocessor
    self.encoder = encoder
    self.style_layers = style_layers
    self.content_layers = content_layers

  def call(self, inputs):
    "Expects float input in [0,1]"
    preprocessed_input = self.preprocessor(inputs)
    outputs = self.encoder(preprocessed_input)
    layers = outputs['encoder_outputs']
    style_outputs = [style_output for i, style_output in enumerate(layers) if i in self.style_layers]
    content_outputs = [content_output for i, content_output in enumerate(layers) if i in self.content_layers]

    style_outputs = [gram_matrix(style_output) for style_output in style_outputs]

    return content_outputs, style_outputs

In [None]:
style_layers = [0,1,2,3,4]
content_layers = [11]
extractor = StyleContentModel(style_layers, content_layers)
results = extractor([sentence])

In [None]:
results

([<tf.Tensor: shape=(1, 128, 768), dtype=float32, numpy=
  array([[[-0.04140566,  0.02487258, -0.18559837, ..., -0.16935846,
            0.02828475,  0.6497009 ],
          [-0.6632701 , -0.01108311, -0.04537605, ..., -0.00759035,
            0.17624772,  0.4847691 ],
          [ 0.06166041, -0.19566756,  0.38863105, ..., -0.19084993,
           -0.00946588,  1.1805072 ],
          ...,
          [ 0.04975953, -0.24883376,  0.33861274, ...,  0.3099918 ,
           -0.16195074,  0.09633178],
          [ 0.07145615, -0.17512976,  0.4124881 , ...,  0.30505356,
           -0.20710379,  0.09815944],
          [-0.14201781, -0.13025676,  0.43745166, ...,  0.27715382,
           -0.2818718 ,  0.13325058]]], dtype=float32)>],
 [<tf.Tensor: shape=(1, 768, 768), dtype=float32, numpy=
  array([[[ 5.9883347e-05,  4.9552327e-05, -5.1466439e-05, ...,
           -3.3446533e-05, -5.7796283e-06,  7.9453451e-07],
          [ 4.9552327e-05,  1.6897061e-04, -2.0542578e-04, ...,
           -1.4332432e-04, 