<a href="https://colab.research.google.com/github/Swapnadeep1998/Custom_and_Distributed_Tensorflow_Training/blob/main/Customize_Tensorflow_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
from tqdm import tqdm
import itertools
import numpy as np

In [None]:
!pip install -q "tqdm>=4.36.1"


## Define Model

In [None]:
def base_model():
    inputs=keras.Input(shape=(784,),name='clothing')
    x=keras.layers.Dense(64, activation='relu', name='dense_1')(inputs)
    x=keras.layers.Dense(64, activation='relu', name='dense_2')(x)
    outputs=keras.layers.Dense(10, activation='softmax', name='predictions')(x)
    model=keras.Model(inputs,outputs)
    return model
    

## Data Pipeline

In [None]:
train_data=tfds.load("fashion_mnist", split="train")
test_data=tfds.load("fashion_mnist", split="test")


def format_image(data):
    image=data["image"]
    image=tf.reshape(image,[-1])
    image=tf.cast(image,'float32')
    image=image/255.0
    return image, data["label"]

train_data=train_data.map(format_image)
test_data=test_data.map(format_image)

batch_size=64
train=train_data.shuffle(buffer_size=1024).batch(batch_size)
test=test_data.batch(batch_size=batch_size)

In [None]:
class_names=['T-shirt', 'Pants','Pullover shirt','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle Boot']

## Defining Loss and Optimizer

In [None]:
loss_object=keras.losses.SparseCategoricalCrossentropy()
optimizer=keras.optimizers.Adam()

## Define Metrics

In [None]:
train_acc_metric=tf.keras.metrics.SparseCategoricalAccuracy()
val_acc_metric=tf.keras.metrics.SparseCategoricalAccuracy()

## Define Custom Training Loop

In [None]:
def apply_gradient(optimizer, model, x, y):
    with tf.GradientTape() as tape:
        logits=model(x)
        loss_value=loss_object(y_true=y, y_pred=logits)
    gradients=tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))
    return logits, loss_value

In [None]:
def train_data_for_one_epoch():
    losses=[]
    pbar=tqdm(total=len(list(enumerate(train))),position=0,leave=True,bar_format='{bar}{bar}| {n_fmt}/{total_fmt} ')
    for step, (x_batch_train, y_batch_train) in enumerate(train):
        logits, loss_value=apply_gradient(optimizer, model, x_batch_train, y_batch_train)
        losses.append(loss_value)
        train_acc_metric(y_batch_train, logits)
        pbar.set_description("Training loss for step %s : %.4f" % (int(step), float(loss_value)))
        pbar.update()
    return losses

## Define Validation Function

In [None]:
def perform_validation():
    losses=[]
    for x_val,y_val in test:
        val_logits=model(x_val)
        val_loss = loss_object(y_true=y_val, y_pred=val_logits)
        losses.append(val_loss)
        val_acc_metric(y_val,val_logits)
    return losses


## Model Training

In [None]:
model=base_model()

epochs=10
epochs_val_losses, epochs_train_losses=[],[]
for epoch in range(epochs):
    print('Start of epoch %d' % (epoch,))
    losses_train=train_data_for_one_epoch()
    train_acc=train_acc_metric.result()

    losses_val=perform_validation()
    val_acc=val_acc_metric.result()

    losses_train_mean=np.mean(losses_train)
    losses_val_mean=np.mean(losses_val)
    epochs_val_losses.append(losses_val_mean)
    epochs_train_losses.append(losses_train_mean)
    print(f"train acc: {train_acc}, val_acc: {val_acc}")
    train_acc_metric.reset_states()
    val_acc_metric.reset_states()


Start of epoch 0


████████████████████| 938/938 

train acc: 0.8359827995300293, val_acc: 0.8450999855995178
Start of epoch 1


████████████████████| 938/938 
█████████▉█████████▉| 937/938 

train acc: 0.8668000102043152, val_acc: 0.8604999780654907
Start of epoch 2


█████████▉█████████▉| 937/938 

train acc: 0.878166675567627, val_acc: 0.864799976348877
Start of epoch 3


████████████████████| 938/938 
████████████████████| 938/938 


train acc: 0.8854833245277405, val_acc: 0.8733999729156494
Start of epoch 4


█████████▉█████████▉| 937/938 

train acc: 0.8909500241279602, val_acc: 0.867900013923645
Start of epoch 5


████████████████████| 938/938 
█████████▉█████████▉| 937/938 

train acc: 0.8935666680335999, val_acc: 0.8752999901771545
Start of epoch 6


████████████████████| 938/938 
████████████████████| 938/938 


train acc: 0.8966666460037231, val_acc: 0.8654000163078308
Start of epoch 7


█████████▉█████████▉| 937/938 

train acc: 0.9001166820526123, val_acc: 0.8773999810218811
Start of epoch 8


████████████████████| 938/938 
███████▉  ███████▉  | 749/938 
████████████████████| 938/938 
█████████▉█████████▉| 937/938 

train acc: 0.904116690158844, val_acc: 0.8738999962806702
Start of epoch 9


████████████████████| 938/938 
█████████▉█████████▉| 937/938 

train acc: 0.9049166440963745, val_acc: 0.8792999982833862
