In [1]:
import os
import argparse
import logging
from pathlib import Path, PurePath
from collections import defaultdict
from src.models.lstm_based.base_model import epsilon_3_model, epsilon_5_model

from src.models.structures import *
from src.models.intermediate_robust_generator.model import *
from src.models.lstm_based.helper import retrieve_dataset, aggregate

import tensorflow as tf
import tensorflow.keras as tfk
import tensorflow_addons as tfa

import sklearn as sk
import numpy as np

In [2]:
BUFFER = 2048
BATCH_SIZE = 16
LEARNING_RATE = 0.0001
EPOCHS = 1

In [3]:
N_GEN = 3
DATA_PATH = ""
DATASET = 'MNIST'
PARTITIONS = 1000
PARTITION_DIR = "D:\SUMMER_2022\PROJECT\PredictionsFromAggregations\data\interim\lstm"
MODEL_STORE = "D:\SUMMER_2022\PROJECT\PredictionsFromAggregations\models\v1"
SEED = 8008

In [4]:
name, data = retrieve_dataset(DATASET, None)
x_train, x_test, y_train, y_test = data
dataset = Dataset(name, x_train, x_test, y_train, y_test)

In [5]:
mean, shape, dataset = aggregate(dataset, PARTITIONS, PARTITION_DIR, SEED)

In [6]:
if len(shape) == 3:
    a, b, c = shape
    d = 1
else:
    a, b, c, d = shape

n_classes = len(np.unique(dataset.y_train))


if len(dataset.x_train.shape) == 3:
    x_train = np.expand_dims(dataset.x_train, axis=-1)
    x_test = np.expand_dims(dataset.x_test, axis=-1)
else:
    x_train = dataset.x_train
    x_test = dataset.x_test
y_train = tfk.utils.to_categorical(dataset.y_train, n_classes)
y_test = tfk.utils.to_categorical(dataset.y_test, n_classes)

x_train, x_val, y_train, y_val = sk.model_selection.train_test_split(x_train, y_train, test_size=0.1, random_state=42)
"""I present the current worst function in the codebase"""
tf_convert = lambda x, y, types : (tf.data.Dataset.from_tensor_slices((tf.cast(x, types[0]), tf.cast(y, types[1])))).shuffle(BUFFER).batch(BATCH_SIZE, drop_remainder=True).cache().prefetch(tf.data.AUTOTUNE)

train_set = tf_convert(x_train, y_train, [tf.float32, tf.uint8])
test_set = tf_convert(x_test, y_test, [tf.float32, tf.uint8])
val_set = tf_convert(x_val, y_val, [tf.float32, tf.uint8])

In [7]:
config = generator_config(b*c*d, 10, n_classes, 4, None, None)
models ={
    3 : epsilon_3_model,
    5 : epsilon_5_model
}
try:
    model = models[N_GEN](config)
except KeyError:
    print("No model matched n_gen value: {}".format(N_GEN))
    exit

optim = tfk.optimizers.Adam(learning_rate=LEARNING_RATE)
loss_fn = tfk.losses.CategoricalCrossentropy()

In [None]:
print('Training Model for {} epochs'.format(EPOCHS))

model.compile(optimizer=optim, loss=loss_fn, metrics=[tfk.metrics.categorical_accuracy()])

In [9]:
print('Training Model for {} epochs'.format(EPOCHS))

results = Result(defaultdict(list), {}, defaultdict(list), defaultdict(list))

train_acc_metric = tfk.metrics.CategoricalAccuracy()
val_acc_metric = tfk.metrics.CategoricalAccuracy()

for epoch in range(EPOCHS):
    print('Epoch {}...'.format(epoch))
    for step, (x_batch, y_batch) in enumerate(train_set): 
        with tf.GradientTape() as tape:
            pred = model(x_batch)
            loss_value = loss_fn(y_batch, pred)
        results.history[epoch].append(loss_value)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optim.apply_gradients(zip(grads, model.trainable_weights))
        train_acc_metric.update_state(y_batch, pred)

    train_acc = train_acc_metric.result()
    results.acc_score[epoch].append(train_acc)
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    train_acc_metric.reset_states()

    if step % BATCH_SIZE == 0:
        print(
            "Training loss (for one batch) at step %d: %.4f"
            % (step, float(loss_value))
        )

    for x_batch, y_batch in val_set:
        val_pred = model(x_batch, training=False)
        val_acc_metric.update_state(y_batch, val_pred)
    val_acc = val_acc_metric.result()
    results.val_acc_score[epoch] = val_acc
    val_acc_metric.reset_states()

test_acc_metric = tfk.metrics.CategoricalAccuracy()

for x_batch, y_batch in test_set:
    test_pred = model(x_batch, training=False)
    test_acc_metric.update_state(y_batch, test_pred)
test_acc = test_acc_metric.result()
results.test_acc = test_acc

Training Model for 1 epochs
Epoch 0...
(16, 10) (16, 10)
(16, 10) (16, 10)


KeyboardInterrupt: 