In [None]:
import os
import argparse
import logging
from pathlib import Path, PurePath
from collections import defaultdict

from src.models.structures import *
from src.models.intermediate_robust_generator.model import *
from src.models.lstm_based.helper import retrieve_dataset, aggregate

import tensorflow as tf
import tensorflow.keras as tfk
import tensorflow_addons as tfa

import sklearn as sk
import numpy as np

In [None]:
BUFFER = 2048
BATCH_SIZE = 128

In [None]:
K = 1000
EPSILON = [0.001, 0.005, 0.01, 0.05, 0.1]
SEED = 8008

In [None]:
PARTITION_DIR = "D:\SUMMER_2022\PROJECT\PredictionsFromAggregations\data\interim\lstm"
DATASET = 'MNIST'
DIR = "D:\SUMMER_2022\PROJECT\PredictionsFromAggregations\models\v1"

In [None]:
name, data = retrieve_dataset(DATASET)
dataset = Dataset(name, data)
mean, dataset = aggregate(dataset, K, PARTITION_DIR, SEED)

In [None]:

a, b, c, d = dataset.x_train.shape
n_classes = len(np.unique(dataset.y_train))

x_train = dataset.x_train.reshape(a, b*c*d)
y_train = tfk.utils.to_categorical(dataset.y_train, n_classes)
x_test = dataset.x_test.reshape(dataset.x_test.shape[0], b*c*d)
y_test = tfk.utils.to_categorical(dataset.y_test, n_classes)

x_train, x_val, y_train, y_val = sk.model_selection.train_test_split(x_train, y_train, test_size=0.2, random_state=42)
"""I present the current worst function in the codebase"""
tf_convert = lambda x, y, type : (tf.data.Dataset.from_tensor_slices((tf.cast(x, type), tf.cast(y, type)))).shuffle(BUFFER).batch(BATCH_SIZE, drop_remainder=True).cache().prefetch(tf.data.AUTOTUNE)

train_set = tf_convert(x_train, y_train, tf.uint8)
test_set = tf_convert(x_test, y_test, tf.uint8)
val_set = tf_convert(x_val, y_val, tf.uint8)

In [None]:
merger = tfk.layers.BiDirectional(tfk.layers.LSTM(mean, activation='relu', name='merging_layer'))

config = generator_config(b*c*d, 10, n_classes, 4, None, merger)
model = stochastic_model(config)

step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)

optim = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
loss_fn = generator_loss

In [None]:
results = Result(defaultdict(list), {}, defaultdict(list))

train_acc_metric = tfk.metrics.SparseCategoricalAccuracy()
val_acc_metric = tfk.metrics.SparseCategoricalAccuracy()

In [None]:
for epoch in range(args.epochs):
    logger.info('Epoch {}...'.format(epoch))
    for step, (x_batch, y_batch) in enumerate(train_set): 
        with tf.GradientTape() as tape:
            pred = model(x_batch)
            loss_value = loss_fn(y_batch, pred, [gen.dense.kernel for gen in model.generators])
        results.history[epoch].append(loss_value)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optim.apply_gradients(zip(grads, model.trainable_weights))

        train_acc_metric.update_state(y_batch, pred)

    train_acc = train_acc_metric.result()
    results.acc_score[epoch].append(train_acc)
    logger.info("Training acc over epoch: %.4f" % (float(train_acc),))

    train_acc_metric.reset_states()

    if step % BATCH_SIZE == 0:
        logger.info(
            "Training loss (for one batch) at step %d: %.4f"
            % (step, float(loss_value))
        )

    for x_batch, y_batch in val_set:
        val_pred = model(x_batch, training=False)
        val_acc_metric.update_state(y_batch, val_pred)
    val_acc = val_acc_metric.result()
    results.val_acc_score[epoch] = val_acc
    val_acc_metric.reset_states()