In [1]:
import os
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Define the directory
dataset_dir = 'D:/AI Algorithm/cucumber'

# Split data
train_data_gen = ImageDataGenerator(
    rescale=1.0/255,
    validation_split=0.2  # 20% for validation
)

# Training data
train_data = train_data_gen.flow_from_directory(
    dataset_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='training'
)

# Validation and Test data
val_data_gen = ImageDataGenerator(rescale=1.0/255)
val_data, test_data = train_test_split(
    list(range(train_data.n)),
    test_size=0.5,
    random_state=42
)


Found 922 images belonging to 2 classes.


In [2]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model

# Load ResNet50 with pretrained weights
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze base layers
for layer in base_model.layers:
    layer.trainable = False

# Add custom classification layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
x = Dense(128, activation='relu')(x)
predictions = Dense(len(train_data.class_indices), activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)


In [3]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

# Compile the model
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Callbacks
checkpoint = ModelCheckpoint('best_model.h5', monitor='val_accuracy', save_best_only=True)
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
lr_scheduler = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5)

# Train the model
history = model.fit(
    train_data,
    validation_data=val_data,
    epochs=30,
    callbacks=[checkpoint, early_stopping, lr_scheduler]
)


ValueError: Data is expected to be in format `x`, `(x,)`, `(x, y)`, or `(x, y, sample_weight)`, found: (75, 450, 673, 854, 277, 68, 446, 578, 188, 271, 680, 569, 236, 88, 781, 117, 125, 753, 289, 238, 0, 912, 495, 802, 395, 545, 126, 278, 710, 116, 473, 228, 678, 672, 57, 529, 274, 318, 620, 144, 651, 572, 838, 659, 369, 268, 638, 307, 423, 310, 354, 46, 349, 195, 921, 767, 714, 263, 443, 621, 304, 341, 889, 149, 124, 723, 50, 353, 852, 142, 470, 399, 576, 320, 19, 744, 777, 743, 407, 537, 635, 38, 175, 245, 812, 616, 692, 789, 154, 287, 554, 17, 127, 322, 255, 606, 887, 190, 115, 567, 180, 301, 697, 655, 666, 630, 734, 517, 906, 45, 835, 157, 706, 171, 16, 511, 48, 893, 773, 515, 631, 480, 283, 603, 225, 26, 867, 437, 874, 364, 229, 37, 888, 374, 469, 890, 877, 668, 194, 785, 795, 503, 892, 764, 579, 891, 162, 866, 152, 626, 644, 693, 111, 226, 718, 103, 421, 419, 586, 119, 53, 151, 403, 869, 207, 885, 915, 8, 756, 36, 452, 253, 303, 685, 571, 623, 653, 662, 262, 610, 297, 414, 150, 774, 640, 816, 550, 728, 488, 147, 146, 705, 855, 679, 348, 463, 325, 186, 123, 784, 608, 143, 881, 197, 609, 279, 293, 400, 122, 183, 202, 438, 246, 415, 703, 827, 129, 637, 402, 721, 708, 839, 219, 641, 841, 758, 844, 624, 837, 741, 386, 894, 509, 267, 754, 441, 496, 112, 691, 232, 800, 607, 671, 373, 903, 847, 233, 722, 676, 317, 648, 410, 825, 709, 358, 258, 627, 632, 282, 376, 384, 224, 876, 749, 472, 347, 505, 639, 909, 853, 904, 619, 786, 645, 778, 556, 880, 577, 85, 242, 698, 159, 524, 35, 540, 170, 654, 817, 788, 783, 868, 733, 95, 563, 240, 742, 574, 690, 460, 553, 806, 206, 392, 397, 780, 217, 4, 768, 642, 824, 612, 738, 546, 725, 683, 98, 727, 573, 406, 502, 47, 32, 779, 200, 134, 27, 808, 230, 489, 772, 378, 288, 418, 674, 391, 592, 498, 138, 62, 471, 647, 128, 898, 520, 64, 14, 156, 40, 492, 379, 187, 763, 216, 791, 52, 337, 748, 719, 724, 295, 701, 251, 726, 461, 455, 918, 815, 269, 201, 161, 555, 729, 401, 702, 476, 821, 771, 105, 565, 389, 1, 861, 561, 80, 205, 34, 775, 508, 427, 454, 366, 91, 339, 564, 345, 776, 241, 13, 315, 600, 387, 273, 166, 840, 914, 646, 818, 484, 902, 504, 831, 243, 566, 562, 686, 189, 782, 699, 475, 681, 510, 58, 474, 560, 856, 747, 252, 21, 313, 459, 160, 276, 191, 385, 805, 413, 491, 343, 769, 308, 661, 130, 663, 871, 99, 372, 87, 458, 330, 214, 466, 121, 614, 20, 700, 71, 106, 270, 860, 435, 102)

In [4]:
print(type(train_data), type(val_data))


<class 'keras.src.legacy.preprocessing.image.DirectoryIterator'> <class 'list'>


In [6]:
for data in train_data(1):
    print(data)


TypeError: 'DirectoryIterator' object is not callable