In [8]:
import tensorflow as tf
import numpy as np
import pandas as pd

from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Dropout, GlobalMaxPooling2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.models import Model

In [9]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu = '')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print('All devices: ', tf.config.list_logical_devices('TPU'))



All devices:  [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]


In [10]:
strategy = tf.distribute.TPUStrategy(resolver)

In [11]:
#load in the data
cifar10 = tf.keras.datasets.cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Flatten the input data
y_train, y_test = y_train.flatten(), y_test.flatten()
print('x_train.shape:', x_train.shape)
print('y_train.shape:', y_train.shape)

x_train.shape: (50000, 32, 32, 3)
y_train.shape: (50000,)


In [18]:
# number of classes
K = len(set(y_train))
print("number of classes:", K)

number of classes: 10


In [19]:
# model creation must be in strategy scope
# we will define the function now, but this code
# won't run outside the scope

In [13]:
# Build model
def create_model():
  i = Input(shape = x_train[0].shape)

  x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(i)
  x = BatchNormalization()(x)
  x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(x)
  x = BatchNormalization()(x)
  x = MaxPooling2D((2, 2))(x)
  x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(x)
  x = BatchNormalization()(x)
  x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(x)
  x = BatchNormalization()(x)
  x = MaxPooling2D((2, 2))(x)
  x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(x)
  x = BatchNormalization()(x)
  x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(x)
  x = BatchNormalization()(x)
  x = MaxPooling2D((2, 2))(x)

  x = Flatten()(x)
  x = Dropout(0.2)(x)
  x = Dense(1024, activation = 'relu')(x)
  x = Dropout(0.2)(x)
  x = Dense(K, activation = 'softmax')(x)

  model = Model(i, x)
  return model

In [14]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

In [20]:
with strategy.scope():
  model = create_model()
  model.compile(
      optimizer = 'adam',
      loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
      metrics = ['sparse_categorical_accuracy']
  )

  batch_size = 256

  # reshuffle_each_iteration = None is default but is later set to True if None
  # thus "True" is the actual default
  train_dataset = train_dataset.shuffle(1000).batch(batch_size)
  test_dataset = test_dataset.batch(batch_size)


In [21]:
model.fit(train_dataset, epochs = 5, validation_data = test_dataset),

Epoch 1/5


  output, from_logits = _get_logits(


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


(<keras.callbacks.History at 0x7f3b90726980>,)

In [22]:
model.save('mymodel.h5')

In [23]:
with strategy.scope():
  model = tf.keras.models.load_model('mymodel.h5')
  out = model.predict(x_test[:1])
  print(out)

[[0.00756784 0.02367925 0.01765121 0.6528193  0.00301466 0.19226538
  0.07879657 0.00409831 0.01655927 0.00354816]]


In [26]:
with strategy.scope():
    model = tf.keras.models.load_model('mymodel.h5')
    out = model.predict(x_test)
    predicted_classes = out.argmax(axis=1)  # Find the index of the maximum value along axis 1
    print("predicted:", predicted_classes[0], "actual:", y_test[0])


predicted: 3 actual: 3
