In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping
from sklearn.model_selection import train_test_split

2024-06-03 14:05:02.685017: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-03 14:05:02.685130: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-03 14:05:02.808595: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

val_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

train_dir = '/kaggle/input/tomato-village/Variant-a(Multiclass Classification)/train'
val_dir = '/kaggle/input/tomato-village/Variant-a(Multiclass Classification)/val'
test_dir = '/kaggle/input/tomato-village/Variant-a(Multiclass Classification)/test'

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)

val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)


Found 3162 images belonging to 8 classes.
Found 902 images belonging to 8 classes.
Found 461 images belonging to 8 classes.


In [3]:
def cbam_block(cbam_feature, ratio=8):
    cbam_feature = channel_attention(cbam_feature, ratio)
    cbam_feature = spatial_attention(cbam_feature)
    return cbam_feature

def channel_attention(input_feature, ratio=8):
    channel_axis = 1 if tf.keras.backend.image_data_format() == "channels_first" else -1
    channel = input_feature.shape[channel_axis]

    shared_layer_one = tf.keras.layers.Dense(channel // ratio,
                                             activation='relu',
                                             kernel_initializer='he_normal',
                                             use_bias=True,
                                             bias_initializer='zeros')
    shared_layer_two = tf.keras.layers.Dense(channel,
                                             kernel_initializer='he_normal',
                                             use_bias=True,
                                             bias_initializer='zeros')

    avg_pool = tf.keras.layers.GlobalAveragePooling2D()(input_feature)    
    avg_pool = tf.keras.layers.Reshape((1, 1, channel))(avg_pool)
    avg_pool = shared_layer_one(avg_pool)
    avg_pool = shared_layer_two(avg_pool)

    max_pool = tf.keras.layers.GlobalMaxPooling2D()(input_feature)
    max_pool = tf.keras.layers.Reshape((1, 1, channel))(max_pool)
    max_pool = shared_layer_one(max_pool)
    max_pool = shared_layer_two(max_pool)

    cbam_feature = tf.keras.layers.Add()([avg_pool, max_pool])
    cbam_feature = tf.keras.layers.Activation('sigmoid')(cbam_feature)

    if tf.keras.backend.image_data_format() == "channels_first":
        cbam_feature = tf.keras.layers.Permute((3, 1, 2))(cbam_feature)

    return tf.keras.layers.multiply([input_feature, cbam_feature])

def spatial_attention(input_feature):
    kernel_size = 7
    if tf.keras.backend.image_data_format() == "channels_first":
        channel = input_feature.shape[1]
        cbam_feature = tf.keras.layers.Permute((2, 3, 1))(input_feature)
    else:
        channel = input_feature.shape[-1]
        cbam_feature = input_feature

    avg_pool = tf.keras.layers.Lambda(lambda x: tf.keras.backend.mean(x, axis=3, keepdims=True))(cbam_feature)
    max_pool = tf.keras.layers.Lambda(lambda x: tf.keras.backend.max(x, axis=3, keepdims=True))(cbam_feature)
    concat = tf.keras.layers.Concatenate(axis=3)([avg_pool, max_pool])
    cbam_feature = tf.keras.layers.Conv2D(filters=1,
                                          kernel_size=kernel_size,
                                          strides=1,
                                          padding='same',
                                          activation='sigmoid',
                                          kernel_initializer='he_normal',
                                          use_bias=False)(concat)

    if tf.keras.backend.image_data_format() == "channels_first":
        cbam_feature = tf.keras.layers.Permute((3, 1, 2))(cbam_feature)

    return tf.keras.layers.multiply([input_feature, cbam_feature])

input_shape = (224, 224, 3)
base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=Input(shape=input_shape))

x = base_model.output
x = cbam_block(x)
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
predictions = Dense(train_generator.num_classes, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [4]:
model.compile(optimizer=Adam(learning_rate=0.0001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

def scheduler(epoch, lr):
    if epoch < 10:
        return float(lr)
    else:
        return float(lr * tf.math.exp(-0.1))

lr_scheduler = LearningRateScheduler(scheduler)
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

history = model.fit(
    train_generator,
    epochs=50,
    validation_data=val_generator,
    callbacks=[lr_scheduler, early_stopping]
)


Epoch 1/50


  self._warn_if_super_not_called()
I0000 00:00:1717423594.944421      90 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
W0000 00:00:1717423595.064896      90 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 856ms/step - accuracy: 0.4567 - loss: 1.5236

W0000 00:00:1717423683.841443      89 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m175s[0m 1s/step - accuracy: 0.4581 - loss: 1.5200 - val_accuracy: 0.1996 - val_loss: 3.4554 - learning_rate: 1.0000e-04
Epoch 2/50
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 381ms/step - accuracy: 0.7566 - loss: 0.6704 - val_accuracy: 0.1098 - val_loss: 2.5429 - learning_rate: 1.0000e-04
Epoch 3/50
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 394ms/step - accuracy: 0.8268 - loss: 0.4767 - val_accuracy: 0.0920 - val_loss: 2.7902 - learning_rate: 1.0000e-04
Epoch 4/50
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m46s[0m 433ms/step - accuracy: 0.8787 - loss: 0.3348 - val_accuracy: 0.2018 - val_loss: 3.2076 - learning_rate: 1.0000e-04
Epoch 5/50
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 386ms/step - accuracy: 0.8934 - loss: 0.3052 - val_accuracy: 0.1175 - val_loss: 2.4277 - learning_rate: 1.0000e-04
Epoch 6/50
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37

In [5]:
test_loss, test_acc = model.evaluate(test_generator)
print(f"Test accuracy: {test_acc:.2f}")


[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 509ms/step - accuracy: 0.9110 - loss: 0.4518
Test accuracy: 0.94
