In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

In [None]:


def load_data(path, width, height, batch_size):
    ## Podrías añadir más técnicas de aumento
    train_datagen = ImageDataGenerator(rescale=1/255,
                                    horizontal_flip=True,
                                    vertical_flip=True)
                                      
    train_generator = train_datagen.flow_from_directory(path,
                                                        seed=123,
                                                        target_size = (width,height),
                                                        batch_size = batch_size,
                                                        class_mode = "categorical",
                                                        shuffle=True,
                                                        )
    return train_generator
    
def create_resnet_model(input_shape):
    backbone = ResNet50(input_shape=input_shape, weights='imagenet', include_top=False)
    
    # usaremos un modelo previamente entrenado con sus parámetros susceptibles de ser aprendidos intactos
    for layer in backbone.layers:
        layer.trainable = False
    
    # entrenemos la parte de clasificación
    model = Sequential()
    model.add(backbone)
    model.add(GlobalAveragePooling2D())
    # también podrías agregar una capa dense aquí: model.add(Dense(128, activation='relu'))
    model.add(Dense(8, activation='softmax'))
    
    optimizer = Adam(learning_rate=0.001)
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    return model
    
 def train_model(model, train_data, val_data, epochs=10):

    model.fit(
        train_data,
        validation_data=val_data,
        epochs=epochs,
        verbose=1
    )
    return model