<a href="https://colab.research.google.com/github/AliBenovaa/IANNwTF_Group24/blob/main/Homework04Group24.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import datetime
import numpy as np
import tqdm 
import matplotlib as plt


# 1. get mnist from tensorflow_datasets
mnist = tfds.load("mnist", split =["train","test"], as_supervised=True)

train_ds = mnist[0]
val_ds = mnist[1]

train_ds2 =mnist[0]
val_ds2 = mnist[1]


# 2. write function to create the dataset that we want
def preprocess(data, batch_size,condition):
    # image should be float
    data = data.map(lambda x, t: (tf.cast(x, float), t))
    # image should be flattened
    data = data.map(lambda x, t: (tf.reshape(x, (-1,)), t))
    # image vector will here have values between -1 and 1
    data = data.map(lambda x,t: ((x/128.)-1., t))
    # we want to have two mnist images in each example
    # this leads to a single example being ((x1,y1),(x2,y2))
    zipped_ds = tf.data.Dataset.zip((data.shuffle(2000), 
                                     data.shuffle(2000)))
    
    if (condition == "bigger_5"):
        # map ((x1,y1),(x2,y2)) to (x1,x2, x1+x2>=5*) *boolean
        zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1]+x2[1]>=5))
        # transform boolean target to int
        zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.cast(t, tf.int32)))
    elif (condition=="minus"):
        # map ((x1,y1),(x2,y2)) to (x1,x2, x1-x2=y1*) *integer
        zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1]-x2[1]))
        # transform boolean target to int
        zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.cast(t, tf.float32)))
    
    # batch the dataset
    zipped_ds = zipped_ds.batch(batch_size)
    # prefetch
    zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)
    return zipped_ds

train_ds = preprocess(train_ds, batch_size=32,condition = "bigger_5") 
val_ds = preprocess(val_ds, batch_size=32, condition="bigger_5" ) 

train_ds2 = preprocess(train_ds2, batch_size=32,condition = "minus") 
val_ds2 = preprocess(val_ds2, batch_size=32, condition="minus" )


#checking
for img1, img2, label in train_ds.take(1):
    print(img1.shape, img2.shape, label.shape)





    

Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to ~/tensorflow_datasets/mnist/3.0.1...


Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]

Dataset mnist downloaded and prepared to ~/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
(32, 784) (32, 784) (32,)


In [17]:
class TwinMNISTModel(tf.keras.Model):

    # 1. constructor
    def __init__(self,optimizer,loss_function,train_ds,test_ds):
        super().__init__()
        # inherit functionality from parent class

        # optimizer, loss function and metrics
        self.metrics_list = [tf.keras.metrics.BinaryAccuracy(),
                             tf.keras.metrics.Mean(name="loss")]
        
        self.optimizer = tf.keras.optimizers.Adam()
        
        self.loss_function = loss_function

      
        
        # layers to be used with activation functions
        self.dense1 = tf.keras.layers.Dense(32, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(32, activation=tf.nn.relu)
        
        self.out_layer = tf.keras.layers.Dense(1,activation=tf.nn.sigmoid)

        
        
    # 2. call method (forward computation)
    def call(self, images, training=False):
        img1, img2 = images
        
        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)
        
        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)
        
        combined_x = tf.concat([img1_x, img2_x ], axis=1)
        
        return self.out_layer(combined_x)

    

    # 3. metrics property
    @property
    def metrics(self):
        return self.metrics_list
        # return a list with all metrics in the model



    # 4. reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()


    #train_step method
    @tf.function
    def train_step(self, data):
        img1, img2, label = data
        
        with tf.GradientTape() as tape:
            output = self((img1, img2), training=True)
            loss = self.loss_function(label, output)
    
            
        gradients = tape.gradient(loss, self.trainable_variables)
        
        self.optimizer.apply_gradients(zip(gradients,self.trainable_variables))
        
        # update the state of the metrics according to loss
        self.metrics[0].update_state(label, output)
        self.metrics[1].update_state(loss)
        
        # return a dictionary with metric names as keys and metric results as values
        return {m.name : m.result() for m in self.metrics}


        

   # 6. test_step method
    @tf.function
    def test_step(self, data):
        img1, img2, label = data
        # same as train step (without parameter updates)
        output = self((img1, img2), training=False)
        loss = self.loss_function(label, output)
        self.metrics[0].update_state(label, output)
        self.metrics[1].update_state(loss)
        
        return {m.name : m.result() for m in self.metrics}




    

