In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
from google.colab import drive
import cv2

drive.mount('/content/drive')
tf.compat.v1.enable_eager_execution()

Get MNIST and STL10 datasets. Convert them to lists.

In [None]:
TRAIN_FRAC = 0.9

#-------------------------------------------------------------------------------
# Setup MNIST train and test datasets.
mnist_ds = tfds.load(name='mnist', split=None)
mnist_ds = list(mnist_ds['train'].as_numpy_iterator()) + list(mnist_ds['test'].as_numpy_iterator())
mnist_num_labeled = len(mnist_ds)
mnist_num_train = int(round(TRAIN_FRAC * mnist_num_labeled))
mnist_num_test = mnist_num_labeled - mnist_num_train
mnist_ds = {'train': mnist_ds[:mnist_num_train], 'test': mnist_ds[mnist_num_train:]}

print('MNIST train dataset size: ', len(mnist_ds['train']))
print('MNIST test dataset size : ', len(mnist_ds['test']))
print('MNIST image shape       : ', mnist_ds['train'][0]['image'].shape)
print('MNIST keys              : ', mnist_ds['train'][0].keys())

# Setup STL10 train and test datasets.
stl10_ds = tfds.load(name='stl10', split=None)
stl10_ds = list(stl10_ds['train'].as_numpy_iterator()) + list(stl10_ds['test'].as_numpy_iterator())
stl10_num_labeled = len(stl10_ds)
stl10_num_train = int(round(TRAIN_FRAC * stl10_num_labeled))
stl10_num_test = stl10_num_labeled - stl10_num_train
stl10_ds = {'train': stl10_ds[:stl10_num_train], 'test': stl10_ds[stl10_num_train:]}

print('STL10 train dataset size: ', len(stl10_ds['train']))
print('STL10 test dataset size : ', len(stl10_ds['test']))
print('STL10 image shape       : ', stl10_ds['train'][0]['image'].shape)
print('STL10 keys              : ', stl10_ds['train'][0].keys())

Function that will generate a single image view pair. Details:

* Either image in the view pair consists of an MNIST instance overlaid on top of an STL10 instance as the background.
* Up to three invariances can be imposed between the two views:
  1. The same MNIST class.
  2. The same x/y position of the MNIST digit.
  3. The same STL10 class.

* The x/y positions are quantized (by num_x_pos and num_y_pos) to be equally spaced throughout the image.
* The STL10 image in each image in a view pair is a random crop of the full-sized STL10 imnage.


