In [1]:
from __future__ import absolute_import, division, print_function
import numpy as np
import os
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
from scipy import linalg
from tensorflow.compat.v2.keras.utils import get_file
import tarfile
import tensorflow_hub as tfhub
import six
#from train_utils import num_device
import tensorflow_gan as tfgan
from PIL import Image




# FID & IS

In [2]:
INCEPTION_TFHUB = 'D:/gpu-torch/improved replay buffer for EBM/inception model'
INCEPTION_OUTPUT = 'logits'
INCEPTION_FINAL_POOL = 'pool_3'
_DEFAULT_DTYPES = {
    INCEPTION_OUTPUT: tf2.float32,
    INCEPTION_FINAL_POOL: tf2.float32
}
def get_inception_model():
    return tfhub.load(INCEPTION_TFHUB)

#inception_model = get_inception_model()

def classifier_fn_from_tfhub(tfhub_module, output_fields, inception_model,
                             return_tensor=False):
  """Returns a function that can be as a classifier function.

  Copied from tfgan but avoid loading the model each time calling _classifier_fn

  Wrapping the TF-Hub module in another function defers loading the module until
  use, which is useful for mocking and not computing heavy default arguments.

  Args:
    tfhub_module: A string handle for a TF-Hub module.
    output_fields: A string, list, or `None`. If present, assume the module
      outputs a dictionary, and select this field.
    inception_model: A model loaded from TFHub.
    return_tensor: If `True`, return a single tensor instead of a dictionary.

  Returns:
    A one-argument function that takes an image Tensor and returns outputs.
  """
  if isinstance(output_fields, six.string_types):
    output_fields = [output_fields]

  def _classifier_fn(images):
    output = inception_model(images)
    if output_fields is not None:
      output = {x: output[x] for x in output_fields}
    if return_tensor:
      assert len(output) == 1
      output = list(output.values())[0]
    return tf2.nest.map_structure(tf.compat.v1.layers.flatten, output)

  return _classifier_fn

def load_dataset_stats():
    """Load the pre-computed dataset statistics."""
    filename = 'statistics/statistics_{}.npz'.format('cifar10')
    with tf2.io.gfile.GFile(filename, 'rb') as fin:
        stats = np.load(fin)
        return stats

def cal_fid_is(images,inception_model,fid_n_samples):
    inputs = (tf2.cast(images, tf2.float32) - 127.5) / 127.5
    res = tfgan.eval.run_classifier_fn(
          inputs,
          num_batches=1,
          classifier_fn=classifier_fn_from_tfhub(INCEPTION_TFHUB, None,
                                                 inception_model),
          dtypes=_DEFAULT_DTYPES)
    
    all_pools=[]
    all_pools.append(res["pool_3"])
    all_pools = np.concatenate(all_pools, axis=0)
    
    all_logits = []
    all_logits.append(res["logits"])
    all_logits = np.concatenate(all_logits, axis=0)

    data_stats = load_dataset_stats()
    data_pools = data_stats["pool_3"][1800:1800+fid_n_samples]

    fid = tfgan.eval.frechet_classifier_distance_from_activations(
          data_pools, all_pools)

    inception_score = tfgan.eval.classifier_score_from_logits(all_logits)
    return fid,inception_score

In [3]:
inception_model = get_inception_model()
datapath = 'D:/Sample_Data/'
re_img = np.zeros((100,32,32,3))
for i in range(100):
    img = Image.open(datapath+'4680_'+str(i)+'.jpg')
    re_img[i] = np.asarray(img)
    print(re_img[i])
#cal_fid_is(re_img,inception_model,100)

[[[ 84. 155.  25.]
  [107. 177.  26.]
  [100. 169.   0.]
  ...
  [237. 255. 117.]
  [225. 255. 125.]
  [206. 237. 117.]]

 [[ 91. 161.  29.]
  [115. 184.  33.]
  [111. 177.   0.]
  ...
  [222. 255. 107.]
  [211. 251. 116.]
  [211. 250. 125.]]

 [[ 68. 134.   0.]
  [ 93. 157.   8.]
  [ 97. 160.   0.]
  ...
  [218. 255. 111.]
  [192. 247. 102.]
  [180. 233.  99.]]

 ...

 [[241. 253.   0.]
  [241. 254.   0.]
  [238. 254.   8.]
  ...
  [150. 233.  21.]
  [161. 237.  40.]
  [188. 255.  68.]]

 [[231. 249.   0.]
  [230. 250.   0.]
  [229. 251.   5.]
  ...
  [161. 251.  32.]
  [168. 254.  33.]
  [187. 255.  49.]]

 [[226. 246.   0.]
  [227. 248.   0.]
  [226. 250.   4.]
  ...
  [161. 255.  34.]
  [170. 255.  29.]
  [186. 255.  38.]]]
[[[144. 163. 144.]
  [178. 202. 186.]
  [149. 180. 174.]
  ...
  [ 73. 151. 164.]
  [ 77. 158. 161.]
  [ 80. 162. 158.]]

 [[127. 152. 133.]
  [ 93. 124. 108.]
  [ 84. 122. 111.]
  ...
  [ 83. 158. 164.]
  [ 74. 149. 144.]
  [ 76. 152. 140.]]

 [[ 85. 124. 103.]