In [26]:

def training_loop(subtask,train_ds, test_ds):

  #items for visualization
  train = np.empty(0)
  test = np.empty(0)
  
  
  if subtask == "minus":
    model  = TwinMNISTModel(tf.keras.optimizers.Adam,tf.keras.losses.MeanSquaredError(),train_ds,test_ds)

  elif subtask == "bigger_5":
    model = TwinMNISTModel(tf.keras.optimizers.Adam,tf.keras.losses.BinaryCrossentropy(from_logits=True),train_ds,test_ds)

  for epoch in range(10):
      print(f"Epoch {epoch}:")
      
      # Training:
      
      for data in tqdm.tqdm(train_ds, position=0, leave=True):
          metrics = model.train_step(data)
          
    
          for metric in model.metrics:
              tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)

      # print the metrics
      print([f"{key}: {value.numpy()}" for (key, value) in metrics.items()])
      #experiments: save into array for later 
      train = np.append(train,metrics["loss"])
      #train_acc = np.append(train_acc,metrics["acc"])
      

      # reset all metrics (requires a reset_metrics method in the model)
      model.reset_metrics()    
      
      # Testing:
      for data in test_ds:
          metrics = model.test_step(data)
  
          for metric in model.metrics:
              tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
                  
      print([f"test_{key}: {value.numpy()}" for (key, value) in metrics.items()])
      #experiments: save into array for later 
      test = np.append(test,metrics["loss"])
      #test_acc = np.append(test_acc,metrics["acc"])

      # reset all metrics
      model.reset_metrics()
      print("\n")
  return train,test

In [27]:
import matplotlib.pyplot as plt

# 2. pass arguments to training loop function
training_loop("bigger_5",train_ds,val_ds)
training_loop("minus",train_ds2,val_ds2)






Epoch 0:


100%|██████████| 1875/1875 [00:19<00:00, 97.75it/s] 


['binary_accuracy: 0.840833306312561', 'loss: 0.4720420837402344']
['test_binary_accuracy: 0.8424000144004822', 'test_loss: 0.47081053256988525']


Epoch 1:


100%|██████████| 1875/1875 [00:20<00:00, 91.42it/s] 


['binary_accuracy: 0.8409000039100647', 'loss: 0.47234994173049927']
['test_binary_accuracy: 0.8411999940872192', 'test_loss: 0.47190871834754944']


Epoch 2:


100%|██████████| 1875/1875 [00:20<00:00, 91.51it/s] 


['binary_accuracy: 0.8424500226974487', 'loss: 0.4708012342453003']
['test_binary_accuracy: 0.833299994468689', 'test_loss: 0.4799957871437073']


Epoch 3:


100%|██████████| 1875/1875 [00:20<00:00, 91.51it/s] 


['binary_accuracy: 0.8418833613395691', 'loss: 0.4713667035102844']
['test_binary_accuracy: 0.8355000019073486', 'test_loss: 0.47779932618141174']


Epoch 4:


100%|██████████| 1875/1875 [00:40<00:00, 45.77it/s] 


['binary_accuracy: 0.8412833213806152', 'loss: 0.47197553515434265']
['test_binary_accuracy: 0.8431000113487244', 'test_loss: 0.4702114760875702']


Epoch 5:


100%|██████████| 1875/1875 [00:21<00:00, 85.80it/s] 


['binary_accuracy: 0.8423166871070862', 'loss: 0.47093334794044495']
['test_binary_accuracy: 0.8406999707221985', 'test_loss: 0.4725078046321869']


Epoch 6:


100%|██████████| 1875/1875 [00:21<00:00, 86.27it/s] 


['binary_accuracy: 0.840399980545044', 'loss: 0.47284993529319763']
['test_binary_accuracy: 0.8359000086784363', 'test_loss: 0.47730013728141785']


