##### Copyright 2020 The TensorFlow Authors.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TensorFlow Constrained Optimization Example Using CelebA Dataset

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/responsible_ai/fairness_indicators/tutorials/Fairness_Indicators_TFCO_CelebA_Case_Study"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/fairness-indicators/blob/master/g3doc/tutorials/Fairness_Indicators_TFCO_CelebA_Case_Study.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/fairness-indicators/tree/master/g3doc/tutorials/Fairness_Indicators_TFCO_CelebA_Case_Study.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/fairness-indicators/g3doc/tutorials/Fairness_Indicators_TFCO_CelebA_Case_Study.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

This notebook demonstrates an easy way to create and optimize constrained problems using the TFCO library. This method can be useful in improving models when we find that they’re not performing equally well across different slices of our data, which we can identify using [Fairness Indicators](https://www.tensorflow.org/responsible_ai/fairness_indicators/guide). The second of Google’s AI principles states that our technology should avoid creating or reinforcing unfair bias, and we believe this technique can help improve model fairness in some situations. In particular, this notebook will:


*   Train a simple, *unconstrained* neural network model to detect a person's smile in images using [`tf.keras`](https://www.tensorflow.org/guide/keras) and the large-scale CelebFaces Attributes ([CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)) dataset.
*   Evaluate model performance against a commonly used fairness metric across age groups, using Fairness Indicators.
*   Set up a simple constrained optimization problem to achieve fairer performance across age groups.
*   Retrain the now *constrained* model and evaluate performance again, ensuring that our chosen fairness metric has improved.

Last updated: 3/11 Feb 2020

# Installation
This notebook was created in [Colaboratory](https://research.google.com/colaboratory/faq.html), connected to the Python 3 Google Compute Engine backend. If you wish to host this notebook in a different environment, then you should not experience any major issues provided you include all the required packages in the cells below.

Note that the very first time you run the pip installs, you may be asked to restart the runtime because of preinstalled out of date packages. Once you do so, the correct packages will be used.

In [2]:
#@title Pip installs
!pip install -q -U pip==20.2

!pip install git+https://github.com/google-research/tensorflow_constrained_optimization
!pip install -q tensorflow-datasets tensorflow
!pip install fairness-indicators \
  "absl-py==0.12.0" \
  "apache-beam<3,>=2.38" \
  "avro-python3==1.9.1" \
  "pyzmq==17.0.0"


Collecting git+https://github.com/google-research/tensorflow_constrained_optimization
  Cloning https://github.com/google-research/tensorflow_constrained_optimization to /tmpfs/tmp/pip-req-build-rjv_6rax














Building wheels for collected packages: tfco-nightly


  Building wheel for tfco-nightly (setup.py) ... [?25l-

 \

 |

 /

 -

 \

 |

 done
[?25h  Created wheel for tfco-nightly: filename=tfco_nightly-0.3.dev20220602-py3-none-any.whl size=194868 sha256=e919d36353c3d8e6474a1bc644bc45cfa930102173b8ab1429f6e08c67ae8a64
  Stored in directory: /tmpfs/tmp/pip-ephem-wheel-cache-xdg39ygr/wheels/33/21/f5/14625dcd44c01ce5ccc7917fa8b0833fb074cb8d70e71d2ce7
Successfully built tfco-nightly


Installing collected packages: tfco-nightly


Successfully installed tfco-nightly-0.3.dev20220602


You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python3.9 -m pip install --upgrade pip' command.[0m


You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python3.9 -m pip install --upgrade pip' command.[0m


Collecting fairness-indicators
  Using cached fairness_indicators-0.39.0-py3-none-any.whl (24 kB)


Collecting absl-py==0.12.0


  Downloading absl_py-0.12.0-py3-none-any.whl (129 kB)
[?25l


[?25hCollecting apache-beam<3,>=2.38
  Using cached apache_beam-2.39.0-cp39-cp39-manylinux2010_x86_64.whl (11.3 MB)


Collecting avro-python3==1.9.1
  Downloading avro-python3-1.9.1.tar.gz (36 kB)


Collecting pyzmq==17.0.0
  Downloading pyzmq-17.0.0.tar.gz (988 kB)
[?25l






Collecting witwidget<2,>=1.4.4
  Using cached witwidget-1.8.1-py3-none-any.whl (1.5 MB)


Collecting tensorflow-data-validation<1.9.0,>=1.8.0
  Using cached tensorflow_data_validation-1.8.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.4 MB)


Collecting tensorflow-model-analysis<0.40,>=0.39
  Using cached tensorflow_model_analysis-0.39.0-py3-none-any.whl (1.8 MB)




Collecting proto-plus<2,>=1.7.1
  Using cached proto_plus-1.20.5-py3-none-any.whl (46 kB)


Processing /home/kbuilder/.cache/pip/wheels/4f/0b/ce/75d96dd714b15e51cb66db631183ea3844e0c4a6d19741a149/dill-0.3.1.1-py3-none-any.whl


Collecting pymongo<4.0.0,>=3.8.0
  Downloading pymongo-3.12.3-cp39-cp39-manylinux2014_x86_64.whl (532 kB)
[?25l




Collecting httplib2<0.20.0,>=0.8
  Using cached httplib2-0.19.1-py3-none-any.whl (95 kB)


Collecting fastavro<2,>=0.23.6
  Using cached fastavro-1.4.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.6 MB)




Collecting hdfs<3.0.0,>=2.1.0
  Using cached hdfs-2.7.0-py3-none-any.whl (34 kB)


Collecting numpy<1.23.0,>=1.14.3


  Using cached numpy-1.22.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.8 MB)


Collecting pyarrow<8.0.0,>=0.15.1
  Downloading pyarrow-7.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.7 MB)
[?25l






Collecting orjson<4.0


  Downloading orjson-3.6.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (256 kB)
[?25l


Collecting cloudpickle<3,>=2.0.0
  Using cached cloudpickle-2.1.0-py3-none-any.whl (25 kB)


Processing /home/kbuilder/.cache/pip/wheels/4a/6c/a6/ffdd136310039bf226f2707a9a8e6857be7d70a3fc061f6b36/crcmod-1.7-cp39-cp39-linux_x86_64.whl








Collecting google-api-python-client>=1.7.8
  Using cached google_api_python_client-2.49.0-py2.py3-none-any.whl (8.5 MB)




Collecting oauth2client>=4.1.3
  Using cached oauth2client-4.1.3-py2.py3-none-any.whl (98 kB)


Processing /home/kbuilder/.cache/pip/wheels/de/2b/b1/c541160670d70f4b08c4786f4e155337d4baeaa3e01d9d1400/pyfarmhash-0.3.2-cp39-cp39-linux_x86_64.whl


Collecting tfx-bsl<1.9,>=1.8.0


  Using cached tfx_bsl-1.8.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (19.2 MB)


Collecting joblib<0.15,>=0.12
  Using cached joblib-0.14.1-py2.py3-none-any.whl (294 kB)




Collecting ipython<8,>=7
  Using cached ipython-7.34.0-py3-none-any.whl (793 kB)






Collecting pyparsing<3,>=2.4.2
  Using cached pyparsing-2.4.7-py2.py3-none-any.whl (67 kB)


Processing /home/kbuilder/.cache/pip/wheels/70/4a/46/1309fc853b8d395e60bafaf1b6df7845bdd82c95fd59dd8d2b/docopt-0.6.2-py2.py3-none-any.whl




Collecting uritemplate<5,>=3.0.1
  Using cached uritemplate-4.1.1-py2.py3-none-any.whl (10 kB)


Collecting google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5
  Using cached google_api_core-2.8.1-py3-none-any.whl (114 kB)


Collecting google-auth-httplib2>=0.1.0
  Using cached google_auth_httplib2-0.1.0-py2.py3-none-any.whl (9.3 kB)




Collecting tensorflow-serving-api!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,<3,>=1.15
  Using cached tensorflow_serving_api-2.8.0-py2.py3-none-any.whl (37 kB)














Building wheels for collected packages: avro-python3, pyzmq


  Building wheel for avro-python3 (setup.py) ... [?25l-

 \

 |

 /

 -

 \

 |

 done
[?25h  Created wheel for avro-python3: filename=avro_python3-1.9.1-py3-none-any.whl size=43179 sha256=78e1163ce9085580b7e62b383d5df787fe11f8f31b74df86ea8216d5bd249d7b
  Stored in directory: /home/kbuilder/.cache/pip/wheels/b8/96/7a/dfe7f817902cd7134d4218ff0e86b7e36671772a1bc37c4ef2


  Building wheel for pyzmq (setup.py) ... [?25l-

 \

 |

 /

 -

 \

 |

 /

 -

 \

 error
[31m  ERROR: Command errored out with exit status 1:
   command: /tmpfs/src/tf_docs_env/bin/python3.9 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmpfs/tmp/pip-install-waff0y9w/pyzmq/setup.py'"'"'; __file__='"'"'/tmpfs/tmp/pip-install-waff0y9w/pyzmq/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmpfs/tmp/pip-wheel-pongh8ek
       cwd: /tmpfs/tmp/pip-install-waff0y9w/pyzmq/
  Complete output (528 lines):
  running bdist_wheel
  running build
  running build_py
  creating build
  creating build/lib.linux-x86_64-cpython-39
  creating build/lib.linux-x86_64-cpython-39/zmq
  copying zmq/error.py -> build/lib.linux-x86_64-cpython-39/zmq
  copying zmq/_future.py -> build/lib.linux-x86_64-cpython-39/zmq
  copying zmq/decorators.py -> build/lib.linux-x86_64-cpython-39/zmq
  copying zmq/__init__.py -> build/lib.l

Successfully built avro-python3
Failed to build pyzmq
[33mDEPRECATION: Could not build wheels for pyzmq which do not use PEP 517. pip will fall back to legacy 'setup.py install' for these. pip 21.0 will remove support for this functionality. A possible replacement is to fix the wheel build issue reported above. You can find discussion regarding this at https://github.com/pypa/pip/issues/8368.[0m


Installing collected packages: uritemplate, google-api-core, pyparsing, httplib2, google-auth-httplib2, google-api-python-client, absl-py, oauth2client, witwidget, numpy, pyfarmhash, proto-plus, dill, pymongo, fastavro, docopt, hdfs, pyarrow, orjson, cloudpickle, crcmod, apache-beam, tensorflow-serving-api, tfx-bsl, joblib, tensorflow-data-validation, ipython, tensorflow-model-analysis, fairness-indicators, avro-python3, pyzmq


  Attempting uninstall: pyparsing
    Found existing installation: pyparsing 3.0.9
    Uninstalling pyparsing-3.0.9:
      Successfully uninstalled pyparsing-3.0.9


  Attempting uninstall: absl-py
    Found existing installation: absl-py 1.1.0


    Uninstalling absl-py-1.1.0:
      Successfully uninstalled absl-py-1.1.0


  Attempting uninstall: numpy
    Found existing installation: numpy 1.23.0rc2


    Uninstalling numpy-1.23.0rc2:
      Successfully uninstalled numpy-1.23.0rc2


  Attempting uninstall: dill


    Found existing installation: dill 0.3.5.1
    Uninstalling dill-0.3.5.1:
      Successfully uninstalled dill-0.3.5.1


  Attempting uninstall: joblib


    Found existing installation: joblib 1.1.0
    Uninstalling joblib-1.1.0:
      Successfully uninstalled joblib-1.1.0


  Attempting uninstall: ipython


    Found existing installation: ipython 8.4.0


    Uninstalling ipython-8.4.0:


      Successfully uninstalled ipython-8.4.0


  Attempting uninstall: pyzmq
    Found existing installation: pyzmq 23.0.0


    Uninstalling pyzmq-23.0.0:
      Successfully uninstalled pyzmq-23.0.0


    Running setup.py install for pyzmq ... [?25l-

 \

 |

 /

 -

 \

 |

 /

 -

 \

 error
[31m    ERROR: Command errored out with exit status 1:
     command: /tmpfs/src/tf_docs_env/bin/python3.9 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmpfs/tmp/pip-install-waff0y9w/pyzmq/setup.py'"'"'; __file__='"'"'/tmpfs/tmp/pip-install-waff0y9w/pyzmq/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' install --record /tmpfs/tmp/pip-record-mjvr7b45/install-record.txt --single-version-externally-managed --compile --install-headers /tmpfs/src/tf_docs_env/include/site/python3.9/pyzmq
         cwd: /tmpfs/tmp/pip-install-waff0y9w/pyzmq/
    Complete output (530 lines):
    running install
    running build
    running build_py
    creating build
    creating build/lib.linux-x86_64-cpython-39
    creating build/lib.linux-x86_64-cpython-39/zmq
    copying zmq/error.py -> build/lib.linux-x86_64-cpython-39/zmq
    copying zmq/_futu

You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python3.9 -m pip install --upgrade pip' command.[0m


Note that depending on when you run the cell below, you may receive a warning about the default version of TensorFlow in Colab switching to TensorFlow 2.X soon. You can safely ignore that warning as this notebook was designed to be compatible with TensorFlow 1.X and 2.X.

In [3]:
#@title Import Modules
import os
import sys
import tempfile
import urllib

import tensorflow as tf
from tensorflow import keras

import tensorflow_datasets as tfds
tfds.disable_progress_bar()

import numpy as np

import tensorflow_constrained_optimization as tfco

from tensorflow_metadata.proto.v0 import schema_pb2
from tfx_bsl.tfxio import tensor_adapter
from tfx_bsl.tfxio import tf_example_record

Additionally, we add a few imports that are specific to Fairness Indicators which we will use to evaluate and visualize the model's performance.

In [4]:
#@title Fairness Indicators related imports
import tensorflow_model_analysis as tfma
import fairness_indicators as fi
from google.protobuf import text_format
import apache_beam as beam

Although TFCO is compatible with eager and graph execution, this notebook assumes that eager execution is enabled by default as it is in TensorFlow 2.x. To ensure that nothing breaks, eager execution will be enabled in the cell below.

In [5]:
#@title Enable Eager Execution and Print Versions
if tf.__version__ < "2.0.0":
  tf.compat.v1.enable_eager_execution()
  print("Eager execution enabled.")
else:
  print("Eager execution enabled by default.")

print("TensorFlow " + tf.__version__)
print("TFMA " + tfma.VERSION_STRING)
print("TFDS " + tfds.version.__version__)
print("FI " + fi.version.__version__)

Eager execution enabled by default.
TensorFlow 2.9.1
TFMA 0.39.0
TFDS 4.5.2
FI 0.39.0


# CelebA Dataset
[CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) is a large-scale face attributes dataset with more than 200,000 celebrity images, each with 40 attribute annotations (such as hair type, fashion accessories, facial features, etc.) and 5 landmark locations (eyes, mouth and nose positions). For more details take a look at [the paper](https://liuziwei7.github.io/projects/FaceAttributes.html).
With the permission of the owners, we have stored this dataset on Google Cloud Storage and mostly access it via [TensorFlow Datasets(`tfds`)](https://www.tensorflow.org/datasets).

In this notebook:
* Our model will attempt to classify whether the subject of the image is smiling, as represented by the "Smiling" attribute<sup>*</sup>.
*   Images will be resized from 218x178 to 28x28 to reduce the execution time and memory when training.
*   Our model's performance will be evaluated across age groups, using the binary "Young" attribute. We will call this "age group" in this notebook.

___

<sup>*</sup> While there is little information available about the labeling methodology for this dataset, we will assume that the "Smiling" attribute was determined by a pleased, kind, or amused expression on the subject's face. For the purpose of this case study, we will take these labels as ground truth.




In [6]:
gcs_base_dir = "gs://celeb_a_dataset/"
celeb_a_builder = tfds.builder("celeb_a", data_dir=gcs_base_dir, version='2.0.0')

celeb_a_builder.download_and_prepare()

num_test_shards_dict = {'0.3.0': 4, '2.0.0': 2} # Used because we download the test dataset separately
version = str(celeb_a_builder.info.version)
print('Celeb_A dataset version: %s' % version)

Celeb_A dataset version: 2.0.0


In [7]:
#@title Test dataset helper functions
local_root = tempfile.mkdtemp(prefix='test-data')
def local_test_filename_base():
  return local_root

def local_test_file_full_prefix():
  return os.path.join(local_test_filename_base(), "celeb_a-test.tfrecord")

def copy_test_files_to_local():
  filename_base = local_test_file_full_prefix()
  num_test_shards = num_test_shards_dict[version]
  for shard in range(num_test_shards):
    url = "https://storage.googleapis.com/celeb_a_dataset/celeb_a/%s/celeb_a-test.tfrecord-0000%s-of-0000%s" % (version, shard, num_test_shards)
    filename = "%s-0000%s-of-0000%s" % (filename_base, shard, num_test_shards)
    res = urllib.request.urlretrieve(url, filename)

## Caveats
Before moving forward, there are several considerations to keep in mind in using CelebA:
*   Although in principle this notebook could use any dataset of face images, CelebA was chosen because it contains public domain images of public figures.
*   All of the attribute annotations in CelebA are operationalized as binary categories. For example, the "Young" attribute (as determined by the dataset labelers) is denoted as either present or absent in the image.
*   CelebA's categorizations do not reflect real human diversity of attributes.
*   For the purposes of this notebook, the feature containing the "Young" attribute is referred to as "age group", where the presence of the "Young" attribute in an image is labeled as a member of the "Young" age group and the absence of the "Young" attribute is labeled as a member of the "Not Young" age group. These are assumptions made as this information is not mentioned in the [original paper](http://openaccess.thecvf.com/content_iccv_2015/html/Liu_Deep_Learning_Face_ICCV_2015_paper.html).
*   As such, performance in the models trained in this notebook is tied to the ways the attributes have been operationalized and annotated by the authors of CelebA.
*   This model should not be used for commercial purposes as that would violate [CelebA's non-commercial research agreement](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).

# Setting Up Input Functions
The subsequent cells will help streamline the input pipeline as well as visualize performance.

First we define some data-related variables and define a requisite preprocessing function.

In [8]:
#@title Define Variables
ATTR_KEY = "attributes"
IMAGE_KEY = "image"
LABEL_KEY = "Smiling"
GROUP_KEY = "Young"
IMAGE_SIZE = 28

In [9]:
#@title Define Preprocessing Functions
def preprocess_input_dict(feat_dict):
  # Separate out the image and target variable from the feature dictionary.
  image = feat_dict[IMAGE_KEY]
  label = feat_dict[ATTR_KEY][LABEL_KEY]
  group = feat_dict[ATTR_KEY][GROUP_KEY]

  # Resize and normalize image.
  image = tf.cast(image, tf.float32)
  image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
  image /= 255.0

  # Cast label and group to float32.
  label = tf.cast(label, tf.float32)
  group = tf.cast(group, tf.float32)

  feat_dict[IMAGE_KEY] = image
  feat_dict[ATTR_KEY][LABEL_KEY] = label
  feat_dict[ATTR_KEY][GROUP_KEY] = group

  return feat_dict

get_image_and_label = lambda feat_dict: (feat_dict[IMAGE_KEY], feat_dict[ATTR_KEY][LABEL_KEY])
get_image_label_and_group = lambda feat_dict: (feat_dict[IMAGE_KEY], feat_dict[ATTR_KEY][LABEL_KEY], feat_dict[ATTR_KEY][GROUP_KEY])

Then, we build out the data functions we need in the rest of the colab.

In [10]:
# Train data returning either 2 or 3 elements (the third element being the group)
def celeb_a_train_data_wo_group(batch_size):
  celeb_a_train_data = celeb_a_builder.as_dataset(split='train').shuffle(1024).repeat().batch(batch_size).map(preprocess_input_dict)
  return celeb_a_train_data.map(get_image_and_label)
def celeb_a_train_data_w_group(batch_size):
  celeb_a_train_data = celeb_a_builder.as_dataset(split='train').shuffle(1024).repeat().batch(batch_size).map(preprocess_input_dict)
  return celeb_a_train_data.map(get_image_label_and_group)

# Test data for the overall evaluation
celeb_a_test_data = celeb_a_builder.as_dataset(split='test').batch(1).map(preprocess_input_dict).map(get_image_label_and_group)
# Copy test data locally to be able to read it into tfma
copy_test_files_to_local()

# Build a simple DNN Model
Because this notebook focuses on TFCO, we will assemble a simple, unconstrained `tf.keras.Sequential` model.

We may be able to greatly improve model performance by adding some complexity (e.g., more densely-connected layers, exploring different activation functions, increasing image size), but that may distract from the goal of demonstrating how easy it is to apply the TFCO library when working with Keras. For that reason, the model will be kept simple — but feel encouraged to explore this space.

In [11]:
def create_model():
  # For this notebook, accuracy will be used to evaluate performance.
  METRICS = [
    tf.keras.metrics.BinaryAccuracy(name='accuracy')
  ]

  # The model consists of:
  # 1. An input layer that represents the 28x28x3 image flatten.
  # 2. A fully connected layer with 64 units activated by a ReLU function.
  # 3. A single-unit readout layer to output real-scores instead of probabilities.
  model = keras.Sequential([
      keras.layers.Flatten(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), name='image'),
      keras.layers.Dense(64, activation='relu'),
      keras.layers.Dense(1, activation=None)
  ])

  # TFCO by default uses hinge loss — and that will also be used in the model.
  model.compile(
      optimizer=tf.keras.optimizers.Adam(0.001),
      loss='hinge',
      metrics=METRICS)
  return model

We also define a function to set seeds to ensure reproducible results. Note that this colab is meant as an educational tool and does not have the stability of a finely tuned production pipeline. Running without setting a seed may lead to varied results. 

In [12]:
def set_seeds():
  np.random.seed(121212)
  tf.compat.v1.set_random_seed(212121)

# Fairness Indicators Helper Functions
Before training our model, we define a number of helper functions that will allow us to evaluate the model's performance via Fairness Indicators.


First, we create a helper function to save our model once we train it.

In [13]:
def save_model(model, subdir):
  base_dir = tempfile.mkdtemp(prefix='saved_models')
  model_location = os.path.join(base_dir, subdir)
  model.save(model_location, save_format='tf')
  return model_location

Next, we define functions used to preprocess the data in order to correctly pass it through to TFMA.

In [14]:
#@title Data Preprocessing functions for 
def tfds_filepattern_for_split(dataset_name, split):
  return f"{local_test_file_full_prefix()}*"

class PreprocessCelebA(object):
  """Class that deserializes, decodes and applies additional preprocessing for CelebA input."""
  def __init__(self, dataset_name):
    builder = tfds.builder(dataset_name)
    self.features = builder.info.features
    example_specs = self.features.get_serialized_info()
    self.parser = tfds.core.example_parser.ExampleParser(example_specs)

  def __call__(self, serialized_example):
    # Deserialize
    deserialized_example = self.parser.parse_example(serialized_example)
    # Decode
    decoded_example = self.features.decode_example(deserialized_example)
    # Additional preprocessing
    image = decoded_example[IMAGE_KEY]
    label = decoded_example[ATTR_KEY][LABEL_KEY]
    # Resize and scale image.
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
    image /= 255.0
    image = tf.reshape(image, [-1])
    # Cast label and group to float32.
    label = tf.cast(label, tf.float32)

    group = decoded_example[ATTR_KEY][GROUP_KEY]
    
    output = tf.train.Example()
    output.features.feature[IMAGE_KEY].float_list.value.extend(image.numpy().tolist())
    output.features.feature[LABEL_KEY].float_list.value.append(label.numpy())
    output.features.feature[GROUP_KEY].bytes_list.value.append(b"Young" if group.numpy() else b'Not Young')
    return output.SerializeToString()

def tfds_as_pcollection(beam_pipeline, dataset_name, split):
  return (
      beam_pipeline
   | 'Read records' >> beam.io.ReadFromTFRecord(tfds_filepattern_for_split(dataset_name, split))
   | 'Preprocess' >> beam.Map(PreprocessCelebA(dataset_name))
  )

Finally, we define a function that evaluates the results in TFMA.

In [15]:
def get_eval_results(model_location, eval_subdir):
  base_dir = tempfile.mkdtemp(prefix='saved_eval_results')
  tfma_eval_result_path = os.path.join(base_dir, eval_subdir)

  eval_config_pbtxt = """
        model_specs {
          label_key: "%s"
        }
        metrics_specs {
          metrics {
            class_name: "FairnessIndicators"
            config: '{ "thresholds": [0.22, 0.5, 0.75] }'
          }
          metrics {
            class_name: "ExampleCount"
          }
        }
        slicing_specs {}
        slicing_specs { feature_keys: "%s" }
        options {
          compute_confidence_intervals { value: False }
          disabled_outputs{values: "analysis"}
        }
      """ % (LABEL_KEY, GROUP_KEY)
      
  eval_config = text_format.Parse(eval_config_pbtxt, tfma.EvalConfig())

  eval_shared_model = tfma.default_eval_shared_model(
        eval_saved_model_path=model_location, tags=[tf.saved_model.SERVING])

  schema_pbtxt = """
        tensor_representation_group {
          key: ""
          value {
            tensor_representation {
              key: "%s"
              value {
                dense_tensor {
                  column_name: "%s"
                  shape {
                    dim { size: 28 }
                    dim { size: 28 }
                    dim { size: 3 }
                  }
                }
              }
            }
          }
        }
        feature {
          name: "%s"
          type: FLOAT
        }
        feature {
          name: "%s"
          type: FLOAT
        }
        feature {
          name: "%s"
          type: BYTES
        }
        """ % (IMAGE_KEY, IMAGE_KEY, IMAGE_KEY, LABEL_KEY, GROUP_KEY)
  schema = text_format.Parse(schema_pbtxt, schema_pb2.Schema())
  coder = tf_example_record.TFExampleBeamRecord(
      physical_format='inmem', schema=schema,
      raw_record_column_name=tfma.ARROW_INPUT_COLUMN)
  tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
    arrow_schema=coder.ArrowSchema(),
    tensor_representations=coder.TensorRepresentations())
  # Run the fairness evaluation.
  with beam.Pipeline() as pipeline:
    _ = (
          tfds_as_pcollection(pipeline, 'celeb_a', 'test')
          | 'ExamplesToRecordBatch' >> coder.BeamSource()
          | 'ExtractEvaluateAndWriteResults' >>
          tfma.ExtractEvaluateAndWriteResults(
              eval_config=eval_config,
              eval_shared_model=eval_shared_model,
              output_path=tfma_eval_result_path,
              tensor_adapter_config=tensor_adapter_config)
    )
  return tfma.load_eval_result(output_path=tfma_eval_result_path)


# Train & Evaluate Unconstrained Model

With the model now defined and the input pipeline in place, we’re now ready to train our model. To cut back on the amount of execution time and memory, we will train the model by slicing the data into small batches with only a few repeated iterations.

Note that running this notebook in TensorFlow < 2.0.0 may result in a deprecation warning for `np.where`. Safely ignore this warning as TensorFlow addresses this in 2.X by using `tf.where` in place of `np.where`.

In [16]:
BATCH_SIZE = 32

# Set seeds to get reproducible results
set_seeds()

model_unconstrained = create_model()
model_unconstrained.fit(celeb_a_train_data_wo_group(BATCH_SIZE), epochs=5, steps_per_epoch=1000)

Epoch 1/5


   1/1000 [..............................] - ETA: 1:13:38 - loss: 0.9872 - accuracy: 0.5312

  20/1000 [..............................] - ETA: 2s - loss: 1.0012 - accuracy: 0.5547     

  38/1000 [>.............................] - ETA: 2s - loss: 0.8899 - accuracy: 0.5995

  52/1000 [>.............................] - ETA: 2s - loss: 0.8668 - accuracy: 0.6070

  67/1000 [=>............................] - ETA: 2s - loss: 0.8373 - accuracy: 0.6255

  82/1000 [=>............................] - ETA: 2s - loss: 0.8182 - accuracy: 0.6296

  97/1000 [=>............................] - ETA: 2s - loss: 0.7912 - accuracy: 0.6408

 112/1000 [==>...........................] - ETA: 2s - loss: 0.7715 - accuracy: 0.6476

 127/1000 [==>...........................] - ETA: 2s - loss: 0.7571 - accuracy: 0.6543

 141/1000 [===>..........................] - ETA: 2s - loss: 0.7486 - accuracy: 0.6580

 155/1000 [===>..........................] - ETA: 2s - loss: 0.7384 - accuracy: 0.6635

 169/1000 [====>.........................] - ETA: 2s - loss: 0.7332 - accuracy: 0.6659

 183/1000 [====>.........................] - ETA: 2s - loss: 0.7262 - accuracy: 0.6699

 198/1000 [====>.........................] - ETA: 2s - loss: 0.7063 - accuracy: 0.6798

 213/1000 [=====>........................] - ETA: 2s - loss: 0.6914 - accuracy: 0.6865

 228/1000 [=====>........................] - ETA: 2s - loss: 0.6821 - accuracy: 0.6911













































































































Epoch 2/5


   1/1000 [..............................] - ETA: 4s - loss: 0.2802 - accuracy: 0.8125

  15/1000 [..............................] - ETA: 3s - loss: 0.4904 - accuracy: 0.7771

  29/1000 [..............................] - ETA: 3s - loss: 0.4367 - accuracy: 0.8028

  43/1000 [>.............................] - ETA: 3s - loss: 0.4456 - accuracy: 0.7980

  58/1000 [>.............................] - ETA: 3s - loss: 0.4374 - accuracy: 0.8023

  73/1000 [=>............................] - ETA: 3s - loss: 0.4326 - accuracy: 0.8052

  87/1000 [=>............................] - ETA: 3s - loss: 0.4414 - accuracy: 0.7996

 101/1000 [==>...........................] - ETA: 3s - loss: 0.4513 - accuracy: 0.7995

 116/1000 [==>...........................] - ETA: 3s - loss: 0.4460 - accuracy: 0.8001

 131/1000 [==>...........................] - ETA: 3s - loss: 0.4400 - accuracy: 0.8046

 146/1000 [===>..........................] - ETA: 3s - loss: 0.4310 - accuracy: 0.8076

 161/1000 [===>..........................] - ETA: 2s - loss: 0.4204 - accuracy: 0.8123

 175/1000 [====>.........................] - ETA: 2s - loss: 0.4144 - accuracy: 0.8157

 189/1000 [====>.........................] - ETA: 2s - loss: 0.4157 - accuracy: 0.8155

 203/1000 [=====>........................] - ETA: 2s - loss: 0.4287 - accuracy: 0.8107

 218/1000 [=====>........................] - ETA: 2s - loss: 0.4314 - accuracy: 0.8091

 232/1000 [=====>........................] - ETA: 2s - loss: 0.4298 - accuracy: 0.8102













































































































Epoch 3/5


   1/1000 [..............................] - ETA: 4s - loss: 0.5148 - accuracy: 0.7188

  16/1000 [..............................] - ETA: 3s - loss: 0.3507 - accuracy: 0.8438

  31/1000 [..............................] - ETA: 3s - loss: 0.3612 - accuracy: 0.8357

  46/1000 [>.............................] - ETA: 3s - loss: 0.3582 - accuracy: 0.8390

  61/1000 [>.............................] - ETA: 3s - loss: 0.3706 - accuracy: 0.8325

  75/1000 [=>............................] - ETA: 3s - loss: 0.3658 - accuracy: 0.8342

  90/1000 [=>............................] - ETA: 3s - loss: 0.3628 - accuracy: 0.8333

 105/1000 [==>...........................] - ETA: 3s - loss: 0.3594 - accuracy: 0.8360

 119/1000 [==>...........................] - ETA: 3s - loss: 0.3567 - accuracy: 0.8393

 133/1000 [==>...........................] - ETA: 3s - loss: 0.3524 - accuracy: 0.8423

 148/1000 [===>..........................] - ETA: 3s - loss: 0.3585 - accuracy: 0.8397

 162/1000 [===>..........................] - ETA: 2s - loss: 0.3564 - accuracy: 0.8410

 177/1000 [====>.........................] - ETA: 2s - loss: 0.3564 - accuracy: 0.8416

 192/1000 [====>.........................] - ETA: 2s - loss: 0.3594 - accuracy: 0.8407

 207/1000 [=====>........................] - ETA: 2s - loss: 0.3609 - accuracy: 0.8410

 222/1000 [=====>........................] - ETA: 2s - loss: 0.3632 - accuracy: 0.8402















































































































Epoch 4/5


   1/1000 [..............................] - ETA: 5s - loss: 0.4204 - accuracy: 0.7188

  16/1000 [..............................] - ETA: 3s - loss: 0.3827 - accuracy: 0.8340

  31/1000 [..............................] - ETA: 3s - loss: 0.3490 - accuracy: 0.8438

  45/1000 [>.............................] - ETA: 3s - loss: 0.3350 - accuracy: 0.8472

  59/1000 [>.............................] - ETA: 3s - loss: 0.3284 - accuracy: 0.8538

  73/1000 [=>............................] - ETA: 3s - loss: 0.3266 - accuracy: 0.8553

  87/1000 [=>............................] - ETA: 3s - loss: 0.3264 - accuracy: 0.8556

 101/1000 [==>...........................] - ETA: 3s - loss: 0.3267 - accuracy: 0.8564

 116/1000 [==>...........................] - ETA: 3s - loss: 0.3306 - accuracy: 0.8516

 131/1000 [==>...........................] - ETA: 3s - loss: 0.3382 - accuracy: 0.8500

 146/1000 [===>..........................] - ETA: 3s - loss: 0.3365 - accuracy: 0.8510

 161/1000 [===>..........................] - ETA: 2s - loss: 0.3378 - accuracy: 0.8509

 176/1000 [====>.........................] - ETA: 2s - loss: 0.3404 - accuracy: 0.8507

 190/1000 [====>.........................] - ETA: 2s - loss: 0.3424 - accuracy: 0.8505

 205/1000 [=====>........................] - ETA: 2s - loss: 0.3420 - accuracy: 0.8505

 220/1000 [=====>........................] - ETA: 2s - loss: 0.3403 - accuracy: 0.8506





















































































































Epoch 5/5


   1/1000 [..............................] - ETA: 4s - loss: 0.4340 - accuracy: 0.8750

  16/1000 [..............................] - ETA: 3s - loss: 0.3367 - accuracy: 0.8496

  30/1000 [..............................] - ETA: 3s - loss: 0.3474 - accuracy: 0.8333

  44/1000 [>.............................] - ETA: 3s - loss: 0.3432 - accuracy: 0.8359

  59/1000 [>.............................] - ETA: 3s - loss: 0.3321 - accuracy: 0.8459

  73/1000 [=>............................] - ETA: 3s - loss: 0.3247 - accuracy: 0.8519

  88/1000 [=>............................] - ETA: 3s - loss: 0.3266 - accuracy: 0.8509

 102/1000 [==>...........................] - ETA: 3s - loss: 0.3254 - accuracy: 0.8532

 117/1000 [==>...........................] - ETA: 3s - loss: 0.3246 - accuracy: 0.8555

 131/1000 [==>...........................] - ETA: 3s - loss: 0.3243 - accuracy: 0.8552

 145/1000 [===>..........................] - ETA: 3s - loss: 0.3238 - accuracy: 0.8556

 160/1000 [===>..........................] - ETA: 3s - loss: 0.3250 - accuracy: 0.8551

 175/1000 [====>.........................] - ETA: 2s - loss: 0.3309 - accuracy: 0.8512

 190/1000 [====>.........................] - ETA: 2s - loss: 0.3363 - accuracy: 0.8498

 204/1000 [=====>........................] - ETA: 2s - loss: 0.3341 - accuracy: 0.8493

 219/1000 [=====>........................] - ETA: 2s - loss: 0.3350 - accuracy: 0.8489















































































































<keras.callbacks.History at 0x7f98405d8370>

Evaluating the model on the test data should result in a final accuracy score of just over 85%. Not bad for a simple model with no fine tuning.

In [17]:
print('Overall Results, Unconstrained')
celeb_a_test_data = celeb_a_builder.as_dataset(split='test').batch(1).map(preprocess_input_dict).map(get_image_label_and_group)
results = model_unconstrained.evaluate(celeb_a_test_data)

Overall Results, Unconstrained






    1/19962 [..............................] - ETA: 5:56:20 - loss: 0.0000e+00 - accuracy: 1.0000

   17/19962 [..............................] - ETA: 19:07 - loss: 0.1585 - accuracy: 0.8824      

   41/19962 [..............................] - ETA: 8:03 - loss: 0.1812 - accuracy: 0.9024 

   66/19962 [..............................] - ETA: 5:13 - loss: 0.1806 - accuracy: 0.8939

   90/19962 [..............................] - ETA: 3:59 - loss: 0.1738 - accuracy: 0.8889

  114/19962 [..............................] - ETA: 3:17 - loss: 0.2350 - accuracy: 0.8684

  138/19962 [..............................] - ETA: 2:50 - loss: 0.2226 - accuracy: 0.8768

  161/19962 [..............................] - ETA: 2:31 - loss: 0.2099 - accuracy: 0.8696

  184/19962 [..............................] - ETA: 2:17 - loss: 0.1983 - accuracy: 0.8696

  207/19962 [..............................] - ETA: 2:07 - loss: 0.1912 - accuracy: 0.8696

  230/19962 [..............................] - ETA: 1:58 - loss: 0.2057 - accuracy: 0.8652

  253/19962 [..............................] - ETA: 1:51 - loss: 0.2109 - accuracy: 0.8538

  276/19962 [..............................] - ETA: 1:46 - loss: 0.2234 - accuracy: 0.8478

  299/19962 [..............................] - ETA: 1:41 - loss: 0.2281 - accuracy: 0.8462

  322/19962 [..............................] - ETA: 1:36 - loss: 0.2264 - accuracy: 0.8509

  345/19962 [..............................] - ETA: 1:33 - loss: 0.2267 - accuracy: 0.8493

  369/19962 [..............................] - ETA: 1:29 - loss: 0.2238 - accuracy: 0.8509

  392/19962 [..............................] - ETA: 1:26 - loss: 0.2340 - accuracy: 0.8418

  415/19962 [..............................] - ETA: 1:24 - loss: 0.2401 - accuracy: 0.8386

  438/19962 [..............................] - ETA: 1:22 - loss: 0.2343 - accuracy: 0.8379

  461/19962 [..............................] - ETA: 1:20 - loss: 0.2353 - accuracy: 0.8351

  485/19962 [..............................] - ETA: 1:18 - loss: 0.2271 - accuracy: 0.8371

  509/19962 [..............................] - ETA: 1:16 - loss: 0.2399 - accuracy: 0.8369

  533/19962 [..............................] - ETA: 1:14 - loss: 0.2379 - accuracy: 0.8405

  557/19962 [..............................] - ETA: 1:13 - loss: 0.2323 - accuracy: 0.8456

  580/19962 [..............................] - ETA: 1:11 - loss: 0.2365 - accuracy: 0.8448

  604/19962 [..............................] - ETA: 1:10 - loss: 0.2407 - accuracy: 0.8411

  627/19962 [..............................] - ETA: 1:09 - loss: 0.2387 - accuracy: 0.8373

  651/19962 [..............................] - ETA: 1:08 - loss: 0.2346 - accuracy: 0.8372

  675/19962 [>.............................] - ETA: 1:07 - loss: 0.2284 - accuracy: 0.8400

  699/19962 [>.............................] - ETA: 1:06 - loss: 0.2294 - accuracy: 0.8398

  724/19962 [>.............................] - ETA: 1:05 - loss: 0.2317 - accuracy: 0.8384

  748/19962 [>.............................] - ETA: 1:04 - loss: 0.2294 - accuracy: 0.8396

  772/19962 [>.............................] - ETA: 1:03 - loss: 0.2270 - accuracy: 0.8420

  796/19962 [>.............................] - ETA: 1:02 - loss: 0.2274 - accuracy: 0.8430

  820/19962 [>.............................] - ETA: 1:02 - loss: 0.2228 - accuracy: 0.8439

  844/19962 [>.............................] - ETA: 1:01 - loss: 0.2234 - accuracy: 0.8412

  868/19962 [>.............................] - ETA: 1:00 - loss: 0.2232 - accuracy: 0.8422

  892/19962 [>.............................] - ETA: 1:00 - loss: 0.2214 - accuracy: 0.8419

  916/19962 [>.............................] - ETA: 59s - loss: 0.2200 - accuracy: 0.8439 

  941/19962 [>.............................] - ETA: 58s - loss: 0.2208 - accuracy: 0.8438

  965/19962 [>.............................] - ETA: 58s - loss: 0.2228 - accuracy: 0.8446

  990/19962 [>.............................] - ETA: 57s - loss: 0.2202 - accuracy: 0.8455

 1015/19962 [>.............................] - ETA: 57s - loss: 0.2243 - accuracy: 0.8463

 1040/19962 [>.............................] - ETA: 56s - loss: 0.2266 - accuracy: 0.8462

 1064/19962 [>.............................] - ETA: 56s - loss: 0.2260 - accuracy: 0.8477

 1089/19962 [>.............................] - ETA: 55s - loss: 0.2234 - accuracy: 0.8485

 1113/19962 [>.............................] - ETA: 55s - loss: 0.2268 - accuracy: 0.8464

 1138/19962 [>.............................] - ETA: 55s - loss: 0.2243 - accuracy: 0.8480

 1162/19962 [>.............................] - ETA: 54s - loss: 0.2218 - accuracy: 0.8477

 1187/19962 [>.............................] - ETA: 54s - loss: 0.2215 - accuracy: 0.8467

 1212/19962 [>.............................] - ETA: 53s - loss: 0.2215 - accuracy: 0.8482

 1237/19962 [>.............................] - ETA: 53s - loss: 0.2233 - accuracy: 0.8480

 1261/19962 [>.............................] - ETA: 53s - loss: 0.2211 - accuracy: 0.8493

 1286/19962 [>.............................] - ETA: 52s - loss: 0.2194 - accuracy: 0.8499

 1310/19962 [>.............................] - ETA: 52s - loss: 0.2222 - accuracy: 0.8481

 1334/19962 [=>............................] - ETA: 52s - loss: 0.2218 - accuracy: 0.8478

 1358/19962 [=>............................] - ETA: 51s - loss: 0.2217 - accuracy: 0.8476

 1382/19962 [=>............................] - ETA: 51s - loss: 0.2205 - accuracy: 0.8488

 1406/19962 [=>............................] - ETA: 51s - loss: 0.2175 - accuracy: 0.8506

 1430/19962 [=>............................] - ETA: 51s - loss: 0.2182 - accuracy: 0.8510

 1454/19962 [=>............................] - ETA: 50s - loss: 0.2188 - accuracy: 0.8508

 1478/19962 [=>............................] - ETA: 50s - loss: 0.2190 - accuracy: 0.8525

 1503/19962 [=>............................] - ETA: 50s - loss: 0.2170 - accuracy: 0.8550

 1528/19962 [=>............................] - ETA: 49s - loss: 0.2194 - accuracy: 0.8541

 1553/19962 [=>............................] - ETA: 49s - loss: 0.2192 - accuracy: 0.8551

 1577/19962 [=>............................] - ETA: 49s - loss: 0.2193 - accuracy: 0.8554

 1601/19962 [=>............................] - ETA: 49s - loss: 0.2202 - accuracy: 0.8551

 1625/19962 [=>............................] - ETA: 49s - loss: 0.2193 - accuracy: 0.8554

 1650/19962 [=>............................] - ETA: 48s - loss: 0.2187 - accuracy: 0.8552

 1674/19962 [=>............................] - ETA: 48s - loss: 0.2207 - accuracy: 0.8554

 1697/19962 [=>............................] - ETA: 48s - loss: 0.2197 - accuracy: 0.8568

 1721/19962 [=>............................] - ETA: 48s - loss: 0.2221 - accuracy: 0.8565

 1745/19962 [=>............................] - ETA: 48s - loss: 0.2214 - accuracy: 0.8579

 1769/19962 [=>............................] - ETA: 47s - loss: 0.2201 - accuracy: 0.8587

 1793/19962 [=>............................] - ETA: 47s - loss: 0.2174 - accuracy: 0.8600

 1817/19962 [=>............................] - ETA: 47s - loss: 0.2146 - accuracy: 0.8613

 1840/19962 [=>............................] - ETA: 47s - loss: 0.2140 - accuracy: 0.8620

 1863/19962 [=>............................] - ETA: 47s - loss: 0.2131 - accuracy: 0.8615

 1887/19962 [=>............................] - ETA: 47s - loss: 0.2142 - accuracy: 0.8617

 1910/19962 [=>............................] - ETA: 46s - loss: 0.2147 - accuracy: 0.8602

 1934/19962 [=>............................] - ETA: 46s - loss: 0.2130 - accuracy: 0.8604

 1958/19962 [=>............................] - ETA: 46s - loss: 0.2123 - accuracy: 0.8611

 1982/19962 [=>............................] - ETA: 46s - loss: 0.2146 - accuracy: 0.8607

 2006/19962 [==>...........................] - ETA: 46s - loss: 0.2145 - accuracy: 0.8614

 2030/19962 [==>...........................] - ETA: 46s - loss: 0.2148 - accuracy: 0.8611

 2053/19962 [==>...........................] - ETA: 45s - loss: 0.2133 - accuracy: 0.8612

 2077/19962 [==>...........................] - ETA: 45s - loss: 0.2132 - accuracy: 0.8599

 2100/19962 [==>...........................] - ETA: 45s - loss: 0.2154 - accuracy: 0.8590

 2123/19962 [==>...........................] - ETA: 45s - loss: 0.2161 - accuracy: 0.8592

 2147/19962 [==>...........................] - ETA: 45s - loss: 0.2185 - accuracy: 0.8570

 2170/19962 [==>...........................] - ETA: 45s - loss: 0.2175 - accuracy: 0.8576

 2194/19962 [==>...........................] - ETA: 45s - loss: 0.2183 - accuracy: 0.8573

 2218/19962 [==>...........................] - ETA: 44s - loss: 0.2174 - accuracy: 0.8580

 2242/19962 [==>...........................] - ETA: 44s - loss: 0.2171 - accuracy: 0.8577

 2265/19962 [==>...........................] - ETA: 44s - loss: 0.2155 - accuracy: 0.8587

 2287/19962 [==>...........................] - ETA: 44s - loss: 0.2149 - accuracy: 0.8588

 2310/19962 [==>...........................] - ETA: 44s - loss: 0.2133 - accuracy: 0.8597

 2333/19962 [==>...........................] - ETA: 44s - loss: 0.2128 - accuracy: 0.8603

 2355/19962 [==>...........................] - ETA: 44s - loss: 0.2128 - accuracy: 0.8599

 2377/19962 [==>...........................] - ETA: 44s - loss: 0.2128 - accuracy: 0.8595

 2400/19962 [==>...........................] - ETA: 44s - loss: 0.2128 - accuracy: 0.8600

 2423/19962 [==>...........................] - ETA: 44s - loss: 0.2132 - accuracy: 0.8597

 2447/19962 [==>...........................] - ETA: 43s - loss: 0.2138 - accuracy: 0.8590

 2471/19962 [==>...........................] - ETA: 43s - loss: 0.2153 - accuracy: 0.8580

 2494/19962 [==>...........................] - ETA: 43s - loss: 0.2138 - accuracy: 0.8581

 2519/19962 [==>...........................] - ETA: 43s - loss: 0.2147 - accuracy: 0.8583

 2543/19962 [==>...........................] - ETA: 43s - loss: 0.2129 - accuracy: 0.8592

 2567/19962 [==>...........................] - ETA: 43s - loss: 0.2141 - accuracy: 0.8586

 2591/19962 [==>...........................] - ETA: 43s - loss: 0.2142 - accuracy: 0.8587

 2616/19962 [==>...........................] - ETA: 43s - loss: 0.2148 - accuracy: 0.8586

 2640/19962 [==>...........................] - ETA: 42s - loss: 0.2158 - accuracy: 0.8583

 2664/19962 [===>..........................] - ETA: 42s - loss: 0.2149 - accuracy: 0.8596

 2688/19962 [===>..........................] - ETA: 42s - loss: 0.2148 - accuracy: 0.8597

 2712/19962 [===>..........................] - ETA: 42s - loss: 0.2156 - accuracy: 0.8595

 2735/19962 [===>..........................] - ETA: 42s - loss: 0.2142 - accuracy: 0.8596

 2758/19962 [===>..........................] - ETA: 42s - loss: 0.2151 - accuracy: 0.8593

 2782/19962 [===>..........................] - ETA: 42s - loss: 0.2138 - accuracy: 0.8602

 2806/19962 [===>..........................] - ETA: 42s - loss: 0.2142 - accuracy: 0.8599

 2830/19962 [===>..........................] - ETA: 42s - loss: 0.2137 - accuracy: 0.8594

 2854/19962 [===>..........................] - ETA: 41s - loss: 0.2150 - accuracy: 0.8591

 2877/19962 [===>..........................] - ETA: 41s - loss: 0.2149 - accuracy: 0.8592

 2901/19962 [===>..........................] - ETA: 41s - loss: 0.2144 - accuracy: 0.8590

 2925/19962 [===>..........................] - ETA: 41s - loss: 0.2150 - accuracy: 0.8588

 2949/19962 [===>..........................] - ETA: 41s - loss: 0.2155 - accuracy: 0.8583

 2973/19962 [===>..........................] - ETA: 41s - loss: 0.2156 - accuracy: 0.8587

 2997/19962 [===>..........................] - ETA: 41s - loss: 0.2153 - accuracy: 0.8589

 3021/19962 [===>..........................] - ETA: 41s - loss: 0.2141 - accuracy: 0.8587

 3046/19962 [===>..........................] - ETA: 41s - loss: 0.2138 - accuracy: 0.8588

 3070/19962 [===>..........................] - ETA: 41s - loss: 0.2140 - accuracy: 0.8586

 3095/19962 [===>..........................] - ETA: 40s - loss: 0.2149 - accuracy: 0.8585

 3119/19962 [===>..........................] - ETA: 40s - loss: 0.2147 - accuracy: 0.8583

 3143/19962 [===>..........................] - ETA: 40s - loss: 0.2142 - accuracy: 0.8587

 3168/19962 [===>..........................] - ETA: 40s - loss: 0.2153 - accuracy: 0.8580

 3193/19962 [===>..........................] - ETA: 40s - loss: 0.2151 - accuracy: 0.8578

 3218/19962 [===>..........................] - ETA: 40s - loss: 0.2156 - accuracy: 0.8583

 3242/19962 [===>..........................] - ETA: 40s - loss: 0.2151 - accuracy: 0.8587

 3266/19962 [===>..........................] - ETA: 40s - loss: 0.2148 - accuracy: 0.8588

 3290/19962 [===>..........................] - ETA: 40s - loss: 0.2150 - accuracy: 0.8593

 3315/19962 [===>..........................] - ETA: 40s - loss: 0.2158 - accuracy: 0.8588

 3340/19962 [====>.........................] - ETA: 39s - loss: 0.2148 - accuracy: 0.8587

 3364/19962 [====>.........................] - ETA: 39s - loss: 0.2170 - accuracy: 0.8582

 3388/19962 [====>.........................] - ETA: 39s - loss: 0.2159 - accuracy: 0.8589

 3413/19962 [====>.........................] - ETA: 39s - loss: 0.2154 - accuracy: 0.8588

 3437/19962 [====>.........................] - ETA: 39s - loss: 0.2142 - accuracy: 0.8583

 3461/19962 [====>.........................] - ETA: 39s - loss: 0.2149 - accuracy: 0.8581

 3486/19962 [====>.........................] - ETA: 39s - loss: 0.2148 - accuracy: 0.8580

 3510/19962 [====>.........................] - ETA: 39s - loss: 0.2148 - accuracy: 0.8590

 3534/19962 [====>.........................] - ETA: 39s - loss: 0.2142 - accuracy: 0.8596

 3558/19962 [====>.........................] - ETA: 39s - loss: 0.2139 - accuracy: 0.8595

 3582/19962 [====>.........................] - ETA: 39s - loss: 0.2138 - accuracy: 0.8590

 3606/19962 [====>.........................] - ETA: 38s - loss: 0.2134 - accuracy: 0.8594

 3631/19962 [====>.........................] - ETA: 38s - loss: 0.2130 - accuracy: 0.8595

 3655/19962 [====>.........................] - ETA: 38s - loss: 0.2125 - accuracy: 0.8594

 3679/19962 [====>.........................] - ETA: 38s - loss: 0.2115 - accuracy: 0.8597

 3703/19962 [====>.........................] - ETA: 38s - loss: 0.2120 - accuracy: 0.8598

 3728/19962 [====>.........................] - ETA: 38s - loss: 0.2119 - accuracy: 0.8600

 3752/19962 [====>.........................] - ETA: 38s - loss: 0.2116 - accuracy: 0.8606

 3776/19962 [====>.........................] - ETA: 38s - loss: 0.2117 - accuracy: 0.8604

 3800/19962 [====>.........................] - ETA: 38s - loss: 0.2117 - accuracy: 0.8600

 3824/19962 [====>.........................] - ETA: 38s - loss: 0.2118 - accuracy: 0.8601

 3848/19962 [====>.........................] - ETA: 38s - loss: 0.2116 - accuracy: 0.8604

 3873/19962 [====>.........................] - ETA: 37s - loss: 0.2126 - accuracy: 0.8598

 3897/19962 [====>.........................] - ETA: 37s - loss: 0.2126 - accuracy: 0.8596

 3921/19962 [====>.........................] - ETA: 37s - loss: 0.2134 - accuracy: 0.8592

 3946/19962 [====>.........................] - ETA: 37s - loss: 0.2132 - accuracy: 0.8594

 3971/19962 [====>.........................] - ETA: 37s - loss: 0.2135 - accuracy: 0.8590

 3995/19962 [=====>........................] - ETA: 37s - loss: 0.2132 - accuracy: 0.8591

 4020/19962 [=====>........................] - ETA: 37s - loss: 0.2133 - accuracy: 0.8592

 4044/19962 [=====>........................] - ETA: 37s - loss: 0.2123 - accuracy: 0.8600

 4068/19962 [=====>........................] - ETA: 37s - loss: 0.2125 - accuracy: 0.8604

 4092/19962 [=====>........................] - ETA: 37s - loss: 0.2128 - accuracy: 0.8605

 4116/19962 [=====>........................] - ETA: 37s - loss: 0.2128 - accuracy: 0.8603

 4140/19962 [=====>........................] - ETA: 37s - loss: 0.2133 - accuracy: 0.8604

 4164/19962 [=====>........................] - ETA: 37s - loss: 0.2132 - accuracy: 0.8607

 4188/19962 [=====>........................] - ETA: 36s - loss: 0.2130 - accuracy: 0.8603

 4212/19962 [=====>........................] - ETA: 36s - loss: 0.2123 - accuracy: 0.8606

 4237/19962 [=====>........................] - ETA: 36s - loss: 0.2127 - accuracy: 0.8605

 4261/19962 [=====>........................] - ETA: 36s - loss: 0.2137 - accuracy: 0.8601

 4286/19962 [=====>........................] - ETA: 36s - loss: 0.2131 - accuracy: 0.8600

 4310/19962 [=====>........................] - ETA: 36s - loss: 0.2139 - accuracy: 0.8594

 4335/19962 [=====>........................] - ETA: 36s - loss: 0.2140 - accuracy: 0.8597

 4360/19962 [=====>........................] - ETA: 36s - loss: 0.2142 - accuracy: 0.8596

 4385/19962 [=====>........................] - ETA: 36s - loss: 0.2155 - accuracy: 0.8593

 4410/19962 [=====>........................] - ETA: 36s - loss: 0.2145 - accuracy: 0.8592

 4434/19962 [=====>........................] - ETA: 36s - loss: 0.2153 - accuracy: 0.8593

 4459/19962 [=====>........................] - ETA: 36s - loss: 0.2158 - accuracy: 0.8594

 4484/19962 [=====>........................] - ETA: 35s - loss: 0.2154 - accuracy: 0.8595

 4509/19962 [=====>........................] - ETA: 35s - loss: 0.2152 - accuracy: 0.8596

 4533/19962 [=====>........................] - ETA: 35s - loss: 0.2149 - accuracy: 0.8590

 4558/19962 [=====>........................] - ETA: 35s - loss: 0.2144 - accuracy: 0.8594

 4583/19962 [=====>........................] - ETA: 35s - loss: 0.2136 - accuracy: 0.8601

 4608/19962 [=====>........................] - ETA: 35s - loss: 0.2130 - accuracy: 0.8607

 4633/19962 [=====>........................] - ETA: 35s - loss: 0.2129 - accuracy: 0.8608

 4657/19962 [=====>........................] - ETA: 35s - loss: 0.2131 - accuracy: 0.8609



























































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































However, performance evaluated across age groups may reveal some shortcomings.

To explore this further, we evaluate the model with Fairness Indicators (via TFMA). In particular, we are interested in seeing whether there is a significant gap in performance between "Young" and "Not Young" categories when evaluated on false positive rate.

A false positive error occurs when the model incorrectly predicts the positive class. In this context, a false positive outcome occurs when the ground truth is an image of a celebrity 'Not Smiling' and the model predicts 'Smiling'. By extension, the false positive rate, which is used in the visualization above, is a measure of accuracy for a test. While this is a relatively mundane error to make in this context, false positive errors can sometimes cause more problematic behaviors. For instance, a false positive error in a spam classifier could cause a user to miss an important email.

In [18]:
model_location = save_model(model_unconstrained, 'model_export_unconstrained')
eval_results_unconstrained = get_eval_results(model_location, 'eval_results_unconstrained')

INFO:tensorflow:Assets written to: /tmpfs/tmp/saved_modelszogx3zsm/model_export_unconstrained/assets


INFO:tensorflow:Assets written to: /tmpfs/tmp/saved_modelszogx3zsm/model_export_unconstrained/assets








Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`


Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`


As mentioned above, we are concentrating on the false positive rate. The current version of Fairness Indicators (0.1.2) selects false negative rate by default. After running the line below, deselect false_negative_rate and select false_positive_rate to look at the metric we are interested in.

In [19]:
tfma.addons.fairness.view.widget_view.render_fairness_indicator(eval_results_unconstrained)

FairnessIndicatorViewer(slicingMetrics=[{'sliceValue': 'Young', 'slice': 'Young:Young', 'metrics': {'example_c…

As the results show above, we do see a **disproportionate gap between "Young" and "Not Young" categories**.

This is where TFCO can help by constraining the false positive rate to be within a more acceptable criterion.


# Constrained Model Set Up
As documented in [TFCO's library](https://github.com/google-research/tensorflow_constrained_optimization/blob/master/README.md), there are several helpers that will make it easier to constrain the problem:

1.   `tfco.rate_context()` – This is what will be used in constructing a constraint for each age group category.
2.   `tfco.RateMinimizationProblem()`– The rate expression to be minimized here will be the false positive rate subject to age group. In other words, performance now will be evaluated based on the difference between the false positive rates of the age group and that of the overall dataset. For this demonstration, a false positive rate of less than or equal to 5% will be set as the constraint.
3.   `tfco.ProxyLagrangianOptimizerV2()` – This is the helper that will actually solve the rate constraint problem.

The cell below will call on these helpers to set up model training with the fairness constraint.




In [20]:
# The batch size is needed to create the input, labels and group tensors.
# These tensors are initialized with all 0's. They will eventually be assigned
# the batch content to them. A large batch size is chosen so that there are
# enough number of "Young" and "Not Young" examples in each batch.
set_seeds()
model_constrained = create_model()
BATCH_SIZE = 32

# Create input tensor.
input_tensor = tf.Variable(
    np.zeros((BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3), dtype="float32"),
    name="input")

# Create labels and group tensors (assuming both labels and groups are binary).
labels_tensor = tf.Variable(
    np.zeros(BATCH_SIZE, dtype="float32"), name="labels")
groups_tensor = tf.Variable(
    np.zeros(BATCH_SIZE, dtype="float32"), name="groups")

# Create a function that returns the applied 'model' to the input tensor
# and generates constrained predictions.
def predictions():
  return model_constrained(input_tensor)

# Create overall context and subsetted context.
# The subsetted context contains subset of examples where group attribute < 1
# (i.e. the subset of "Not Young" celebrity images).
# "groups_tensor < 1" is used instead of "groups_tensor == 0" as the former
# would be a comparison on the tensor value, while the latter would be a
# comparison on the Tensor object.
context = tfco.rate_context(predictions, labels=lambda:labels_tensor)
context_subset = context.subset(lambda:groups_tensor < 1)

# Setup list of constraints.
# In this notebook, the constraint will just be: FPR to less or equal to 5%.
constraints = [tfco.false_positive_rate(context_subset) <= 0.05]

# Setup rate minimization problem: minimize overall error rate s.t. constraints.
problem = tfco.RateMinimizationProblem(tfco.error_rate(context), constraints)

# Create constrained optimizer and obtain train_op.
# Separate optimizers are specified for the objective and constraints
optimizer = tfco.ProxyLagrangianOptimizerV2(
      optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
      constraint_optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
      num_constraints=problem.num_constraints)

# A list of all trainable variables is also needed to use TFCO.
var_list = (model_constrained.trainable_weights + list(problem.trainable_variables) +
            optimizer.trainable_variables())

The model is now set up and ready to be trained with the false positive rate constraint across age group.

Now, because the last iteration of the constrained model may not necessarily be the best performing model in terms of the defined constraint, the TFCO library comes equipped with `tfco.find_best_candidate_index()` that can help choose the best iterate out of the ones found after each epoch. Think of `tfco.find_best_candidate_index()` as an added heuristic that ranks each of the outcomes based on accuracy and fairness constraint (in this case, false positive rate across age group) separately with respect to the training data. That way, it can search for a better trade-off between overall accuracy and the fairness constraint.

The following cells will start the training with constraints while also finding the best performing model per iteration.

In [21]:
# Obtain train set batches.

NUM_ITERATIONS = 100  # Number of training iterations.
SKIP_ITERATIONS = 10  # Print training stats once in this many iterations.

# Create temp directory for saving snapshots of models.
temp_directory = tempfile.mktemp()
os.mkdir(temp_directory)

# List of objective and constraints across iterations.
objective_list = []
violations_list = []

# Training iterations.
iteration_count = 0
for (image, label, group) in celeb_a_train_data_w_group(BATCH_SIZE):
  # Assign current batch to input, labels and groups tensors.
  input_tensor.assign(image)
  labels_tensor.assign(label)
  groups_tensor.assign(group)

  # Run gradient update.
  optimizer.minimize(problem, var_list=var_list)

  # Record objective and violations.
  objective = problem.objective()
  violations = problem.constraints()

  sys.stdout.write(
      "\r Iteration %d: Hinge Loss = %.3f, Max. Constraint Violation = %.3f"
      % (iteration_count + 1, objective, max(violations)))

  # Snapshot model once in SKIP_ITERATIONS iterations.
  if iteration_count % SKIP_ITERATIONS == 0:
    objective_list.append(objective)
    violations_list.append(violations)

    # Save snapshot of model weights.
    model_constrained.save_weights(
        temp_directory + "/celeb_a_constrained_" +
        str(iteration_count / SKIP_ITERATIONS) + ".h5")

  iteration_count += 1
  if iteration_count >= NUM_ITERATIONS:
    break

# Choose best model from recorded iterates and load that model.
best_index = tfco.find_best_candidate_index(
    np.array(objective_list), np.array(violations_list))

model_constrained.load_weights(
    temp_directory + "/celeb_a_constrained_" + str(best_index) + ".0.h5")

# Remove temp directory.
os.system("rm -r " + temp_directory)

 Iteration 1: Hinge Loss = 1.653, Max. Constraint Violation = 0.950 Iteration 2: Hinge Loss = 1.263, Max. Constraint Violation = -0.050 Iteration 3: Hinge Loss = 0.970, Max. Constraint Violation = 0.807

 Iteration 4: Hinge Loss = 1.054, Max. Constraint Violation = 1.150 Iteration 5: Hinge Loss = 1.030, Max. Constraint Violation = 1.825 Iteration 6: Hinge Loss = 0.815, Max. Constraint Violation = 0.266

 Iteration 7: Hinge Loss = 1.170, Max. Constraint Violation = -0.050 Iteration 8: Hinge Loss = 1.010, Max. Constraint Violation = -0.050 Iteration 9: Hinge Loss = 0.795, Max. Constraint Violation = -0.050

 Iteration 10: Hinge Loss = 1.007, Max. Constraint Violation = -0.050 Iteration 11: Hinge Loss = 0.907, Max. Constraint Violation = -0.050 Iteration 12: Hinge Loss = 0.990, Max. Constraint Violation = 0.283

 Iteration 13: Hinge Loss = 0.915, Max. Constraint Violation = 0.634 Iteration 14: Hinge Loss = 0.817, Max. Constraint Violation = 1.283 Iteration 15: Hinge Loss = 0.849, Max. Constraint Violation = 0.588

 Iteration 16: Hinge Loss = 0.852, Max. Constraint Violation = -0.050 Iteration 17: Hinge Loss = 1.063, Max. Constraint Violation = -0.050 Iteration 18: Hinge Loss = 0.947, Max. Constraint Violation = -0.050

 Iteration 19: Hinge Loss = 0.961, Max. Constraint Violation = -0.050 Iteration 20: Hinge Loss = 0.794, Max. Constraint Violation = -0.050 Iteration 21: Hinge Loss = 0.910, Max. Constraint Violation = 0.289

 Iteration 22: Hinge Loss = 0.846, Max. Constraint Violation = 0.637 Iteration 23: Hinge Loss = 0.853, Max. Constraint Violation = 0.647 Iteration 24: Hinge Loss = 0.683, Max. Constraint Violation = 0.656

 Iteration 25: Hinge Loss = 0.797, Max. Constraint Violation = 0.635 Iteration 26: Hinge Loss = 0.873, Max. Constraint Violation = 0.301 Iteration 27: Hinge Loss = 0.868, Max. Constraint Violation = 0.310

 Iteration 28: Hinge Loss = 0.635, Max. Constraint Violation = -0.050 Iteration 29: Hinge Loss = 0.818, Max. Constraint Violation = 1.051 Iteration 30: Hinge Loss = 0.541, Max. Constraint Violation = -0.050

 Iteration 31: Hinge Loss = 0.954, Max. Constraint Violation = 1.057 Iteration 32: Hinge Loss = 0.838, Max. Constraint Violation = 0.686 Iteration 33: Hinge Loss = 0.802, Max. Constraint Violation = 0.692

 Iteration 34: Hinge Loss = 0.686, Max. Constraint Violation = 1.047 Iteration 35: Hinge Loss = 0.641, Max. Constraint Violation = -0.050 Iteration 36: Hinge Loss = 0.776, Max. Constraint Violation = 0.677

 Iteration 37: Hinge Loss = 0.624, Max. Constraint Violation = -0.050 Iteration 38: Hinge Loss = 0.755, Max. Constraint Violation = -0.050 Iteration 39: Hinge Loss = 0.790, Max. Constraint Violation = 0.301

 Iteration 40: Hinge Loss = 0.746, Max. Constraint Violation = 0.295 Iteration 41: Hinge Loss = 0.673, Max. Constraint Violation = 0.639 Iteration 42: Hinge Loss = 0.815, Max. Constraint Violation = 0.644

 Iteration 43: Hinge Loss = 0.704, Max. Constraint Violation = 0.644 Iteration 44: Hinge Loss = 0.692, Max. Constraint Violation = 0.989 Iteration 45: Hinge Loss = 0.649, Max. Constraint Violation = 0.299

 Iteration 46: Hinge Loss = 0.755, Max. Constraint Violation = 0.304 Iteration 47: Hinge Loss = 0.553, Max. Constraint Violation = -0.050 Iteration 48: Hinge Loss = 0.670, Max. Constraint Violation = -0.050

 Iteration 49: Hinge Loss = 0.935, Max. Constraint Violation = -0.050 Iteration 50: Hinge Loss = 1.294, Max. Constraint Violation = -0.050 Iteration 51: Hinge Loss = 0.895, Max. Constraint Violation = -0.050

 Iteration 52: Hinge Loss = 0.613, Max. Constraint Violation = 0.643 Iteration 53: Hinge Loss = 0.596, Max. Constraint Violation = 0.638 Iteration 54: Hinge Loss = 0.889, Max. Constraint Violation = 0.642

 Iteration 55: Hinge Loss = 0.832, Max. Constraint Violation = 0.646 Iteration 56: Hinge Loss = 0.659, Max. Constraint Violation = 0.298 Iteration 57: Hinge Loss = 0.689, Max. Constraint Violation = 0.645

 Iteration 58: Hinge Loss = 0.640, Max. Constraint Violation = 0.289 Iteration 59: Hinge Loss = 0.906, Max. Constraint Violation = -0.050 Iteration 60: Hinge Loss = 1.144, Max. Constraint Violation = -0.050

 Iteration 61: Hinge Loss = 0.771, Max. Constraint Violation = -0.050 Iteration 62: Hinge Loss = 0.792, Max. Constraint Violation = -0.050 Iteration 63: Hinge Loss = 0.795, Max. Constraint Violation = -0.050

 Iteration 64: Hinge Loss = 0.827, Max. Constraint Violation = 0.283 Iteration 65: Hinge Loss = 0.763, Max. Constraint Violation = 0.620 Iteration 66: Hinge Loss = 0.725, Max. Constraint Violation = 0.288

 Iteration 67: Hinge Loss = 0.719, Max. Constraint Violation = 0.950 Iteration 68: Hinge Loss = 0.755, Max. Constraint Violation = 0.282 Iteration 69: Hinge Loss = 0.905, Max. Constraint Violation = 0.285

 Iteration 70: Hinge Loss = 0.656, Max. Constraint Violation = -0.050 Iteration 71: Hinge Loss = 0.755, Max. Constraint Violation = -0.050 Iteration 72: Hinge Loss = 0.668, Max. Constraint Violation = 0.285

 Iteration 73: Hinge Loss = 0.798, Max. Constraint Violation = 0.283 Iteration 74: Hinge Loss = 0.801, Max. Constraint Violation = -0.050 Iteration 75: Hinge Loss = 0.777, Max. Constraint Violation = -0.050

 Iteration 76: Hinge Loss = 0.734, Max. Constraint Violation = 0.632 Iteration 77: Hinge Loss = 0.687, Max. Constraint Violation = 0.634 Iteration 78: Hinge Loss = 0.875, Max. Constraint Violation = 1.312

 Iteration 79: Hinge Loss = 0.788, Max. Constraint Violation = 1.300 Iteration 80: Hinge Loss = 0.750, Max. Constraint Violation = -0.050 Iteration 81: Hinge Loss = 0.531, Max. Constraint Violation = -0.050

 Iteration 82: Hinge Loss = 0.588, Max. Constraint Violation = -0.050 Iteration 83: Hinge Loss = 0.577, Max. Constraint Violation = 0.283 Iteration 84: Hinge Loss = 0.591, Max. Constraint Violation = 0.283

 Iteration 85: Hinge Loss = 0.588, Max. Constraint Violation = -0.050 Iteration 86: Hinge Loss = 0.587, Max. Constraint Violation = 0.289 Iteration 87: Hinge Loss = 0.733, Max. Constraint Violation = 0.617

 Iteration 88: Hinge Loss = 0.735, Max. Constraint Violation = 0.283 Iteration 89: Hinge Loss = 0.670, Max. Constraint Violation = -0.050 Iteration 90: Hinge Loss = 0.618, Max. Constraint Violation = -0.050

 Iteration 91: Hinge Loss = 0.720, Max. Constraint Violation = 0.936 Iteration 92: Hinge Loss = 0.592, Max. Constraint Violation = -0.050 Iteration 93: Hinge Loss = 0.468, Max. Constraint Violation = -0.050

 Iteration 94: Hinge Loss = 0.617, Max. Constraint Violation = 0.605 Iteration 95: Hinge Loss = 0.575, Max. Constraint Violation = -0.050 Iteration 96: Hinge Loss = 0.597, Max. Constraint Violation = -0.050

 Iteration 97: Hinge Loss = 0.792, Max. Constraint Violation = 0.270 Iteration 98: Hinge Loss = 0.556, Max. Constraint Violation = 0.267 Iteration 99: Hinge Loss = 0.596, Max. Constraint Violation = -0.050

 Iteration 100: Hinge Loss = 0.614, Max. Constraint Violation = 0.268

0

After having applied the constraint, we evaluate the results once again using Fairness Indicators.

In [22]:
model_location = save_model(model_constrained, 'model_export_constrained')
eval_result_constrained = get_eval_results(model_location, 'eval_results_constrained')

INFO:tensorflow:Assets written to: /tmpfs/tmp/saved_modelses71pd9f/model_export_constrained/assets


INFO:tensorflow:Assets written to: /tmpfs/tmp/saved_modelses71pd9f/model_export_constrained/assets




As with the previous time we used Fairness Indicators, deselect false_negative_rate and select false_positive_rate to look at the metric we are interested in.

Note that to fairly compare the two versions of our model, it is important to use thresholds that set the overall false positive rate to be roughly equal. This ensures that we are looking at actual change as opposed to just a shift in the model equivalent to simply moving the threshold boundary. In our case, comparing the unconstrained model at 0.5 and the constrained model at 0.22 provides a fair comparison for the models.

In [23]:
eval_results_dict = {
    'constrained': eval_result_constrained,
    'unconstrained': eval_results_unconstrained,
}
tfma.addons.fairness.view.widget_view.render_fairness_indicator(multi_eval_results=eval_results_dict)

FairnessIndicatorViewer(evalName='constrained', evalNameCompare='unconstrained', slicingMetrics=[{'sliceValue'…

With TFCO's ability to express a more complex requirement as a rate constraint, we helped this model achieve a more desirable outcome with little impact to the overall performance. There is, of course, still room for improvement, but at least TFCO was able to find a model that gets close to satisfying the constraint and reduces the disparity between the groups as much as possible.