# Federated Learning - Text Generation

*Course: Machine Learning Projects with TensorFlow 2.0 by Vlad Sebastian Ionescu*

*Tutorial: https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation#load_and_preprocess_the_federated_shakespeare_data*

Federated Learning means training on a distributed way and combining the results on a centralized server

In [1]:
import nest_asyncio
import tensorflow_federated as tff

In [2]:
import collections
import functools
import os
import time
import numpy as np
import tensorflow as tf

nest_asyncio.apply()
tf.compat.v1.enable_v2_behavior()

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()

b'Hello, World!'

## 1. Data Processing

In [3]:
# A fixed vocabularly of ASCII chars that occur in the works of Shakespeare and Dickens:
vocab = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAEIMQUY]!%)-159\r')

# Creating a mapping from unique characters to indices
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

In [4]:
train_data, test_data = tff.simulation.datasets.shakespeare.load_data()

In [5]:
# Here the play is "The Tragedy of King Lear" and the character is "King".
raw_example_dataset = train_data.create_tf_dataset_for_client(
    'THE_TRAGEDY_OF_KING_LEAR_KING')
# To allow for future extensions, each entry x
# is an OrderedDict with a single key 'snippets' which contains the text.
for x in raw_example_dataset.take(2):
  print(x['snippets'])

tf.Tensor(b"Live regist'red upon our brazen tombs,\nAnd then grace us in the disgrace of death;\nWhen, spite of cormorant devouring Time,\nTh' endeavour of this present breath may buy\nThat honour which shall bate his scythe's keen edge,\nAnd make us heirs of all eternity.\nTherefore, brave conquerors- for so you are\nThat war against your own affections\nAnd the huge army of the world's desires-\nOur late edict shall strongly stand in force:\nNavarre shall be the wonder of the world;\nOur court shall be a little Academe,\nStill and contemplative in living art.\nYou three, Berowne, Dumain, and Longaville,\nHave sworn for three years' term to live with me\nMy fellow-scholars, and to keep those statutes\nThat are recorded in this schedule here.\nYour oaths are pass'd; and now subscribe your names,\nThat his own hand may strike his honour down\nThat violates the smallest branch herein.\nIf you are arm'd to do as sworn to do,\nSubscribe to your deep oaths, and keep it too.\nYour oath is pa

## 2. Text Generation

In [6]:
# Input pre-processing parameters
SEQ_LENGTH = 100
BATCH_SIZE = 8
BUFFER_SIZE = 100  # For dataset shuffling

In [7]:
# Construct a lookup table to map string chars to indexes,
# using the vocab loaded above:
table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(
        keys=vocab, values=tf.constant(list(range(len(vocab))),
                                       dtype=tf.int64)),
    default_value=0)


def to_ids(x):
    s = tf.reshape(x['snippets'], shape=[1])
    chars = tf.strings.bytes_split(s).values
    ids = table.lookup(chars)
    return ids


def split_input_target(chunk):
    input_text = tf.map_fn(lambda x: x[:-1], chunk)
    target_text = tf.map_fn(lambda x: x[1:], chunk)
    return (input_text, target_text)


def preprocess(dataset):
    return (
      # Map ASCII chars to int64 indexes using the vocab
      dataset.map(to_ids)
      # Split into individual chars
      .unbatch()
      # Form example sequences of SEQ_LENGTH +1
      .batch(SEQ_LENGTH + 1, drop_remainder=True)
      # Shuffle and form minibatches
      .shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
      # And finally split into (input, target) tuples,
      # each of length SEQ_LENGTH.
      .map(split_input_target))


In [8]:
example_dataset = preprocess(raw_example_dataset)
print(example_dataset.element_spec)

(TensorSpec(shape=(8, 100), dtype=tf.int64, name=None), TensorSpec(shape=(8, 100), dtype=tf.int64, name=None))


In [9]:
def load_model(batch_size):
    urls = {
        1: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel',
        8: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel'}
    assert batch_size in urls, 'batch_size must be in ' + str(urls.keys())
    url = urls[batch_size]
    local_file = tf.keras.utils.get_file(os.path.basename(url), origin=url)  
    return tf.keras.models.load_model(local_file, compile=False)

In [10]:
def generate_text(model, start_string):
    
    # Number of characters to generate
    num_generate = 1000
    
    # Coverting the start_string to numbers (vectorizing)
    input_eval = [char2idx[s] for s in start_string]
    input_eval = tf.expand_dims(input_eval, 0)
    
    text_generated = []
    # Low temperature --> more predictable text
    # High temperature --> more surprising text
    temperature = 1.0
    
    model.reset_states()
    for i in range(num_generate):
        predictions = model(input_eval)
        predictions = tf.squeeze(predictions, 0)
        
        # Using a categorical distribution to predict the char returned by the model
        predictions = predictions / temperature
        predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()
        
        # We pass the predicted char to the model along with the previous hidden state
        input_eval = tf.expand_dims([predicted_id], 0)
        
        text_generated.append(idx2char[predicted_id])
        
    return (start_string + ''.join(text_generated))

In [11]:
# Text generation requires a batch_size=1 model.
keras_model_batch1 = load_model(batch_size=1)
print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))

What of TensorFlow Federated, you ask? S.

"So frightfully for fave to come or not trace, and touch the crime as
got up interest; and what we supportable I had, and he cannot
sure not that silence. It had a ridier wn the village, up of Saint Antoine, and
to be the object that could not conseng to the lighting up into the dess, "and how can you say he was,
it seemed after you, I am glided whom the
straite had made his three last shining upon it--"I must foin. Proof should keep in Suspicion
 thoughtfully and as me! but only to light I am highly
for that Defarges, and a few Frame ago. From the gendral, showed him greatly.




IX. Two

magn from our mut of the Project Gutenberg time when he had wandered away uarted in his sinister family; it may be madness me.
Le very good what midday, and brought him his hand her
lips to himself as he wore, whisper itself to the coachman, and the distant streets, Mr.
Lorry's eyes gradually sound, like a dream, Charles Darnay swoonspend out the air, lagaga

In [12]:
BATCH_SIZE = 8  # The training and eval batch size for the rest of this tutorial.
keras_model = load_model(batch_size=BATCH_SIZE)
keras_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))

In [13]:
def create_tff_model():
    # TFF uses an `input_spec` so it knows the types and shapes
    # that your model expects.
    input_spec = example_dataset.element_spec
    keras_model_clone = tf.keras.models.clone_model(keras_model)
    
    return tff.learning.from_keras_model(
        keras_model_clone,
        input_spec=input_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))


In [14]:
# This command builds all the TensorFlow graphs and serializes them: 
fed_avg = tff.learning.build_federated_averaging_process(
    model_fn=create_tff_model,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=0.5))

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.






In [16]:
nest_asyncio.apply()
NUM_ROUNDS = 5
state = fed_avg.initialize()
for _ in range(NUM_ROUNDS):
    state, metrics = fed_avg.next(state, [example_dataset.take(5)])
    print(f'loss={metrics.train}')

loss=<loss=4.401425838470459>
loss=<loss=4.270419120788574>
loss=<loss=4.158207893371582>
loss=<loss=4.049384117126465>
loss=<loss=3.96478271484375>
