In [1]:
import tensorflow as tf
import numpy as np
import os
import random

# Define the list of classes
classes = ['보쌈', '볶음면', '비빔밥', '빵']

# Define the image size
img_size = 256


# Define the paths for the training and validation sets
train_path = r'/content/drive/MyDrive/음식/'
valid_path = r'/content/drive/MyDrive/음식/'

# Define the batch size and number of epochs
batch_size = 32
num_epochs = 8

# Load the ResNet50 model
base_model = tf.keras.applications.ResNet50(input_shape=(img_size, img_size, 3), include_top=False, weights='imagenet')

# Freeze the layers in the base model
base_model.trainable = False

# Add a classification head to the model
model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(len(classes), activation='softmax')
])

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Create data generators for the training and validation sets
train_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

valid_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255
)

train_generator = train_data_gen.flow_from_directory(
    train_path,
    target_size=(img_size, img_size),
    batch_size=batch_size,
    class_mode='categorical'
)

valid_generator = valid_data_gen.flow_from_directory(
    valid_path,
    target_size=(img_size, img_size),
    batch_size=batch_size,
    class_mode='categorical'
)

# Train the model
history = model.fit(
    train_generator,
    epochs=num_epochs,
    validation_data=valid_generator
)

# Save the trained model in the models directory
model.save('food_classifier_resnet50.js')

# Evaluate the model on the validation set
loss, accuracy = model.evaluate(valid_generator)

# Print the validation accuracy
print('Validation accuracy:', accuracy)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
Found 400 images belonging to 4 classes.
Found 400 images belonging to 4 classes.
Epoch 1/8
Epoch 2/8
Epoch 3/8
Epoch 4/8
Epoch 5/8
Epoch 6/8
Epoch 7/8
Epoch 8/8




Validation accuracy: 0.5375000238418579
