In [1]:
import numpy as np
import pandas as pd
import keras
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.applications import VGG19
from sklearn.model_selection import train_test_split
import time
from transformers import TFAutoModelForImageClassification
from tensorflow.keras.layers import Input, Permute
from tensorflow.keras.models import Model
import tensorflow as tf
from transformers import TFAutoModelForImageClassification
from tensorflow.keras.layers import Input, Dense, ReLU, Dropout
from tensorflow.keras.models import Model

In [2]:
(X_train, y_train), (X_test, y_test) = cifar10.load_data()


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [3]:
NUM_CLASSES = len(np.unique(y_train))
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 64

In [4]:
y_train = to_categorical(y_train, NUM_CLASSES)
y_test = to_categorical(y_test, NUM_CLASSES)

In [5]:
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0

In [6]:
X_train,X_val,y_train,y_val=train_test_split(X_train,y_train,test_size=0.2)

In [7]:
data_gen = ImageDataGenerator(horizontal_flip=False)
train_iter = data_gen.flow(
    X_train,
    y_train,
    batch_size=BATCH_SIZE,
    shuffle=True
)
val_iter = data_gen.flow(
    X_val,
    y_val,
    batch_size=BATCH_SIZE,
    shuffle=True
)
test_iter = data_gen.flow(
    X_test,
    y_test,
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [8]:
def resize_generator(data_iter):
    for images, labels in data_iter:
        resized_images = tf.image.resize(images, (224, 224), method=tf.image.ResizeMethod.BILINEAR)
        yield resized_images, labels


gen_train = resize_generator(train_iter)
gen_val = resize_generator(val_iter)
gen_test = resize_generator(test_iter)


In [16]:
model1 = TFAutoModelForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')


for layer in model1.deit.encoder.layer[:11]:
    layer.trainable = False

for layer in model1.deit.encoder.layer[11:]:
    layer.trainable = True

All model checkpoint layers were used when initializing TFDeiTForImageClassificationWithTeacher.

All the layers of TFDeiTForImageClassificationWithTeacher were initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDeiTForImageClassificationWithTeacher for predictions without further training.


In [10]:
class DeiTLayer(tf.keras.layers.Layer):
    def __init__(self, deit_model):
        super(DeiTLayer, self).__init__()
        self.deit_model = deit_model

    def call(self, inputs):
        return self.deit_model(inputs).logits


In [17]:
inputs = keras.layers.Input(shape=(224, 224, 3))
x = tf.transpose(inputs, perm=[0, 3, 1, 2])
x = DeiTLayer(model1)(x)
x = keras.layers.Dense(256,activation='elu')(x)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(NUM_CLASSES, activation='softmax')(x)
model = keras.models.Model(inputs, x)

In [12]:
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 tf.compat.v1.transpose (TF  (None, 3, 224, 224)       0         
 OpLambda)                                                       
                                                                 
 dei_t_layer (DeiTLayer)     (None, 1000)              87338192  
                                                                 
 dense (Dense)               (None, 256)               256256    
                                                                 
 dropout (Dropout)           (None, 256)               0         
                                                                 
 dense_1 (Dense)             (None, 10)                2570      
                                                             

In [18]:
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0001),
                loss='categorical_crossentropy',
                metrics=['accuracy'])

In [19]:
training_times = []
validation_times = []


class TimeHistory(tf.keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        epoch_time = time.time() - self.epoch_start_time
        training_times.append(epoch_time)
        print(f"Epoch {epoch+1} training time: {epoch_time:.2f} seconds")

    def on_test_begin(self, logs=None):
        self.test_start_time = time.time()

    def on_test_end(self, logs=None):
        test_time = time.time() - self.test_start_time
        validation_times.append(test_time)
        print(f"\nValidation time: {test_time:.2f} seconds")
time_callback = TimeHistory()

In [20]:
history = model.fit(gen_train,steps_per_epoch=len(X_train) // BATCH_SIZE, validation_data = gen_val , validation_steps=len(X_val) // BATCH_SIZE ,epochs=3, callbacks=[time_callback], verbose=1)

Epoch 1/3
Validation time: 142.63 seconds
Epoch 1 training time: 1375.07 seconds
Epoch 2/3
Validation time: 134.34 seconds
Epoch 2 training time: 1333.98 seconds
Epoch 3/3
Validation time: 134.03 seconds
Epoch 3 training time: 1333.97 seconds


In [24]:
avg_training_time = sum(training_times) / len(training_times)
avg_validation_time = sum(validation_times) / len(validation_times)
print(f'Average training time per epoch: {avg_training_time:.2f} seconds')
print(f'Average validation time per epoch: {avg_validation_time:.2f} seconds')

Average training time per epoch: 1347.67 seconds
Average validation time per epoch: 137.00 seconds


In [21]:
import plotly.express as px

# metric: 'accuracy' or 'loss'
def display_curves(history, metric):
  df = pd.DataFrame(history.history[metric], columns=[metric])
  df['val_'+metric] = history.history['val_'+metric]
  fig = px.line(df, x= df.index+1, y= [metric, 'val_'+metric])
  fig.update_layout(xaxis_title='Epochs', yaxis_title=metric)
  fig.show()

In [22]:
display_curves(history, 'loss')

In [23]:
display_curves(history, 'accuracy')