In [1]:
!pip install mlflow mlflow[extras]



In [2]:
import tensorflow as tf
from tensorflow.keras import layers, Model

In [34]:
class CustomConvolved(Model):
  def __init__(self, filters, kernel_size, strides, padding, relu = False):
    super(CustomConvolved, self).__init__()
    self.conv = layers.Conv2D(filters, kernel_size, strides, padding)
    self.bn = layers.BatchNormalization()

    self.relu = layers.ReLU() if relu else None

  def call(self, inputs, training = False):
    x = self.conv(inputs)
    x = self.bn(x, training = training)
    if self.relu:
      x = self.relu(x)
    return x

class ResidualBlock(Model):
    def __init__(self, filters, downsample=False):
        super(ResidualBlock, self).__init__()

        stride = 2 if downsample else 1

        self.conv1 = CustomConvolved(filters, kernel_size=3, strides=stride, padding="same", relu=True)
        self.conv2 = CustomConvolved(filters, kernel_size=3, strides=1, padding="same", relu=False) ## We do not downsample the second convolution

        ## If downsample, we use a 1x1 convolution
        self.shortcut = None
        if downsample:
            self.shortcut = CustomConvolved(filters, kernel_size=1, strides=stride, padding="same", relu=False)

            ## THIS IS (1, 16, 16, 128)
        self.relu = layers.ReLU()

    def call(self, inputs, training=False):
        x = self.conv1(inputs, training = training)

        x = self.conv2(x, training = training)

        # Residual connection
        if self.shortcut:
          shortcut = self.shortcut(inputs)

          x += shortcut

        x = self.relu(x)
        return x


In [35]:
class ResNet34(Model):
    def __init__(self, num_classes):
        super(ResNet34, self).__init__()

        # Initial convolutional layer
        self.initial_conv = CustomConvolved(filters=64, kernel_size=7, strides=2, padding="same", relu=True)
        self.max_pool = layers.MaxPooling2D(pool_size=3, strides=2, padding="same")

        # Residual blocks following ResNet-34 architecture
        self.layer1 = self._build_residual_block(64, num_blocks=3, downsample=False)
        self.layer2 = self._build_residual_block(128, num_blocks=4, downsample=True)
        self.layer3 = self._build_residual_block(256, num_blocks=6, downsample=True)
        self.layer4 = self._build_residual_block(512, num_blocks=3, downsample=True)

        # Global Average Pooling and Fully Connected Layer
        self.global_avg_pool = layers.GlobalAveragePooling2D()
        self.fc = layers.Dense(num_classes, activation='softmax')

    def _build_residual_block(self, filters, num_blocks, downsample):
        blocks = []
        blocks.append(ResidualBlock(filters, downsample=downsample))  # First block, handles downsampling if required
        for _ in range(1, num_blocks):
            blocks.append(ResidualBlock(filters, downsample=False))  # Rest without downsampling
        return tf.keras.Sequential(blocks)

    def call(self, inputs, training=False):
        x = self.initial_conv(inputs, training=training)
        x = self.max_pool(x)

        x = self.layer1(x, training=training)
        x = self.layer2(x, training=training)
        x = self.layer3(x, training=training)
        x = self.layer4(x, training=training)

        x = self.global_avg_pool(x)
        x = self.fc(x)
        return x





In [42]:
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# One-hot encode the labels
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)

# Normalize the images
x_train, x_test = x_train / 255.0, x_test / 255.0

# Create a dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

# Shuffle, batch, and prefetch the data for better performance
train_dataset = train_dataset.shuffle(50000).batch(64).prefetch(tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.batch(64).prefetch(tf.data.experimental.AUTOTUNE)




In [46]:
# This is important for TPU usage in Colab
try:
    # Detect TPU and initialize the distribution strategy
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver()  # Detect TPU
    tf.config.experimental_connect_to_cluster(resolver)  # Connect to TPU cluster
    tf.config.set_logical_device_configuration(
        tf.config.list_physical_devices('TPU')[0],
        [tf.config.LogicalDeviceConfiguration(memory_limit=8192)])  # Adjust memory if needed

    strategy = tf.distribute.TPUStrategy(resolver)  # This enables the use of TPU
    print("TPU found")
except ValueError:
    # If there's no TPU available, print an error
    strategy = tf.distribute.get_strategy()  # Default to CPU/GPU if TPU isn't available
    print("No TPU found. Using CPU/GPU.")
with strategy.scope():
    model = ResNet34(num_classes=10) ## CIFAR10 HAS 10 CLASSES
    model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
    history = model.fit(train_dataset, epochs=20, validation_data=test_dataset)




No TPU found. Using CPU/GPU.
Epoch 1/20


NotFoundError: Graph execution error:

Detected at node StatefulPartitionedCall defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "/usr/local/lib/python3.11/dist-packages/colab_kernel_launcher.py", line 37, in <module>

  File "/usr/local/lib/python3.11/dist-packages/traitlets/config/application.py", line 992, in launch_instance

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelapp.py", line 712, in start

  File "/usr/local/lib/python3.11/dist-packages/tornado/platform/asyncio.py", line 205, in start

  File "/usr/lib/python3.11/asyncio/base_events.py", line 608, in run_forever

  File "/usr/lib/python3.11/asyncio/base_events.py", line 1936, in _run_once

  File "/usr/lib/python3.11/asyncio/events.py", line 84, in _run

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 499, in process_one

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 730, in execute_request

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/ipkernel.py", line 383, in do_execute

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/zmqshell.py", line 528, in run_cell

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "<ipython-input-46-342298433b57>", line 21, in <cell line: 0>

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 371, in fit

  File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 219, in function

  File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 132, in multi_step_on_iterator

could not find registered transfer manager for platform Host -- check target linkage
	 [[{{node StatefulPartitionedCall}}]] [Op:__inference_multi_step_on_iterator_189768]

In [None]:
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f'Test accuracy: {test_acc}')