## Setup

In [None]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D, Flatten
from tensorflow.keras.models import Model
import tensorflow as tf

print(tf.__version__)

## MNIST Dataset Load

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

## Make TF Dataset

In [None]:
def make_datasets(x, y):
    # (28, 28) -> (28, 28, 1)
    def _new_axis(x, y):
        y = tf.one_hot(y, depth = 10)
        
        return x[..., tf.newaxis], y
            
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    ds = ds.map(_new_axis, num_parallel_calls = tf.data.experimental.AUTOTUNE)
    ds = ds.shuffle(100).batch(32) # 배치 크기 조절하세요
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    
    return ds
    
ds = make_datasets(x_train, y_train)

## Make Models

In [None]:
# rescaling, 1 / 255
preprocessing_layer = tf.keras.models.Sequential([
        tf.keras.layers.experimental.preprocessing.Rescaling(1./255),
    ])

# simple CNN model
def get_model():
    inputs = Input(shape = (28, 28, 1))
    preprocessing_inputs = preprocessing_layer(inputs)
    
    x = Conv2D(filters = 32, kernel_size = (3, 3), activation='relu')(preprocessing_inputs)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(filters = 64, kernel_size = (3, 3), activation='relu')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(filters = 64, kernel_size =(3, 3), activation='relu')(x)
    
    x = Flatten()(x)
    x = Dense(64, activation = 'relu')(x)
    outputs = Dense(10, activation = 'softmax')(x)
    
    model = Model(inputs = inputs, outputs = outputs)
    
    return model

model = get_model()
model.summary()

## Training with Gradient Accumulation

In [None]:
epochs = 10
num_accum = 4 # 누적 횟수

loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits = True)
optimizer = tf.keras.optimizers.Adam()
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()

In [None]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x)
        loss_value = loss_fn(y, logits)
    gradients = tape.gradient(loss_value, model.trainable_weights)
    
    # update metrics    
    train_acc_metric.update_state(y, logits)
    
    return gradients, loss_value

def train():
    for epoch in range(epochs):
        print(f'################ Start of epoch: {epoch} ################')
        # 누적 gradient를 담기 위한 zeros_like 선언
        accumulation_gradients = [tf.zeros_like(ele) for ele in model.trainable_weights]
        
        for step, (batch_x_train, batch_y_train) in enumerate(ds):
            gradients, loss_value = train_step(batch_x_train, batch_y_train)
            
            if step % num_accum == 0:
                accumulation_gradients = [grad / num_accum for grad in accumulation_gradients]
                optimizer.apply_gradients(zip(gradients, model.trainable_weights))

                # zero-like init
                accumulation_gradients = [tf.zeros_like(ele) for ele in model.trainable_weights]
            else:
                accumulation_gradients = [(accum_grad + grad) for accum_grad, grad in zip(accumulation_gradients, gradients)]

            if step % 100 == 0:
                print(f"Loss at Step: {step} : {loss_value:.4f}")
            
        train_acc = train_acc_metric.result()
        print(f'Accuracy : {(train_acc * 100):.4f}%')
        train_acc_metric.reset_states()
        
# start training
train()