In [2]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models

# Generate dummy data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32') / 255
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype('float32') / 255

class DiffusionModel(models.Model):
    def __init__(self):
        super(DiffusionModel, self).__init__()
        self.conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')
        self.conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')
        self.flatten = layers.Flatten()
        self.dense1 = layers.Dense(128, activation='relu')
        self.dense2 = layers.Dense(10, activation='softmax')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.flatten(x)
        x = self.dense1(x)
        return self.dense2(x)


model = DiffusionModel()
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy'])


model.fit(x_train, y_train, epochs=10, batch_size=32, validation_split=0.1)

test_loss, test_acc = model.evaluate(x_test, y_test)
print("Test accuracy:", test_acc)


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 1us/step
Epoch 1/10
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m114s[0m 66ms/step - accuracy: 0.9242 - loss: 0.2394 - val_accuracy: 0.9865 - val_loss: 0.0488
Epoch 2/10
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m114s[0m 67ms/step - accuracy: 0.9879 - loss: 0.0379 - val_accuracy: 0.9887 - val_loss: 0.0412
Epoch 3/10
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m111s[0m 66ms/step - accuracy: 0.9941 - loss: 0.0184 - val_accuracy: 0.9855 - val_loss: 0.0631
Epoch 4/10
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m110s[0m 65ms/step - accuracy: 0.9964 - loss: 0.0111 - val_accuracy: 0.9877 - val_loss: 0.0517
Epoch 5/10
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m116s[0m 69ms/step - accuracy: 0.9971 - loss: 0.0085 - val_accuracy: 0.9902 - val_l