In [None]:
import tensorflow as tf
from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, AveragePooling2D,
                                   GlobalAveragePooling2D, BatchNormalization,
                                   Activation, Add, Concatenate, Dense, Dropout,
                                   Multiply, Reshape, Permute)
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras import layers, models

In [None]:
# 1. Upload kaggle.json file
from google.colab import files
files.upload()  # Upload kaggle.json here

# 2. Make directory and move kaggle.json
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# 3. Install kaggle package if not installed
!pip install kaggle

# 4. Download PlantVillage dataset
# Dataset URL on Kaggle: https://www.kaggle.com/datasets/emmarex/plantdisease

!kaggle datasets download -d abdallahalidev/plantvillage-dataset
!unzip -q plantvillage-dataset.zip -d /content/plantvillage

# 6. List extracted folders and files
!ls /content/plantvillage/plantvillage dataset



Saving kaggle.json to kaggle.json
Dataset URL: https://www.kaggle.com/datasets/abdallahalidev/plantvillage-dataset
License(s): CC-BY-NC-SA-4.0
Downloading plantvillage-dataset.zip to /content
 99% 2.02G/2.04G [00:17<00:00, 261MB/s]
100% 2.04G/2.04G [00:17<00:00, 123MB/s]
ls: cannot access '/content/plantvillage/plantvillage': No such file or directory
ls: cannot access 'dataset': No such file or directory


In [None]:
import os

print(os.listdir('/content/plantvillage/plantvillage dataset/color'))