In [3]:
def create_mnist_on_stl10_view_pair(
    mnist_ds,                # MNIST dataset. List of {'image', 'label'} dict for only *one* split.
    stl10_ds,                # STL10 dataset. List of {'image', 'label'} dict for only *one* split.
    img_size=64,             # x or y size (in pix) of each view in a created pair.
    num_x_pos=4,             # Number of possible evenly-spaced x values for digit position.
    num_y_pos=2,             # Number of possible evenly-spaced y values for digit position.
    same_digit=False,        # Should both views in the pair have digits corresponding to the same MNIST digit class?
    same_pos=False,          # Should both views in the pair position the digits at the same x/y pixel location?
    same_bkgnd=False,        # Should both views in the pair have backgrounds corresponding to the same STL10 digit class?
    center_bkgnd_crop=True,  # Should STL10 images be center cropped down to |img_size|? Else, they are randomly cropped.
    ):
  assert len(mnist_ds) > 0
  assert len(stl10_ds) > 0
  mnist_img_size = mnist_ds[0]['image'].shape[0]
  stl10_img_size = stl10_ds[0]['image'].shape[0]
  assert mnist_img_size <= stl10_img_size
  assert mnist_img_size <= img_size <= stl10_img_size
  assert img_size % num_x_pos == 0
  assert img_size % num_y_pos == 0

  # Get STL10 instances to use for views. The two views will share the same
  # STL10 label iff same_bkgnd == True.
  stl10_index_for_view0 = np.random.randint(0, len(stl10_ds))
  stl10_image_for_view0 = stl10_ds[stl10_index_for_view0]['image']  # shape [96, 96, 3]
  stl10_label_for_view0 = stl10_ds[stl10_index_for_view0]['label']

  stl10_index_for_view1 = np.random.randint(0, len(stl10_ds))
  stl10_label_for_view1 = stl10_ds[stl10_index_for_view1]['label']
  while (stl10_label_for_view1 != stl10_label_for_view0 if same_bkgnd
         else stl10_label_for_view1 == stl10_label_for_view0):
    stl10_index_for_view1 = np.random.randint(0, len(stl10_ds))
    stl10_label_for_view1 = stl10_ds[stl10_index_for_view1]['label']
  stl10_image_for_view1 = stl10_ds[stl10_index_for_view1]['image']  # shape [96, 96, 3]

  # Crop STL10 instances down to desired size.
  if center_bkgnd_crop:
    top_for_view0 = (stl10_img_size - img_size) // 2
    lft_for_view0 = (stl10_img_size - img_size) // 2
    top_for_view1 = (stl10_img_size - img_size) // 2
    lft_for_view1 = (stl10_img_size - img_size) // 2
  else:
    top_for_view0 = np.random.randint(0, stl10_img_size - img_size + 1)
    lft_for_view0 = np.random.randint(0, stl10_img_size - img_size + 1)
    top_for_view1 = np.random.randint(0, stl10_img_size - img_size + 1)
    lft_for_view1 = np.random.randint(0, stl10_img_size - img_size + 1)
  stl10_image_for_view0 = (
      np.copy(stl10_image_for_view0[top_for_view0:top_for_view0 + img_size,
                            lft_for_view0:lft_for_view0 + img_size, :]))
  stl10_image_for_view1 = (
      np.copy(stl10_image_for_view1[top_for_view1:top_for_view1 + img_size,
                            lft_for_view1:lft_for_view1 + img_size, :]))
  assert stl10_image_for_view0.shape == (img_size, img_size, 3)
  assert stl10_image_for_view1.shape == (img_size, img_size, 3)

  # Get MNIST instances to use for views. The two views will share the same
  # MNIST label iff same_digit == True.
  mnist_index_for_view0 = np.random.randint(0, len(mnist_ds))
  mnist_image_for_view0 = mnist_ds[mnist_index_for_view0]['image']  # shape [28, 28, 1]
  mnist_image_for_view0 = np.concatenate((mnist_image_for_view0,) * 3, axis=-1)  # shape [28, 28, 3]
  mnist_label_for_view0 = mnist_ds[mnist_index_for_view0]['label']

  mnist_index_for_view1 = np.random.randint(0, len(mnist_ds))
  mnist_label_for_view1 = mnist_ds[mnist_index_for_view1]['label']
  while (mnist_label_for_view1 != mnist_label_for_view0 if same_digit
         else mnist_label_for_view1 == mnist_label_for_view0):
    mnist_index_for_view1 = np.random.randint(0, len(mnist_ds))
    mnist_label_for_view1 = mnist_ds[mnist_index_for_view1]['label']
  mnist_image_for_view1 = mnist_ds[mnist_index_for_view1]['image']  # shape [28, 28, 1]
  mnist_image_for_view1 = np.concatenate((mnist_image_for_view1,) * 3, axis=-1)  # shape [28, 28, 3]

  # Get (top/left) position to use to place digit within background. The two
  # views will have the same x/y position iff same_pos == True.
  top_label_for_view0 = np.random.randint(0, num_y_pos)
  lft_label_for_view0 = np.random.randint(0, num_x_pos)
  pos_label_for_view0 = top_label_for_view0 * num_x_pos + lft_label_for_view0
  if same_pos:
    top_label_for_view1 = top_label_for_view0
    lft_label_for_view1 = lft_label_for_view0
    pos_label_for_view1 = pos_label_for_view0
  else:
    top_label_for_view1 = np.random.randint(0, num_y_pos)
    lft_label_for_view1 = np.random.randint(0, num_x_pos)
    pos_label_for_view1 = top_label_for_view1 * num_x_pos + lft_label_for_view1
    while pos_label_for_view1 == pos_label_for_view0:
      top_label_for_view1 = np.random.randint(0, num_y_pos)
      lft_label_for_view1 = np.random.randint(0, num_x_pos)
      pos_label_for_view1 = top_label_for_view1 * num_x_pos + lft_label_for_view1

  if img_size >= mnist_img_size * num_y_pos:
    digit_top_pos_for_view0 = top_label_for_view0 * (img_size // num_y_pos) + (img_size // num_y_pos - mnist_img_size) // 2
    digit_top_pos_for_view1 = top_label_for_view1 * (img_size // num_y_pos) + (img_size // num_y_pos - mnist_img_size) // 2
  else:
    digit_top_pos_for_view0 = top_label_for_view0 * (img_size - mnist_img_size) // (num_y_pos - 1)
    digit_top_pos_for_view1 = top_label_for_view1 * (img_size - mnist_img_size) // (num_y_pos - 1)

  if img_size >= mnist_img_size * num_x_pos:
    digit_lft_pos_for_view0 = lft_label_for_view0 * (img_size // num_x_pos) + (img_size // num_x_pos - mnist_img_size) // 2
    digit_lft_pos_for_view1 = lft_label_for_view1 * (img_size // num_x_pos) + (img_size // num_x_pos - mnist_img_size) // 2
  else:
    digit_lft_pos_for_view0 = lft_label_for_view0 * (img_size - mnist_img_size) // (num_x_pos - 1)
    digit_lft_pos_for_view1 = lft_label_for_view1 * (img_size - mnist_img_size) // (num_x_pos - 1)

  # Overlay MNIST instances on STL10 ones.
  view0 = stl10_image_for_view0
  y_range = slice(digit_top_pos_for_view0, digit_top_pos_for_view0 + mnist_img_size)
  x_range = slice(digit_lft_pos_for_view0, digit_lft_pos_for_view0 + mnist_img_size)
  #####view0[y_range, x_range, :] = np.where(mnist_image_for_view0 > 0, mnist_image_for_view0, view0[y_range, x_range, :])
  view0[y_range, x_range, :] = np.maximum(mnist_image_for_view0, view0[y_range, x_range, :])

  view1 = stl10_image_for_view1
  y_range = slice(digit_top_pos_for_view1, digit_top_pos_for_view1 + mnist_img_size)
  x_range = slice(digit_lft_pos_for_view1, digit_lft_pos_for_view1 + mnist_img_size)
  #####view1[y_range, x_range, :] = np.where(mnist_image_for_view1 > 0, mnist_image_for_view1, view1[y_range, x_range, :])
  view1[y_range, x_range, :] = np.maximum(mnist_image_for_view1, view1[y_range, x_range, :])

  return (view0, view1,
          (mnist_label_for_view0, pos_label_for_view0, stl10_label_for_view0),
          (mnist_label_for_view1, pos_label_for_view1, stl10_label_for_view1))


Test the view pair generating function.

In [None]:
NUM_VIEW_PAIRS = 5
SPLIT = 'train'
IMG_SIZE = 64
NUM_X_POS = 4
NUM_Y_POS = 2
SAME_DIGIT = True
SAME_POS = False
SAME_BKGND = True
CENTER_BKGND_CROP = True

#-------------------------------------------------------------------------------

def MakePlot(fig, num_plot_rows, num_plot_cols, plot_num, plot_img, plot_title,
             cmap='jet', vmin=0.0, vmax=1.0):
  fig.add_subplot(num_plot_rows, num_plot_cols, plot_num)
  plt.imshow(plot_img, cmap=cmap, vmin=vmin, vmax=vmax)
  plt.grid(False)
  plt.axis('off')
  plt.title(plot_title)

for i in range(NUM_VIEW_PAIRS):
  (view0, view1,
   (mnist_label_for_view0, pos_label_for_view0, stl10_label_for_view0),
   (mnist_label_for_view1, pos_label_for_view1, stl10_label_for_view1)
   ) = create_mnist_on_stl10_view_pair(
       mnist_ds[SPLIT],
       stl10_ds[SPLIT],
       img_size=IMG_SIZE,
       num_x_pos=NUM_X_POS,
       num_y_pos=NUM_Y_POS,
       same_digit=SAME_DIGIT,
       same_pos=SAME_POS,
       same_bkgnd=SAME_BKGND,
       center_bkgnd_crop=CENTER_BKGND_CROP)

  num_plot_cols = 2
  num_plot_rows =1
  fig = plt.figure(figsize = (5 * num_plot_cols, 5 * num_plot_rows))
  plot_num = 0

  plot_num += 1
  MakePlot(fig, num_plot_rows, num_plot_cols, plot_num, view0 / 255.0, f'View0: M/P/S Labels = {mnist_label_for_view0}/{pos_label_for_view0}/{stl10_label_for_view0}')
  plot_num += 1
  MakePlot(fig, num_plot_rows, num_plot_cols, plot_num, view1 / 255.0, f'View1: M/P/S Labels = {mnist_label_for_view1}/{pos_label_for_view1}/{stl10_label_for_view1}')
  plt.plot()


Setup a class to generate a specified number of view pair instances and write them to disk as sharded TFRecord files.

In [5]:
class DatasetWriter:
  def __init__(
      self,
      num_view_pairs,
      split,
      output_dir,
      num_shards=1,
      img_size=64,
      num_x_pos=4,
      num_y_pos=2,
      same_digit=False,
      same_pos=False,
      same_bkgnd=False,
      center_bkgnd_crop=True,
      ):
    self.num_view_pairs=num_view_pairs
    self.split=split
    self.img_size=img_size
    self.num_x_pos=num_x_pos
    self.num_y_pos=num_y_pos
    self.same_digit=same_digit
    self.same_pos=same_pos
    self.same_bkgnd=same_bkgnd
    self.center_bkgnd_crop=center_bkgnd_crop

    self.output_dir = output_dir
    self.num_shards = num_shards
    if not tf1.gfile.Exists(self.output_dir):
      tf1.gfile.MakeDirs(self.output_dir)

  def _bytes_feature(self, value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
      value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

  def _float_feature(self, value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

  def _int64_feature(self, value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
  def _image_to_encoded_jpg_bytes(self, image):
    success, encoded_image = cv2.imencode('.jpg', image)
    if not success:
      raise RuntimeError('Could not encode image')
    return encoded_image.tobytes()

  def _get_label(self, digit_label, pos_label, bkgnd_label):
    num_digit_labels=10
    num_pos_labels=self.num_x_pos * self.num_y_pos
    num_bkgnd_labels=10

    if self.same_digit and self.same_pos and self.same_bkgnd:
      label = num_bkgnd_labels * pos_label + bkgnd_label  # 0 - 79 
      label = num_bkgnd_labels * num_pos_labels * digit_label + label  # 0 - 799
    elif self.same_digit and self.same_pos:
      label = num_pos_labels * digit_label + pos_label  # 0 - 79 
    elif self.same_digit and self.same_bkgnd:
      label = num_bkgnd_labels * digit_label + bkgnd_label  # 0 - 99
    elif self.same_pos and self.same_bkgnd:
      label = num_bkgnd_labels * pos_label + bkgnd_label  # 0 - 79
    elif self.same_digit:
      label = digit_label  # 0 - 9
    elif self.same_pos:
      label = pos_label  # 0 - 7
    elif self.same_bkgnd:
      label = bkgnd_label  # 0 - 9
    else:
      # None of same_digit, same_pos, or same_bkgnd are True. Set label to -1.
      label = -1

    return label

  def _create_example(self):
    (view0, view1, (digit_label, pos_label, bkgnd_label), _) = create_mnist_on_stl10_view_pair(
        mnist_ds[self.split],
        stl10_ds[self.split],
        img_size=self.img_size,
        num_x_pos=self.num_x_pos,
        num_y_pos=self.num_y_pos,
        same_digit=self.same_digit,
        same_pos=self.same_pos,
        same_bkgnd=self.same_bkgnd,
        center_bkgnd_crop=self.center_bkgnd_crop,
        )

    label = self._get_label(digit_label, pos_label, bkgnd_label)
    
    example = tf.train.Example(features=tf.train.Features(feature={
      'view0/encoded': self._bytes_feature(self._image_to_encoded_jpg_bytes(view0)),
      'view1/encoded': self._bytes_feature(self._image_to_encoded_jpg_bytes(view1)),
      'height': self._int64_feature(view0.shape[0]),
      'width': self._int64_feature(view0.shape[1]),
      'channels': self._int64_feature(view0.shape[2]),
      'label': self._int64_feature(label),
      }))

    return example

  def write_tfrecords(self):
    max_records_per_shard = (self.num_view_pairs + self.num_shards - 1) // self.num_shards
    num_processed_examples = 0
    for shard in range(self.num_shards):
      prefix = 'same_digit_%s_same_pos_%s_same_bkgnd_%s_%s' % (
          'true' if self.same_digit else 'false',
          'true' if self.same_pos else 'false',
          'true' if self.same_bkgnd else 'false', self.split)
      filename = os.path.join(self.output_dir, '%s-%.5d-of-%.5d' % (prefix, shard, self.num_shards))
      print('Writing file ', filename)
      writer = tf.io.TFRecordWriter(filename)
      num_records_to_process = min(max_records_per_shard, self.num_view_pairs - num_processed_examples)
      for i in range(num_records_to_process):
        example = self._create_example()
        writer.write(example.SerializeToString())
        num_processed_examples += 1
      writer.close()


Create and write the 2 * 2 * 2 datasets to disk.

In [None]:
if True:  # Be careful: Running this codeblock will overwrite the TFRecord files on disk!!!
  NUM_TRAIN_VIEW_PAIRS = 100000
  TRAIN_SPLIT = 'train'
  NUM_TRAIN_SHARDS = 75

  NUM_TEST_VIEW_PAIRS = 10000
  TEST_SPLIT = 'test'
  NUM_TEST_SHARDS = 10

  OUTPUT_DIR = '/content/drive/My Drive/mnist-stl10/' # '/cns/ok-d/home/viscam/contrastive_learning/mnist_on_stl10'

  IMG_SIZE = 64
  NUM_X_POS = 4
  NUM_Y_POS = 2
  CENTER_BKGND_CROP = True

  #-------------------------------------------------------------------------------

  def WriteTrainAndTestDataset(same_digit, same_pos, same_bkgnd):
    ds_writer = (
        DatasetWriter(NUM_TRAIN_VIEW_PAIRS, TRAIN_SPLIT, OUTPUT_DIR, num_shards=NUM_TRAIN_SHARDS, img_size=IMG_SIZE, num_x_pos=NUM_X_POS, num_y_pos=NUM_Y_POS,
                      same_digit=same_digit, same_pos=same_pos, same_bkgnd=same_bkgnd, center_bkgnd_crop=CENTER_BKGND_CROP))
    ds_writer.write_tfrecords()
    ds_writer = (
        DatasetWriter(NUM_TEST_VIEW_PAIRS, TEST_SPLIT, OUTPUT_DIR, num_shards=NUM_TEST_SHARDS, img_size=IMG_SIZE, num_x_pos=NUM_X_POS, num_y_pos=NUM_Y_POS,
                      same_digit=same_digit, same_pos=same_pos, same_bkgnd=same_bkgnd, center_bkgnd_crop=CENTER_BKGND_CROP))
    ds_writer.write_tfrecords()

  WriteTrainAndTestDataset(same_digit=True , same_pos=True , same_bkgnd=False)
  WriteTrainAndTestDataset(same_digit=True , same_pos=False, same_bkgnd=True )
  WriteTrainAndTestDataset(same_digit=False, same_pos=True , same_bkgnd=True )


Check that the files on disk contain valid data.

In [None]:
SAME_DIGIT = 'true'  # 'true' or 'false' strings
SAME_POS = 'true'  # 'true' or 'false' strings
SAME_BKGND = 'false'  # 'true' or 'false' strings

OUTPUT_DIR = '/content/drive/My Drive/mnist-stl10'
NUM_TRAIN_SHARDS = 75

#-------------------------------------------------------------------------------

filename = os.path.join(OUTPUT_DIR, 'same_digit_%s_same_pos_%s_same_bkgnd_%s_train-00000-of-%.5d' % (SAME_DIGIT, SAME_POS, SAME_BKGND, NUM_TRAIN_SHARDS))
tfrecord_dataset = tf.data.TFRecordDataset(filename).take(10)

feature_description = {
    'view0/encoded': tf.io.FixedLenFeature([], tf.string),
    'view1/encoded': tf.io.FixedLenFeature([], tf.string),
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'channels': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    }

def parse(example_proto):
  return tf.io.parse_single_example(example_proto, feature_description)

parsed_dataset = tfrecord_dataset.map(parse)
parsed_dataset
for features in parsed_dataset:
  view0 = cv2.imdecode(np.fromstring(features['view0/encoded'].numpy(), np.uint8), cv2.IMREAD_COLOR)
  view1 = cv2.imdecode(np.fromstring(features['view1/encoded'].numpy(), np.uint8), cv2.IMREAD_COLOR)
  height = features['height'].numpy()
  width = features['width'].numpy()
  channels = features['channels'].numpy()
  label = features['label'].numpy()

  num_plot_cols = 2
  num_plot_rows =1
  fig = plt.figure(figsize = (5 * num_plot_cols, 5 * num_plot_rows))
  plot_num = 0

  plot_num += 1
  MakePlot(fig, num_plot_rows, num_plot_cols, plot_num, view0 / 255.0, f'View0: shape=[{height},{width},{channels}], label={label}')
  plot_num += 1
  MakePlot(fig, num_plot_rows, num_plot_cols, plot_num, view1 / 255.0, f'View1: shape=[{height},{width},{channels}], label={label}')
  plt.plot()
