# Training the Model

## Preparation

In [None]:
import pickle as pkl, numpy as np
from tf_tm import *

#### Set Parameters

In [None]:
BUFFER_SIZE = 64
BATCH_SIZE = 32
IMG_SIZE = 256
EPOCHS = 1000
noise_dim = 10
num_examples_to_generate = 9
generator_lr = .001
discriminator_lr = .0001

#### Try Starting Device

In [None]:
try:
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
    tf.config.experimental_connect_to_cluster(resolver)
    # This is the TPU initialization code that has to be at the beginning.
    tf.tpu.experimental.initialize_tpu_system(resolver)
    print("All devices: ", tf.config.list_logical_devices('TPU'))
    strategy = tf.distribute.TPUStrategy(resolver)
except ValueError:
    print('\x1b[31mNo TPU found\x1b[0m')
    tf.config.list_physical_devices('GPU')

#### Define the data

In [None]:
planets = pkl.load(open('planets.pkl', 'rb'))
images_arrs = np.load('training_images.npy')
image_info = np.load('training_info.npy')
labels = np.load('training_labels.npy')

datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=True,
    rotation_range=15,
    width_shift_range=16,
    height_shift_range=16,
    horizontal_flip=True,
    vertical_flip=True,
)
datagen.fit(images_arrs, augment=True)

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

## Modeling

#### Defining the Discriminator

In [None]:
disc_shape = (128, )
disc_img_shape = (IMG_SIZE, IMG_SIZE, 3)

# with strategy.scope():
discriminator = make_discriminator_model(disc_shape, disc_img_shape)
print (discriminator.summary())

decision = discriminator.predict(images_arrs)
print()
# print ("Prediction for image from training data:", decision)

#### Defining the Generator

In [None]:
gen_shape = (32, 64, 128, IMG_SIZE)
gen_depth = (32, 32, 16, 1)
gen_input = (image_info.shape[1] + noise_dim)

# with strategy.scope():
generator = make_generator_model(gen_shape, gen_depth, gen_input)

noise = tf.concat((tf.random.normal([image_info.shape[0], noise_dim]), tf.convert_to_tensor(image_info, dtype=tf.float32)), axis=1)
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0])
plt.show()

print(generator.summary())