In [None]:
seed = 42

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['MPLCONFIGDIR'] = os.getcwd()+'/configs/'

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)

import numpy as np
np.random.seed(seed)

import logging
import tensorflow as tf
tf.autograph.set_verbosity(0)
tf.get_logger().setLevel(logging.ERROR)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
tf.random.set_seed(seed)
tf.compat.v1.set_random_seed(seed)

import random
random.seed(seed)

from tensorflow.keras import backend as K
import tf2onnx
import onnxruntime as rt
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

tfk = tf.keras
tfkl = tf.keras.layers
print(tf.version)

In [None]:
from sgde_client.auth import get_users, get_user

get_users()

In [None]:
from sgde_client.auth import login

login()

In [None]:
from sgde_client.models.training import train_image_generator
from sgde_client.models.inference import generate_samples_onnx

(X_train, y_train), (X_test, y_test) = tfk.datasets.mnist.load_data()
X = np.concatenate([X_train, X_test], axis=0)
y = np.concatenate([y_train, y_test], axis=0)

file_path, metadata = train_image_generator(
    name="my_other_1ep_gan",
    description="""
    # An dummy GAN for MNIST
    ...with a slick Markdown description.
    """,
    X=X,
    y=y,
    epochs=1,
    classifier_epochs=1,
    batch_size=64,
    image_size=28,
    model_size="small",
    task="classification",
    sub_task="Handwritten digit classification",
    data_description="A balanced grayscale image dataset containing 10 different classes",
    dataset_name="MNIST",
    verbose=True,
)

In [None]:
file_path

In [None]:
metadata

In [None]:
from sgde_client.exchange import upload_generator

upload_generator(file_path, metadata)

In [None]:
from sgde_client.exchange import get_generators

get_generators()

In [None]:
from sgde_client.exchange import get_generator, download_generator

downloaded_metadata = get_generator("mnist_small_gan")
downloaded_gan_path = download_generator("mnist_small_gan")
downloaded_metadata

In [None]:
from sgde_client.models.inference import generate_samples_onnx

num_gen_samples = 10

samples = generate_samples_onnx(
    num_samples=num_gen_samples,
    path=downloaded_gan_path,
    input_shape=downloaded_metadata.generator_input_shape,
    num_classes=downloaded_metadata.num_classes
)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

fig, axes = plt.subplots(1, num_gen_samples, figsize=(20, 3 * num_gen_samples))
for i in range(num_gen_samples):
    img = tfk.preprocessing.image.array_to_img(samples[i])
    ax = axes[i % num_gen_samples]
    ax.imshow(np.squeeze(img), cmap="gray")
    ax.set_xticks([]), ax.set_yticks([])
plt.tight_layout()
plt.show()