# TensorFlowFederated Text Classification

This notebook attempts to simulate a federated learning workflow using a centralized dataset.

The dataset is split equally amongst the clients (the current splitting has not been done with the intent of i.i.d client datasets and this condition would need to be added to the flow).

This notebook is based on the TFF example notebook `federated_learning_for_image_classification.ipynb`.
https://www.tensorflow.org/federated/tutorials/federated_learning_for_image_classification

## Loading packages and checking if TFF has been loaded.

In [1]:
import string
import re
import math
import collections
import nest_asyncio

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import tensorflow_datasets as tfds

from tensorflow.keras import layers
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
from matplotlib import pyplot as plt

# Define re-usable print seperators.
sep = '-' * 25
sep_2 = '#' * 10
sep_3 = '-' * 50
sep_4 = '=' * 100

In [2]:
nest_asyncio.apply()
%load_ext tensorboard
np.random.seed(0)

tff.federated_computation(lambda: 'Testing TFF.')()
assert tf.executing_eagerly() == True

## Setting federated learning parameters.

In [3]:
# The number of clients who's tuned weights will be included in the federated model updation.
NUM_CLIENTS = 10

# These parameters were used by the image classification notebook using tff.simulation dataset.
NUM_EPOCHS = 5
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

## Loading the tf.data.Dataset

In [4]:
#dataset_name = "snli"
#dataset_name = "imdb_reviews"
#dataset_name = "sentiment140"

# `ag_news_subset` supervised dataset has 3 features: (`title`, `description`, `label`),
# with feature `label` having 4 classes.
dataset_name = "ag_news_subset"

train_data, test_data= tfds.load(
    name=dataset_name,
    split=["train", "test"], 
    with_info=False, 
    as_supervised=True,
    shuffle_files=True
)

assert isinstance(train_data, tf.data.Dataset)
assert isinstance(test_data, tf.data.Dataset)
print('Using dataset:', dataset_name)
print('train_data.element_spec:', train_data.element_spec)
print('type(train_data):', type(train_data))

Using dataset: ag_news_subset
train_data.element_spec: (TensorSpec(shape=(), dtype=tf.string, name=None), TensorSpec(shape=(), dtype=tf.int64, name=None))
type(train_data): <class 'tensorflow.python.data.ops.dataset_ops._OptionsDataset'>


### Extracting dataset as list.

In [5]:
train_data_list = list(train_data)

### Setting number of rows in dataset and batch size per client.

In [6]:
NUM_ROWS = len(train_data_list)
BATCH_SIZE = math.floor(NUM_ROWS/NUM_CLIENTS)

## Extracting features from rows in train dataset (2 methods)

### 1. Using Generator

#### Defining a function that returns row values extracted from the train dataset.

In [7]:
def train_data_gen():
    """Generator that returns row features."""
    for row in train_data_list:
        (x, y) = row
        yield (x, y)

In [8]:
# Generator returns a dataset (from which client datasets are created).
generator_train_data = tf.data.Dataset.from_generator(
    train_data_gen, 
    output_types=(tf.string, tf.int64), 
    output_shapes=((), ())
)

print('type(generator_train_data):', type(generator_train_data))

type(generator_train_data): <class 'tensorflow.python.data.ops.dataset_ops.FlatMapDataset'>


### 2. Using Tensor Slices

In [9]:
# Extract row features and append to feature list.
description, label = [], []
for row in train_data:
    description.append(row[0])
    label.append(row[1])

# Create dataset from tensor slices.
tensorslice_train_data = tf.data.Dataset.from_tensor_slices((description, label))
print('type(tensorslice_train_data):', type(tensorslice_train_data))

type(tensorslice_train_data): <class 'tensorflow.python.data.ops.dataset_ops.TensorSliceDataset'>


## Using extracted data to create Client datasets.

### from TensorSliceDataset data.