Epoch 7:


100%|██████████| 1875/1875 [00:20<00:00, 91.51it/s] 


['binary_accuracy: 0.8402000069618225', 'loss: 0.47304990887641907']
['test_binary_accuracy: 0.8421000242233276', 'test_loss: 0.47120988368988037']


Epoch 8:


100%|██████████| 1875/1875 [00:19<00:00, 97.50it/s] 


['binary_accuracy: 0.8405333161354065', 'loss: 0.4727165699005127']
['test_binary_accuracy: 0.8388000130653381', 'test_loss: 0.47450461983680725']


Epoch 9:


100%|██████████| 1875/1875 [00:21<00:00, 86.44it/s] 


['binary_accuracy: 0.8407166600227356', 'loss: 0.47253331542015076']
['test_binary_accuracy: 0.8406999707221985', 'test_loss: 0.4728073179721832']


Epoch 0:


100%|██████████| 1875/1875 [00:40<00:00, 45.76it/s] 


['binary_accuracy: 0.11694999784231186', 'loss: 14.519569396972656']
['test_binary_accuracy: 0.12280000001192093', 'test_loss: 14.250250816345215']


Epoch 1:


100%|██████████| 1875/1875 [00:19<00:00, 98.51it/s] 


['binary_accuracy: 0.12470000237226486', 'loss: 14.1488618850708']
['test_binary_accuracy: 0.1371999979019165', 'test_loss: 13.735511779785156']


Epoch 2:


100%|██████████| 1875/1875 [00:17<00:00, 107.29it/s]


['binary_accuracy: 0.12781666219234467', 'loss: 14.09836483001709']
['test_binary_accuracy: 0.12309999763965607', 'test_loss: 14.16234016418457']


Epoch 3:


100%|██████████| 1875/1875 [00:17<00:00, 104.77it/s]


['binary_accuracy: 0.1287333369255066', 'loss: 14.098254203796387']
['test_binary_accuracy: 0.1290999948978424', 'test_loss: 13.978679656982422']


Epoch 4:


100%|██████████| 1875/1875 [00:18<00:00, 104.06it/s]


['binary_accuracy: 0.13178333640098572', 'loss: 14.132402420043945']
['test_binary_accuracy: 0.1306000053882599', 'test_loss: 14.199557304382324']


Epoch 5:


100%|██████████| 1875/1875 [00:20<00:00, 91.51it/s] 


['binary_accuracy: 0.13243333995342255', 'loss: 14.081136703491211']
['test_binary_accuracy: 0.13619999587535858', 'test_loss: 13.93023681640625']


Epoch 6:


100%|██████████| 1875/1875 [00:17<00:00, 108.66it/s]


['binary_accuracy: 0.1336333304643631', 'loss: 13.997218132019043']
['test_binary_accuracy: 0.12939999997615814', 'test_loss: 14.125307083129883']


Epoch 7:


100%|██████████| 1875/1875 [00:20<00:00, 91.51it/s] 


['binary_accuracy: 0.13623332977294922', 'loss: 14.030960083007812']
['test_binary_accuracy: 0.13099999725818634', 'test_loss: 14.196012496948242']


Epoch 8:


100%|██████████| 1875/1875 [00:19<00:00, 96.16it/s] 


['binary_accuracy: 0.13663333654403687', 'loss: 14.093245506286621']
['test_binary_accuracy: 0.1331000030040741', 'test_loss: 14.09048843383789']


Epoch 9:


100%|██████████| 1875/1875 [00:20<00:00, 91.50it/s] 


['binary_accuracy: 0.13893333077430725', 'loss: 14.075468063354492']
['test_binary_accuracy: 0.14350000023841858', 'test_loss: 13.903180122375488']




(array([14.5195694 , 14.14886189, 14.09836483, 14.0982542 , 14.13240242,
        14.0811367 , 13.99721813, 14.03096008, 14.09324551, 14.07546806]),
 array([14.25025082, 13.73551178, 14.16234016, 13.97867966, 14.1995573 ,
        13.93023682, 14.12530708, 14.1960125 , 14.09048843, 13.90318012]))