In [14]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from timm.models import vit_base_patch16_224

# Load pre-trained models
efficientnet = tf.keras.applications.EfficientNetB0(weights='imagenet', include_top=False)
vit_model = vit_base_patch16_224(pretrained=True)

# Freeze pre-trained model weights (optional)
for layer in efficientnet.layers:
  layer.trainable = False
for layer in vit_model.parameters():
  layer.requires_grad = False  # Adjust if fine-tuning is needed

# Extract features
def extract_features(x):
  efficientnet_out = efficientnet(x)
  vit_out = vit_model.patch_embed(x)  # Extract features from patch embedding of ViT
  return efficientnet_out, vit_out

# Concatenate features
def forward(x):
  efficientnet_features, vit_features = extract_features(x)
  features = tf.keras.layers.concatenate([efficientnet_features, vit_features])
  # ... rest of the classifier (same as previous example) ...
  x = tf.keras.layers.Dense(64, activation='relu')(features)
  output = tf.keras.layers.Dense(1, activation='sigmoid')(x)
  return output

# Model definition (corrected input shape)
model = tf.keras.Model(inputs=efficientnet.input, outputs=forward(tf.keras.Input(shape=(224, 224, 3))))

# Data augmentation (replace with your data paths)
train_datagen = ImageDataGenerator(rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
validation_datagen = ImageDataGenerator(rescale=1./255)

# Preprocess data function
def preprocess_data(image):
  # Your data preprocessing logic here (e.g., convert to RGB if necessary)
  image = tf.image.resize(image, (224, 224))  # Example resizing
  return image

# Apply preprocessing to data generators
train_generator = (
    (preprocess_data(x[0]), x[1]) for x in train_datagen.flow_from_directory(
        'path/to/training/data/',
        target_size=(224, 224),
        batch_size=32,
        class_mode='binary'  # Adjust class_mode based on your classification problem
    )
)

validation_generator = (
    (preprocess_data(x[0]), x[1]) for x in validation_datagen.flow_from_directory(
        'path/to/validation/data/',
        target_size=(224, 224),
        batch_size=32,
        class_mode='binary'  # Adjust class_mode based on your classification problem
    )
)

# Early stopping (optional)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

# Compile and train the model (replace with your desired epochs and patience)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.fit(
    train_generator,
    epochs=10,  # Adjust based on your data and needs
    validation_data=validation_generator,
    steps_per_epoch=len(train_generator),  # Adjust based on your data size
    validation_steps=len(validation_generator),  # Adjust based on your data size
    callbacks=[early_stopping]  # Add early stopping callback (optional)
)

# Save the model (optional)
model.save('leaf_disease_detection_model.h5')


AssertionError: Input width (3) doesn't match model (224).