In [10]:
def create_client_dict(dataset, name, 
                       show=True, num_clients=NUM_CLIENTS, 
                       batch_size=BATCH_SIZE, num_rows=NUM_ROWS):
    """Takes in extracted data and returns a dict containing client datasets."""
    client_dict = {}
    if show:
        print('Name of dataset:', name)
        print("Num of rows:", num_rows)
        print("Client batch size:", batch_size)
        print(sep)
    batched_data = list(dataset.batch(batch_size))
    for i in range(num_clients):
        key = "client_" + str(i)
        client_dict[key] = tf.data.Dataset.from_tensor_slices(batched_data[i])
    return client_dict

In [11]:
def batch_format_fn(element, show=False):
    """Return the features as an `OrderedDict`."""
    if show:
        print('preprocess:', element['description'], element['label'])
    return collections.OrderedDict(
        x=tf.reshape(element['description'], [-1, 1]),
        y=tf.reshape(element['label'], [-1, 1])
    )

def preprocess(dataset):
    return dataset.map(batch_format_fn).prefetch(PREFETCH_BUFFER)
    #return batch_format_fn(dataset)

In [12]:
# The preprocessing has been commented out as this was pertaining to re-shaping image data.
# (There could be mapping functions used which are necessary to load the data correctly, 
# please refer to the image classification counterpart function)
def make_federated_data(client_data, client_ids):
    """Returns client datasets as a list."""
    return [
        client_data[x]
        #preprocess(client_data[x])
        for x in client_ids
    ]

## Create client dataset dictionaries from the extracted data (using the 2 methods mentioned above).

In [13]:
client_datasets_tensor = create_client_dict(tensorslice_train_data, name='tensorslice')
client_datasets_generator = create_client_dict(generator_train_data, name='generator')

# Check if dict keys are the same.
assert client_datasets_tensor.keys() == client_datasets_generator.keys()
client_list = client_datasets_tensor.keys()

# Check the datasets.
for x in client_list:
    assert type(client_datasets_tensor[x]) == type(client_datasets_generator[x])
    assert len(client_datasets_tensor[x]) == len(client_datasets_generator[x])
    assert len(client_datasets_tensor[x]) == BATCH_SIZE

Name of dataset: tensorslice
Num of rows: 120000
Client batch size: 12000
-------------------------
Name of dataset: generator
Num of rows: 120000
Client batch size: 12000
-------------------------


In [14]:
def peek_client_data(client_datasets, num_clients=1):
    """Prints the 1st row values (w/ reshaping) for the 1st `num_clients` clients."""
    for client_index, key in enumerate(client_datasets.keys()):
        if client_index == num_clients:
            break
        print(key, 'has', len(list(client_datasets[key])), 'rows.\n')
        for index, row in enumerate(client_datasets[key]):
            (x, y) = row
            if index == 0:
                print('1st row:', row)
                print(sep)
                print('x:', x)
                print('type(x):', type(x))
                print(sep_2)
                x_reshaped = tf.reshape(x, [-1, 1])
                print('reshaped x:', x_reshaped)
                print('type(reshaped x):', type(x_reshaped))
                print(sep)
                print('y:', y)
                print('type(y):', type(y))
                print(sep_2)
                y_reshaped = tf.reshape(y, [-1, 1])
                print('reshaped y:', y_reshaped)
                print('type(reshaped y):', type(y_reshaped))
                print(sep_4, '\n')
                break

In [15]:
# Uncomment and run to print the 1st row values (w/ reshaping) of the 1st client.
#peek_client_data(client_datasets_tensor)
#peek_client_data(client_datasets_generator)

## Preprocessing the input data

In [16]:
# Generate federated datasets using the client dataset dictionaries 
# created from the 2 extraction methods.
federated_train_data_tensor = make_federated_data(client_datasets_tensor, client_list)
federated_train_data_generator = make_federated_data(client_datasets_generator, client_list)

