In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout, LayerNormalization, Add, Layer,BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, TensorBoard

# Define the CNN part of the model
def create_cnn(input_shape, num_classes=8):
    input = Input(shape=input_shape)
    x = Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same')(input)
    x = Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D(pool_size=(2,2))(x)
    x = Dropout(0.4)(x)

    x = Conv2D(filters=384,kernel_size=(3,3),activation='relu',padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D(pool_size=(2,2))(x)
    x = Dropout(0.4)(x)
    
    x = Conv2D(filters=192,kernel_size=(3,3),activation='relu',padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D(pool_size=(2,2))(x)
    x = Dropout(0.4)(x)


    x = Conv2D(filters=384,kernel_size=(3,3),activation='relu',padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D(pool_size=(2,2))(x)
    x = Dropout(0.4)(x)

    x = Flatten()(x)
    x = Dense(256,activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    
    x = Dense(num_classes,activation='softmax')(x)
    
    return Model(input, x)

# Custom layer for patch extraction and embedding
class PatchEmbedding(Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEmbedding, self).__init__()
        self.num_patches = num_patches
        self.projection_dim = projection_dim
        self.projection = Dense(projection_dim)

    def build(self, input_shape):
        num_patches = (input_shape[1] // self.num_patches) ** 2
        self.position_embedding = self.add_weight(
            name="position_embedding",
            shape=[num_patches, self.projection_dim],
            initializer=tf.keras.initializers.RandomNormal(),
        )

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.num_patches, self.num_patches, 1],
            strides=[1, self.num_patches, self.num_patches, 1],
            rates=[1, 1, 1, 1],
            padding='VALID'
        )
        patch_dim = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dim])
        patches = self.projection(patches)
        patches += self.position_embedding
        return patches

# Define the Vision Transformer part of the model
def create_vit(input_shape, num_patches, projection_dim, transformer_layers):
    inputs = Input(shape=input_shape)
    patches = PatchEmbedding(num_patches, projection_dim)(inputs)
    
    x = patches
    for _ in range(transformer_layers):
        x1 = LayerNormalization(epsilon=1e-6)(x)
        attention_output = tf.keras.layers.MultiHeadAttention(
            num_heads=8, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        x2 = Add()([attention_output, x])
        x3 = LayerNormalization(epsilon=1e-6)(x2)
        feedforward_output = tf.keras.Sequential([
            Dense(2048, activation=tf.nn.gelu),
            Dense(projection_dim)
        ])(x3)
        x = Add()([feedforward_output, x2])
    
    x = Flatten()(x)
    return Model(inputs, x)

input_shape = (48, 48, 1)

cnn_model = create_cnn(input_shape)
vit_model = create_vit(input_shape, num_patches=4, projection_dim=64, transformer_layers=4)

combined_input = Input(shape=input_shape)
cnn_output = cnn_model(combined_input)
vit_output = vit_model(combined_input)

# Flatten outputs before concatenation
cnn_output_flattened = cnn_output
vit_output_flattened = vit_output

combined_output = tf.keras.layers.concatenate([cnn_output_flattened, vit_output_flattened])
final_output = Dense(8, activation='softmax')(combined_output)

model = Model(combined_input, final_output)

# Define the paths to the dataset
train_dir = '/kaggle/input/google-fer-image-format/train'
val_dir = '/kaggle/input/google-fer-image-format/val'

# Data Augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=10,
    zoom_range=0.1,
    horizontal_flip=True
)

val_datagen = ImageDataGenerator(rescale=1./255)

# Load the training and validation data
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(48, 48),
    color_mode='grayscale',
    batch_size=64,
    class_mode='sparse'
)

val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(48, 48),
    color_mode='grayscale',
    batch_size=64,
    class_mode='sparse'
)

lr = 1e-3
optimizer=tf.keras.optimizers.Adam(learning_rate=lr)

model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
reduce_lr = ReduceLROnPlateau(monitor='val_loss',factor=0.1,patience=5,min_lr=0.000001)   
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1)
# Define the log directory for TensorBoard
log_dir = "./log"
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)

# Train the model
model.fit(train_generator, 
          validation_data=val_generator, 
          epochs=100, 
          callbacks=[tensorboard_callback, early_stopping, reduce_lr])

# Evaluate the model
test_loss, test_accuracy = model.evaluate(val_generator, verbose=1)
print(f'Test accuracy: {test_accuracy}')
model.save("./cnn-vit_model.h5")

2024-06-18 19:20:03.966978: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-18 19:20:03.967108: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-18 19:20:04.094332: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Found 32645 images belonging to 8 classes.
Found 8166 images belonging to 8 classes.
Epoch 1/100


  self._warn_if_super_not_called()
I0000 00:00:1718738513.984999     122 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
W0000 00:00:1718738514.028718     122 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
W0000 00:00:1718738514.030429     122 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
W0000 00:00:1718738514.033090     122 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


[1m316/511[0m [32m━━━━━━━━━━━━[0m[37m━━━━━━━━[0m [1m1:15[0m 390ms/step - accuracy: 0.2286 - loss: 3.0524

W0000 00:00:1718738637.289717     122 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
W0000 00:00:1718738637.291133     122 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


[1m511/511[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 352ms/step - accuracy: 0.2517 - loss: 2.6809

W0000 00:00:1718738697.280861     123 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
W0000 00:00:1718738743.128714     121 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


[1m511/511[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m312s[0m 451ms/step - accuracy: 0.2518 - loss: 2.6796 - val_accuracy: 0.3734 - val_loss: 1.6291 - learning_rate: 0.0010
Epoch 2/100
[1m511/511[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m122s[0m 237ms/step - accuracy: 0.3500 - loss: 1.6713 - val_accuracy: 0.3691 - val_loss: 1.5909 - learning_rate: 0.0010
Epoch 3/100
[1m511/511[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 243ms/step - accuracy: 0.3648 - loss: 1.6255 - val_accuracy: 0.3973 - val_loss: 1.5424 - learning_rate: 0.0010
Epoch 4/100
[1m511/511[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 243ms/step - accuracy: 0.3860 - loss: 1.5863 - val_accuracy: 0.4198 - val_loss: 1.4993 - learning_rate: 0.0010
Epoch 5/100
[1m511/511[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 243ms/step - accuracy: 0.4038 - loss: 1.5430 - val_accuracy: 0.4259 - val_loss: 1.4905 - learning_rate: 0.0010
Epoch 6/100
[1m511/511[0m [32m━━━━━━━━━━━━━━━━━━━━[0m