In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
import numpy as np
import matplotlib.pyplot as plt

# Constants
BATCH_SIZE = 128
EPOCHS = 200
NOISE_DIM = 100
NUM_EXAMPLES_TO_GENERATE = 16

# Load and preprocess MNIST data
(train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize to [-1, 1]

# Create tf.data.Dataset
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(BATCH_SIZE)

# Load the trained teacher generator
teacher_generator = tf.keras.models.load_model('mnist_generator.h5')

In [16]:
teacher_generator.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 12544)             1254400   
                                                                 
 batch_normalization (BatchN  (None, 12544)            50176     
 ormalization)                                                   
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 12544)             0         
                                                                 
 reshape (Reshape)           (None, 7, 7, 256)         0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 7, 7, 128)        524288    
 nspose)                                                         
                                                                 
 batch_normalization_1 (Batc  (None, 7, 7, 128)        5

In [17]:

class StudentGenerator(tf.keras.Model):
    def __init__(self):
        super(StudentGenerator, self).__init__()
        self.dense = layers.Dense(7*7*128, use_bias=False, input_shape=(NOISE_DIM,))
        self.batch_norm1 = layers.BatchNormalization()
        self.leaky_relu1 = layers.LeakyReLU(alpha=0.2)
        self.reshape = layers.Reshape((7, 7, 128))
        self.conv_transpose1 = layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', use_bias=False)
        self.batch_norm2 = layers.BatchNormalization()
        self.leaky_relu2 = layers.LeakyReLU(alpha=0.2)
        self.conv_transpose2 = layers.Conv2DTranspose(1, (4, 4), strides=(2, 2), padding='same', use_bias=False, activation='tanh')

    def call(self, x, training=False):
        x = self.dense(x)
        x = self.batch_norm1(x, training=training)
        x = self.leaky_relu1(x)
        x = self.reshape(x)
        x = self.conv_transpose1(x)
        x = self.batch_norm2(x, training=training)
        x = self.leaky_relu2(x)
        x = self.conv_transpose2(x)
        return x

    def get_intermediate_layers(self, x, training=False):
        layers = []
        x = self.dense(x)
        x = self.batch_norm1(x, training=training)
        x = self.leaky_relu1(x)
        x = self.reshape(x)
        layers.append(x)
        x = self.conv_transpose1(x)
        x = self.batch_norm2(x, training=training)
        x = self.leaky_relu2(x)
        layers.append(x)
        x = self.conv_transpose2(x)
        layers.append(x)
        return layers

def make_discriminator_model():
    model = models.Sequential([
        layers.Conv2D(64, (4, 4), strides=(2, 2), padding='same', input_shape=[28, 28, 1]),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        layers.Flatten(),
        layers.Dense(1)
    ])
    return model

# Create student generator and discriminator
student_generator = StudentGenerator()
discriminator = make_discriminator_model()

In [21]:


# Define loss functions
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
mse = tf.keras.losses.MeanSquaredError()

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

def get_teacher_intermediate_layers(teacher, x):
    layers = []
    x = teacher.layers[0](x)  
    x = teacher.layers[1](x)  
    x = teacher.layers[2](x)  
    x = teacher.layers[3](x)  
    x = teacher.layers[4](x)  
    x = teacher.layers[5](x)  
    x = teacher.layers[6](x) 
    layers.append(x) 
    x = teacher.layers[7](x)  
    x = teacher.layers[8](x)
    x = teacher.layers[9](x)
    layers.append(x)
    x = teacher.layers[10](x)
    layers.append(x)
    return layers

def distillation_loss(teacher_layers, student_layers):
    loss = 0
    for t_layer, s_layer in zip(teacher_layers, student_layers):
        if t_layer.shape[-1] != s_layer.shape[-1]:
            s_layer = layers.Conv2D(t_layer.shape[-1], (1, 1), padding='same')(s_layer)
        loss += mse(t_layer, s_layer)
    return loss

# Define optimizers
student_generator_optimizer = optimizers.Adam(1e-4)
discriminator_optimizer = optimizers.Adam(1e-4)


In [None]:


# Define training step
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Generate images
        teacher_generated_images = teacher_generator(noise, training=False)
        student_generated_images = student_generator(noise, training=True)

        # Get intermediate layers
        teacher_layers = get_teacher_intermediate_layers(teacher_generator, noise)
        student_layers = student_generator.get_intermediate_layers(noise, training=True)

        # Discriminator outputs
        real_output = discriminator(images, training=True)
        fake_output = discriminator(student_generated_images, training=True)

        # Calculate losses
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
        distill_loss = distillation_loss(teacher_layers, student_layers)

        # Combine generator and distillation loss
        total_gen_loss = 0.5 * gen_loss + 0.5 * distill_loss

    # Calculate gradients
    gradients_of_generator = gen_tape.gradient(total_gen_loss, student_generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    # Apply gradients
    student_generator_optimizer.apply_gradients(zip(gradients_of_generator, student_generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss, distill_loss

# Training loop
def train(dataset, epochs):
    for epoch in range(epochs):
        for image_batch in dataset:
            gen_loss, disc_loss, distill_loss = train_step(image_batch)
        
        # Print losses every 10 epochs
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}, Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}, Distill Loss: {distill_loss:.4f}")
        
        # Generate and save images every 20 epochs
        if (epoch + 1) % 20 == 0:
            generate_and_save_images(student_generator, epoch + 1, "student")
            generate_and_save_images(teacher_generator, epoch + 1, "teacher")

# Function to generate and save images
def generate_and_save_images(model, epoch, prefix):
    test_input = tf.random.normal([NUM_EXAMPLES_TO_GENERATE, NOISE_DIM])
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
    plt.savefig(f'{prefix}_image_at_epoch_{epoch:04d}.png')
    plt.close()

# Train the model
train(train_dataset, EPOCHS)

# Save the trained student generator
student_generator.save('mnist_student_generator.h5')

print("Knowledge distillation completed and student generator model saved.")

# Evaluate the student model
def evaluate_models(teacher, student, num_samples=1000):
    noise = tf.random.normal([num_samples, NOISE_DIM])
    teacher_images = teacher(noise, training=False)
    student_images = student(noise, training=False)
    mse_loss = tf.reduce_mean(tf.keras.losses.MSE(teacher_images, student_images))
    print(f"Mean Squared Error between teacher and student generated images: {mse_loss:.4f}")

evaluate_models(teacher_generator, student_generator)

In [23]:
# Save only the weights of the student generator
student_generator.save_weights('mnist_student_generator_weights.h5')
