In [None]:
# custom model and custom loop for mnist,fashion_mnist and cifer dataset

# Model loading
import pandas as pd
from matplotlib import pyplot as plt
import tensorflow as tf
df = pd.read_csv('/content/sample_data/mnist_train_small.csv')[:30000]

def convert_and_shuflle(df,Target_class):
    y = df[Target_class]
    X = df.drop(Target_class,axis=1)
    X = X/255.0
    X_tens = tf.convert_to_tensor(X,dtype=tf.float32)
    y_tens = tf.convert_to_tensor(y,dtype=tf.float32)
    dataset = tf.data.Dataset.from_tensor_slices((X_tens,y_tens))
    dataset = dataset.shuffle(buffer_size = 2500, seed=40).batch(32)
    len_ = tf.data.experimental.cardinality(dataset)
    test_size = len_ // 5
    test_set = (dataset.take(test_size))
    valid_set = dataset.skip(test_size).take(test_size)
    train_set = (dataset.skip(2*test_size))

    train_set = train_set.prefetch(1)
    test_set = test_set.cache()
    valid_set = valid_set.cache()
    return (train_set,test_set,valid_set)

(train_set,test_set,valid_set) = convert_and_shuflle(df,'6')

# custom model Making
class Model(tf.keras.Model):
  def __init__(self):
    super().__init__()
    self.dense1 = tf.keras.layers.Dense(128,activation='relu',kernel_regularizer=None)
    self.dense2 = tf.keras.layers.Dense(64,activation='relu',kernel_regularizer=None)
    self.dense3 = tf.keras.layers.Dense(32,activation='relu',kernel_regularizer=None)
    self.dense4 = tf.keras.layers.Dense(10,activation='softmax')
    self.dropout = tf.keras.layers.Dropout(0.4)
    self.batch_norm1 = tf.keras.layers.BatchNormalization()
    self.batch_norm2 = tf.keras.layers.BatchNormalization()
    self.flatten = tf.keras.layers.Flatten()
  def gaussian_noise(self,x):
    return x + tf.random.normal(tf.shape(x),stddev=0.1)
  def call(self,x):
    x = self.flatten(x)
    x = self.gaussian_noise(x)
    x = self.dense1(x)
    x = self.batch_norm1(x)
    x= self.dropout(x)
    x = self.dense2(x)
    x = self.batch_norm2(x)
    x = self.dropout(x)
    x = self.dense3(x)
    return self.dense4(x)

    pass

# // --- Custom Loop ---\\

class Custom_Loop:
  def __init__(self,model,optimizer_,loss_fn,train_set,valid_set,metrics_):
    self.model = model
    self.optimizer = optimizer_
    self.loss_fn = loss_fn
    self.train_set = train_set
    self.valid_set = valid_set
    self.train_loss_history = []
    self.valid_loss_history = []
    self.train_acc_history = []
    self.valid_acc_history = []
    self.train_loss_metric = tf.keras.metrics.Mean()
    self.valid_loss_metric = tf.keras.metrics.Mean()
    self.train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
    self.valid_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

  @tf.function
  def training_step(self,x,y,x_valid,y_valid):

      y_vaild_pred = self.model(x_valid)
      loss_valid = self.loss_fn(y_valid,y_vaild_pred)
      with tf.GradientTape() as tape:
        y_pred = self.model(x)
        loss = self.loss_fn(y,y_pred) + tf.reduce_sum(self.model.losses)
      trainable_var = self.model.trainable_variables

      grds = tape.gradient(loss,trainable_var)

      optimizer.apply_gradients(zip(grds,trainable_var))

      self.train_loss_metric.update_state(loss)
      self.valid_loss_metric.update_state(loss_valid)
      self.train_acc_metric.update_state(y,y_pred)
      self.valid_acc_metric.update_state(y_valid,y_vaild_pred)


  def loop(self, epochs, steps_per_epoch):
        self.best_loss = float('inf')
        self.paitance = 5
        for epoch in range(epochs):
            train_iter = iter(self.train_set.repeat())
            valid_iter = iter(self.valid_set.repeat())

            for step in range(steps_per_epoch):
                x, y = next(train_iter)
                x_valid, y_valid = next(valid_iter)
                self.training_step(x, y, x_valid, y_valid)

                vl_loss = self.valid_loss_metric.result()
                if vl_loss < self.best_loss:
                    self.best_loss = vl_loss
                    self.paitance = 5
                else:
                    self.paitance -= 1
            self.train_loss_history.append(self.train_loss_metric.result())
            self.valid_loss_history.append(self.valid_loss_metric.result())
            self.train_acc_history.append(self.train_acc_metric.result())
            self.valid_acc_history.append(self.valid_acc_metric.result())

            # Print epoch results
            print(f"Epoch {epoch+1} - "
                  f"Train Loss: {self.train_loss_metric.result():.4f}, "
                  f"Val Loss: {self.valid_loss_metric.result():.4f}, "
                  f"Train Acc: {self.train_acc_metric.result():.4f}, "
                  f"Val Acc: {self.valid_acc_metric.result():.4f}")

            # Reset metrics
            self.train_loss_metric.reset_state()
            self.valid_loss_metric.reset_state()
            self.train_acc_metric.reset_state()
            self.valid_acc_metric.reset_state()

  def history(self):
    train_loss = [float(x.numpy()) for x in self.train_loss_history]
    valid_loss = [float(x.numpy()) for x in self.valid_loss_history]
    train_acc = [float(x.numpy()) for x in self.train_acc_history]
    valid_acc = [float(x.numpy()) for x in self.valid_acc_history]
    return (train_loss,valid_loss,train_acc,valid_acc)




loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.AdamW(learning_rate=0.001,weight_decay=1e-4)
metrics = tf.keras.metrics.SparseCategoricalAccuracy()
model = Model()
loop = Custom_Loop(model,optimizer,loss,train_set,valid_set,metrics)
loop.loop(42,300)

train_loss,valid_loss,train_acc,valid_acc = loop.history()

plt.figure(figsize=(15,15))
plt.subplot(2,2,1)
plt.plot(train_loss,label='train_loss')
plt.plot(valid_loss,label='valid_loss')
plt.legend()
plt.subplot(2,2,2)
plt.plot(train_acc,label='train_acc')
plt.plot(valid_acc,label='valid_acc')
plt.legend()
plt.subplot(2,2,3)
plt.plot(train_loss,label ='train_loss')
plt.plot(train_acc,label ='train_acc')
plt.legend()
plt.subplot(2,2,4)
plt.plot(valid_loss,label ='valid_loss')
plt.plot(valid_acc,label ='valid_acc')

plt.legend()
plt.show()
