In [None]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
#!ls /content/drive/MyDrive/Blackboxes

In [None]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

input_shape = (32,32,3)
num_labels = 10

def preprocess_images(images, target_shape=input_shape):
    images = images.reshape((images.shape[0], 32, 32, 3)).astype('float32') / 255.
    return images

train_images = preprocess_images(train_images)
test_images = preprocess_images(test_images)

train_size = train_images.shape[0]
test_size = test_images.shape[0]
batch_size = 128

train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(train_size).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(test_size).batch(batch_size)

In [None]:
class ResNet(tf.keras.Model):
    def __init__(self, model):
        super().__init__()
        self.model1 = model
        self.model2 = model
        self.model3 = model
        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(10)
    
    def call(self, input, *args, **kwargs):
        delta1 = self.model1(input, *args, **kwargs)
        x1 = input + delta1
        delta2 = self.model2(x1, *args, **kwargs)
        x2 = x1 + delta2
        delta3 = self.model3(x2, *args, **kwargs)
        x3 = x2 + delta3
        out = self.dense(self.flatten(x3))
        
        
filters = [32, 64]
kernel_size = [2, 3]
strides = [1,2]


residualBlock = tf.keras.Sequential([
    tf.keras.layers.Conv2D(filters=filters[0], kernel_size=kernel_size[0], strides=strides[0], activation='relu'),
    tf.keras.layers.Conv2D(filters=filters[1], kernel_size=kernel_size[1], strides=strides[1], activation='relu'),
])

resNet = ResNet(residualBlock)

resNet.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])


In [None]:
history = resNet.fit(train_images, train_labels, epochs=50, 
                    validation_data=(test_images, test_labels))

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')

test_loss, test_acc = resNet.evaluate(test_images,  test_labels, verbose=2)

In [None]:
resNet.save('/content/drive/MyDrive/Blackboxes/ResNet_black_box')