assert len(federated_train_data_tensor) == NUM_CLIENTS
assert len(federated_train_data_tensor) == len(federated_train_data_generator)
assert format(federated_train_data_tensor[0]) == format(federated_train_data_generator[0])
assert federated_train_data_tensor[0].element_spec == federated_train_data_generator[0].element_spec
"""
# Throws `AssertionError` but upon printing they seem same.
# I suspect this could also be causing an error.

assert federated_train_data_tensor[0] == federated_train_data_generator[0]
"""
print('federated_train_data_tensor[0]:', federated_train_data_tensor[0])
print('federated_train_data_generator[0]:', federated_train_data_generator[0])
print(sep)
print('Number of client datasets: {l}'.format(l=len(federated_train_data_tensor)))
print('First dataset: {d}'.format(d=federated_train_data_tensor[0]))
print('element_spec: {d}'.format(d=federated_train_data_tensor[0].element_spec))

federated_train_data_tensor[0]: <TensorSliceDataset shapes: ((), ()), types: (tf.string, tf.int64)>
federated_train_data_generator[0]: <TensorSliceDataset shapes: ((), ()), types: (tf.string, tf.int64)>
-------------------------
Number of client datasets: 10
First dataset: <TensorSliceDataset shapes: ((), ()), types: (tf.string, tf.int64)>
element_spec: (TensorSpec(shape=(), dtype=tf.string, name=None), TensorSpec(shape=(), dtype=tf.int64, name=None))


## Creating a model with Keras

Created using reference https://keras.io/examples/nlp/text_classification_from_scratch/

In [17]:
def custom_standardization(input_data):
    """Function that applies standardization in TextVectorization layer."""
    lowercase = tf.strings.lower(input_data)
    stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")
    return tf.strings.regex_replace(
        stripped_html, "[%s]" % re.escape(string.punctuation), ""
    )

def create_keras_model():
    max_features = 20000
    embedding_dim = 128
    sequence_length = 500
    
    return tf.keras.models.Sequential([
        tf.keras.layers.InputLayer(input_shape=(1,), dtype=tf.string),
        tf.keras.layers.experimental.preprocessing.TextVectorization(
            standardize=custom_standardization,
            max_tokens=max_features,
            output_mode="int",
            output_sequence_length=sequence_length,),
        tf.keras.layers.Embedding(max_features, embedding_dim),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Conv1D(128, 7, padding="valid", activation="relu", strides=3),
        tf.keras.layers.Conv1D(128, 7, padding="valid", activation="relu", strides=3),
        tf.keras.layers.GlobalMaxPooling1D(),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(1, activation="sigmoid"),
        tf.keras.layers.Softmax(),
    ])

In [18]:
def model_fn():
    keras_model = create_keras_model()
    print(keras_model.summary())
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=federated_train_data_tensor[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

## Training the model on federated data

In [19]:
iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
)

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
text_vectorization (TextVect (None, 500)               0         
_________________________________________________________________
embedding (Embedding)        (None, 500, 128)          2560000   
_________________________________________________________________
dropout (Dropout)            (None, 500, 128)          0         
_________________________________________________________________
conv1d (Conv1D)              (None, 165, 128)          114816    
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 53, 128)           114816    
_________________________________________________________________
global_max_pooling1d (Global (None, 128)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               1

TypeError: Expected tensorflow.python.ops.variables.Variable, found tensorflow.python.keras.engine.base_layer_utils.TrackableWeightHandler.

In [None]:
str(iterative_process.initialize.type_signature)

In [None]:
state = iterative_process.initialize()

In [None]:
# SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))

In [None]:
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))

In [None]:
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    print('round {:2d}, metrics={}'.format(round_num, metrics))

## Displaying model metrics in TensorBoard

In [None]:
logdir = "./.tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

In [None]:
with summary_writer.as_default():
    for round_num in range(1, NUM_ROUNDS):
        state, metrics = iterative_process.next(state, federated_train_data)
        for name, value in metrics['train'].items():
            tf.summary.scalar(name, value, step=round_num)

In [None]:
!ls {logdir}
%tensorboard --logdir {logdir} --port=0

# !rm -R ./.tmp/logs/scalars/*