In [1]:
import tensorflow as tf
from tensorflow.keras.applications import InceptionV3
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import pickle

# Define the directory where batch files are saved
batch_dir = "batches"

# Load the model base
base_model = InceptionV3(weights='imagenet', include_top=False, input_shape=(299, 299, 3))
base_model.trainable = False

# Define the model architecture
inputs = tf.keras.Input(shape=(299, 299, 3))
x = inputs
x = tf.keras.applications.inception_v3.preprocess_input(x)
x = base_model(x, training=False)
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
x = BatchNormalization()(x)
outputs = Dense(128, activation='softmax')(x)

model = tf.keras.Model(inputs, outputs)
model.summary()

# Compile the model
model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(learning_rate=0.001), metrics=['accuracy'])

# Iterate through batch files and train the model
num_batches = 42  # Adjust based on the number of batch files
for batch_index in range(num_batches):
    # Load batch data
    X_batch = pickle.load(open(f"{batch_dir}/X_batch_{batch_index}.pickle", "rb"))
    y_batch = pickle.load(open(f"{batch_dir}/y_batch_{batch_index}.pickle", "rb"))

    # Split the data
    x_train_batch, x_test_batch, y_train_batch, y_test_batch = train_test_split(X_batch, y_batch, test_size=0.3, stratify=y_batch)

    # Train the model with the current batch
    model.fit(x_train_batch, y_train_batch, batch_size=16, epochs=10)

# Save the trained model
model.save('128dogs.h5')





Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 299, 299, 3)]     0         
                                                                 
 tf.math.truediv (TFOpLambd  (None, 299, 299, 3)       0         
 a)                                                              
                                                                 
 tf.math.subtract (TFOpLamb  (None, 299, 299, 3)       0         
 da)                                                             
                                                                 
 inception_v3 (Functional)   (None, 8, 8, 2048)        21802784  
                                                                 
 global_average_pooling2d (  (None, 2048)              0         
 GlobalAveragePooling2D)                                         
                                                          

KeyboardInterrupt: 