In [7]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import tensorflow_addons as tfa
import warnings
warnings.filterwarnings('ignore')


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 

 The versions of TensorFlow you are currently using is 2.16.1 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


ModuleNotFoundError: No module named 'keras.src.engine'

In [6]:
# !pip install --upgrade tensorflow

In [10]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
from tensorflow.keras.utils import plot_model

# Define Mish activation function
class Mish(layers.Layer):
    def __init__(self):
        super(Mish, self).__init__()

    def call(self, inputs):
        return inputs * tf.math.tanh(tf.math.softplus(inputs))

# Define ResNet block with Mish activation
def resnet_block(inputs, filters, strides=(1, 1)):
    x = layers.Conv2D(filters, (3, 3), strides=strides, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = Mish()(x)
    x = layers.Conv2D(filters, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    shortcut = inputs
    if strides != (1, 1) or inputs.shape[-1] != filters:
        shortcut = layers.Conv2D(filters, (1, 1), strides=strides, padding='same')(inputs)
        shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Add()([x, shortcut])
    x = Mish()(x)
    return x

# Define Elastic block (just as an example, it could be a drop-in replacement for resnet_block)
def elastic_block(inputs, filters, strides=(1, 1)):
    x = layers.Conv2D(filters, (3, 3), strides=strides, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.001))(inputs)
    x = layers.BatchNormalization()(x)
    x = Mish()(x)
    x = layers.Conv2D(filters, (3, 3), padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.001))(x)
    x = layers.BatchNormalization()(x)
    
    shortcut = inputs
    if strides != (1, 1) or inputs.shape[-1] != filters:
        shortcut = layers.Conv2D(filters, (1, 1), strides=strides, padding='same')(+inputs)
        shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Add()([x, shortcut])
    x = Mish()(x)
    return x

# Define ResNet-18 architecture with Mish activation and Elastic block
def ResNet18(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(64, (7, 7), strides=(2, 2), padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = Mish()(x)
    x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)

    x = resnet_block(x, 64)
    x = elastic_block(x, 64)

    x = resnet_block(x, 128, strides=(2, 2))
    x = resnet_block(x, 128)

    x = elastic_block(x, 256, strides=(2, 2))
    x = resnet_block(x, 256)

    x = resnet_block(x, 512, strides=(2, 2))
    x = elastic_block(x, 512)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs, x)
    return model

# Create and compile the model
model = ResNet18(input_shape=(224, 224, 3), num_classes=1000)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Visualize the model structure
plot_model(model, to_file='model_structure.png', show_shapes=True, show_layer_names=True)

# Set up TensorBoard callback
log_dir = "logs/fit/"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

# Dummy data for demonstration purposes
x_train = np.random.rand(10, 224, 224, 3)
y_train = tf.keras.utils.to_categorical(np.random.randint(10, size=(10, 1)), num_classes=1000)

# Train the model
model.fit(x_train, y_train, epochs=1, callbacks=[tensorboard_callback])

# Launch TensorBoard
%load_ext tensorboard
%tensorboard --logdir logs/fit

You must install pydot (`pip install pydot`) for `plot_model` to work.
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 34s/step - accuracy: 0.0000e+00 - loss: 8.7857
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 10344), started 0:46:35 ago. (Use '!kill 10344' to kill it.)