In [10]:
import os
import cv2
import numpy as np
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

def load_images_and_labels(data_root_path, target_size=(224, 224)):
    image_data = []
    image_labels = []
    class_names = []

    for domain_dir in os.listdir(data_root_path):
        domain_path = os.path.join(data_root_path, domain_dir)
        if os.path.isdir(domain_path):  
            for class_dir in os.listdir(domain_path):
                class_path = os.path.join(domain_path, class_dir)
                if os.path.isdir(class_path):  
                    for img_file in os.listdir(class_path):
                        img_path = os.path.join(class_path, img_file)
                        img = cv2.imread(img_path)
                        if img is not None:
                            resized_img = cv2.resize(img, target_size)
                            image_data.append(resized_img)
                            image_labels.append(class_dir)
                            if class_dir not in class_names:
                                class_names.append(class_dir)

    image_data = np.array(image_data)
    image_labels = np.array(image_labels)
    class_to_index = {name: idx for idx, name in enumerate(class_names)}
    image_labels = np.array([class_to_index[label] for label in image_labels])

    return image_data, image_labels, class_names

def split_data_by_class(image_data, image_labels, test_ratio=0.5):
    train_data, train_labels, test_data, test_labels = [], [], [], []

    for label in np.unique(image_labels):
        indices = np.where(image_labels == label)[0]
        label_data, label_labels = image_data[indices], image_labels[indices]
        data_train, data_test, labels_train, labels_test = train_test_split(
            label_data, label_labels, test_size=test_ratio, random_state=42
        )
        train_data.extend(data_train)
        train_labels.extend(labels_train)
        test_data.extend(data_test)
        test_labels.extend(labels_test)

    return np.array(train_data), np.array(train_labels), np.array(test_data), np.array(test_labels)



In [11]:
data_root_path = "/Users/yanzhu/Documents/Office31"
image_data, image_labels, class_names = load_images_and_labels(data_root_path)
train_data, train_labels, test_data, test_labels = split_data_by_class(image_data, image_labels, test_ratio=0.5)
base_resnet_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
x = base_resnet_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
output = Dense(len(class_names), activation='softmax')(x)

resnet_model = Model(inputs=base_resnet_model.input, outputs=output)
for layer in base_resnet_model.layers:
    layer.trainable = False
resnet_model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
resnet_model.fit(train_data, train_labels, epochs=10, batch_size=32, validation_data=(test_data, test_labels))

Epoch 1/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 1s/step - accuracy: 0.4973 - loss: 2.0145 - val_accuracy: 0.7822 - val_loss: 0.8658
Epoch 2/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 1s/step - accuracy: 0.8988 - loss: 0.3699 - val_accuracy: 0.8164 - val_loss: 0.7434
Epoch 3/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m61s[0m 1s/step - accuracy: 0.9647 - loss: 0.1533 - val_accuracy: 0.8335 - val_loss: 0.6722
Epoch 4/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 1s/step - accuracy: 0.9756 - loss: 0.1043 - val_accuracy: 0.8445 - val_loss: 0.6836
Epoch 5/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 1s/step - accuracy: 0.9978 - loss: 0.0348 - val_accuracy: 0.8523 - val_loss: 0.6421
Epoch 6/10
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 1s/step - accuracy: 0.9997 - loss: 0.0170 - val_accuracy: 0.8572 - val_loss: 0.6527
Epoch 7/10
[1m57/57[0m [32m━━━━━━━━━━

<keras.src.callbacks.history.History at 0x323501400>