['Strawberry___healthy', 'Corn_(maize)___Northern_Leaf_Blight', 'Tomato___Bacterial_spot', 'Corn_(maize)___Common_rust_', 'Tomato___Late_blight', 'Tomato___Early_blight', 'Soybean___healthy', 'Peach___Bacterial_spot', 'Strawberry___Leaf_scorch', 'Pepper,_bell___healthy', 'Apple___Black_rot', 'Tomato___Target_Spot', 'Tomato___Tomato_mosaic_virus', 'Cherry_(including_sour)___Powdery_mildew', 'Raspberry___healthy', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Pepper,_bell___Bacterial_spot', 'Orange___Haunglongbing_(Citrus_greening)', 'Blueberry___healthy', 'Tomato___Septoria_leaf_spot', 'Grape___healthy', 'Peach___healthy', 'Apple___Cedar_apple_rust', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Tomato___healthy', 'Apple___healthy', 'Cherry_(including_sour)___healthy', 'Tomato___Leaf_Mold', 'Potato___Early_blight', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Squash___Powdery_mildew', 'Potato___healthy', 'Corn_(maize)___hea

In [None]:

import tensorflow as tf

dataset_train = tf.keras.preprocessing.image_dataset_from_directory(
    '/content/plantvillage/plantvillage dataset/color',
    validation_split=0.5,
    subset="training",
    seed=123,
    image_size=(64, 64),
    batch_size=32,
)

dataset_val = tf.keras.preprocessing.image_dataset_from_directory(
    '/content/plantvillage/plantvillage dataset/color',
    validation_split=0.5,
    subset="validation",
    seed=123,
    image_size=(64, 64),
    batch_size=32,
)



Found 54305 files belonging to 38 classes.
Using 27153 files for training.
Found 54305 files belonging to 38 classes.
Using 27152 files for validation.


In [None]:
# Step 2: Normalize images
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = dataset_train.map(lambda x, y: (normalization_layer(x), y))
val_ds = dataset_val.map(lambda x, y: (normalization_layer(x), y))

In [None]:


def simple_attention(x, reduction_ratio=8):
    """Simple Channel Attention Module for final outputs"""
    # Only channel attention - lighter than CBAM
    channel = x.shape[-1]

    # Global Average Pooling
    gap = GlobalAveragePooling2D()(x)
    gap = layers.Reshape((1, 1, channel))(gap)

    # Simple MLP with reduction
    attention = layers.Dense(channel // reduction_ratio, activation='relu')(gap)
    attention = layers.Dense(channel, activation='sigmoid')(attention)

    # Apply attention
    return Multiply()([x, attention])

In [None]:

def residual_block(x, filters, stride=1):
    """Residual block with optional stride"""
    shortcut = x

    x = Conv2D(filters, 3, strides=stride, padding="same", use_bias=False)(x)
    x = BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = Conv2D(filters, 3, padding="same", use_bias=False)(x)
    x = BatchNormalization()(x)

    # Adjust shortcut if needed
    if stride != 1 or shortcut.shape[-1] != filters:
        shortcut = Conv2D(filters, 1, strides=stride, use_bias=False)(shortcut)
        shortcut = BatchNormalization()(shortcut)

    x = Add()([x, shortcut])
    x = layers.ReLU()(x)
    return x
def residual_block_group(x, filters, n_blocks, stride=1):
    x = residual_block(x, filters, stride=stride)
    for _ in range(1, n_blocks):
        x = residual_block(x, filters, stride=1)
    return x


def dense_block(x, num_layers, growth_rate):
    """Dense block from DenseNet"""
    concat_features = [x]

    for _ in range(num_layers):
        x = Concatenate()(concat_features)
        out = BatchNormalization()(x)
        out = Activation('relu')(out)
        out = Conv2D(4 * growth_rate, (1, 1), padding='same', use_bias=False)(out)
        out = BatchNormalization()(out)
        out = Activation('relu')(out)
        out = Conv2D(growth_rate, (3, 3), padding='same', use_bias=False)(out)
        concat_features.append(out)

    x = Concatenate()(concat_features)
    return x

def transition_layer(x):
    """Transition layer for DenseNet"""
    x = BatchNormalization()(x)
    x = Conv2D(x.shape[-1] // 2, (1, 1), padding='same', use_bias=False)(x)
    x = AveragePooling2D(pool_size=(2, 2), strides=2)(x)
    return x

def gfa_residual_stream(input_tensor):
    """GFA Residual Stream with attention only on final output"""
    x = Conv2D(64, 3, padding='same', use_bias=False)(input_tensor)
    x = BatchNormalization()(x)
    x = layers.ReLU()(x)

    # Residual blocks without intermediate attention


    x = residual_block_group(x, filters=128, n_blocks=2, stride=1)
    x = residual_block_group(x, filters=256, n_blocks=2, stride=2)
    x = residual_block_group(x, filters=512, n_blocks=2, stride=2)

    # Apply simple attention before global pooling
    x = simple_attention(x)

    return GlobalAveragePooling2D()(x)

def sf_dense_stream(input_tensor, growth_rate=16):
    """SF Dense Stream without attention"""
    x = Conv2D(64, (3, 3), padding='same', use_bias=False)(input_tensor)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2))(x)

    # Dense blocks without any attention
    x = dense_block(x, num_layers=4, growth_rate=growth_rate)
    x = transition_layer(x)

    x = dense_block(x, num_layers=4, growth_rate=growth_rate)
    x = transition_layer(x)

    x = dense_block(x, num_layers=8, growth_rate=growth_rate)

    sf_pool = GlobalAveragePooling2D()(x)
    return sf_pool

def build_derefnet(input_shape=(64, 64, 3), num_classes=38, growth_rate=16):
    """Build DeReFNet for PlantVillage dataset - attention only on residual stream"""
    from tensorflow.keras import Input
    inputs = Input(shape=input_shape)


    # Two parallel streams - attention only on residual stream output
    gfa_out = gfa_residual_stream(inputs)  # Simple attention applied before GAP
    sf_out = sf_dense_stream(inputs, growth_rate)  # No attention

    # Fusion
    fused = Concatenate()([gfa_out, sf_out])
    fused = Dropout(0.5)(fused)

    # Final classification layer
    outputs = Dense(num_classes, activation='softmax', dtype='float32')(fused)

    return Model(inputs, outputs, name='DeReFNet_PlantVillage')

In [None]:

# Build and compile the model for PlantVillage
model = build_derefnet(input_shape=(64,64,3), num_classes=38, growth_rate=16)

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:


model.fit(train_ds, validation_data=val_ds, epochs=10)


Epoch 1/10
[1m849/849[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m372s[0m 351ms/step - accuracy: 0.5254 - loss: 1.7036 - val_accuracy: 0.5537 - val_loss: 1.8303
Epoch 2/10
[1m849/849[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m245s[0m 287ms/step - accuracy: 0.7773 - loss: 0.7271 - val_accuracy: 0.3484 - val_loss: 3.8719
Epoch 3/10
[1m849/849[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m262s[0m 287ms/step - accuracy: 0.8410 - loss: 0.5203 - val_accuracy: 0.5864 - val_loss: 1.7160
Epoch 4/10
[1m849/849[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m274s[0m 323ms/step - accuracy: 0.8703 - loss: 0.4044 - val_accuracy: 0.6456 - val_loss: 1.3446
Epoch 5/10
[1m849/849[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m244s[0m 287ms/step - accuracy: 0.8968 - loss: 0.3239 - val_accuracy: 0.3038 - val_loss: 6.7771
Epoch 6/10
[1m849/849[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m274s[0m 323ms/step - accuracy: 0.9063 - loss: 0.2785 - val_accuracy: 0.7520 - val_loss: 0.9638
Epoc

<keras.src.callbacks.history.History at 0x7dda9c48fe50>

In [None]:
model.fit(train_ds, validation_data=val_ds, epochs=10)

Epoch 1/10
[1m849/849[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m243s[0m 287ms/step - accuracy: 0.9501 - loss: 0.1472 - val_accuracy: 0.6134 - val_loss: 2.3508
Epoch 2/10
[1m849/849[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m244s[0m 287ms/step - accuracy: 0.9551 - loss: 0.1411 - val_accuracy: 0.5268 - val_loss: 3.1176
Epoch 3/10
[1m849/849[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m262s[0m 287ms/step - accuracy: 0.9541 - loss: 0.1375 - val_accuracy: 0.7949 - val_loss: 0.8837
Epoch 4/10
[1m849/849[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m243s[0m 287ms/step - accuracy: 0.9576 - loss: 0.1315 - val_accuracy: 0.8310 - val_loss: 0.7063
Epoch 5/10
[1m849/849[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m262s[0m 287ms/step - accuracy: 0.9646 - loss: 0.1065 - val_accuracy: 0.8440 - val_loss: 0.5752
Epoch 6/10
[1m849/849[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m292s[0m 323ms/step - accuracy: 0.9679 - loss: 0.1004 - val_accuracy: 0.3605 - val_loss: 5.0240
Epoc

<keras.src.callbacks.history.History at 0x7dd98263a590>