In [1]:
import utils
import pathlib
utils.mount_src()

from data_loader import DataLoader

CONFIG = utils.load_config()
DATA_PATH = CONFIG["images_path"]
MODEL_PATH = pathlib.Path("../models")
EPOCHS = 20
INPUT_SHAPE = (224, 224, 3)

In [2]:
LOSS = {"year_output": "mse", "lat_output": "mse", "lon_output": "mse"}
METRICS = {"year_output": ["mae"], "lat_output": ["mae"], "lon_output": ["mae"]}

In [3]:
from tensorflow import keras

def model_traning_pipeline(model: keras.Model, data_path: str, epochs=10, target_size=(224, 224), callbacks=None):
    data_loader = DataLoader(data_path)
    train_dataset, test_dataset = data_loader.dataset(target_size=target_size)
    
    history = model.fit(
        train_dataset,
        epochs=epochs,
        validation_data=test_dataset,
        callbacks=callbacks
    )
    return {"model": model, "history": history}

In [4]:
def train_models(ms: list[keras.Model], names: list[str]) -> dict[dict]:
    out = {}
    for m, n in zip(ms, names):
        callbacks = [
            keras.callbacks.ModelCheckpoint(
                filepath=MODEL_PATH / f"{n}.keras",
                save_best_only=True,
                monitor="val_year_output_mse",
                mode="min"
            ),
            keras.callbacks.ModelCheckpoint(
                filepath=MODEL_PATH / f"{n}.keras",
                save_best_only=True,
                monitor="val_lat_output_mse",
                mode="min"
            ),
            keras.callbacks.ModelCheckpoint(
                filepath=MODEL_PATH / f"{n}.keras",
                save_best_only=True,
                monitor="val_lon_output_mse",
                mode="min"
            )
        ]
        model_training_results = model_traning_pipeline(m, DATA_PATH, epochs=EPOCHS, callbacks=callbacks)
        out[n] = model_training_results
    return out

In [None]:
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras import layers
import tensorflow as tf

def create_efficient_net() -> keras.Model:
    base_model = EfficientNetB3(weights="imagenet", include_top=False, input_shape=INPUT_SHAPE)
    base_model.trainable = False

    inputs = layers.Input(shape=INPUT_SHAPE)
    x = layers.Rescaling(1./255)(inputs)
    x = base_model(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, activation="relu")(x)

    output_year = tf.keras.layers.Dense(1, name="year", activation="sigmoid")(x)
    output_lat = tf.keras.layers.Dense(1, name="lat", activation="sigmoid")(x)
    output_lon = tf.keras.layers.Dense(1, name="lon", activation="sigmoid")(x)

    return tf.keras.Model(inputs=inputs, outputs={"year": output_year, "lat": output_lat, "lon": output_lon})

efficient_net = create_efficient_net()
efficient_net.compile(
    optimizer="adam",
    loss=LOSS,
    metrics=METRICS
)
efficient_net.summary()

In [None]:
from tensorflow.keras.applications import ResNet50

def create_res_net():
    base_model = ResNet50(weights="imagenet", include_top=False, input_shape=INPUT_SHAPE)
    base_model.trainable = False

    inputs = layers.Input(shape=INPUT_SHAPE)
    x = layers.Rescaling(1./255)(inputs)
    x = base_model(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, activation="relu")(x)

    output_year = tf.keras.layers.Dense(1, name="year", activation="sigmoid")(x)
    output_lat = tf.keras.layers.Dense(1, name="lat", activation="sigmoid")(x)
    output_lon = tf.keras.layers.Dense(1, name="lon", activation="sigmoid")(x)

    return tf.keras.Model(inputs=inputs, outputs={"year": output_year, "lat": output_lat, "lon": output_lon})

res_net = create_res_net()
res_net.compile(
    optimizer="adam",
    loss=LOSS,
    metrics=METRICS
)
res_net.summary()

In [None]:
from tensorflow.keras.applications import Xception

def create_xception():
    base_model = Xception(weights="imagenet", include_top=False, input_shape=INPUT_SHAPE)
    base_model.trainable = False

    inputs = layers.Input(shape=INPUT_SHAPE)
    x = layers.Rescaling(1./255)(inputs)
    x = base_model(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, activation="relu")(x)

    output_year = tf.keras.layers.Dense(1, name="year", activation="sigmoid")(x)
    output_lat = tf.keras.layers.Dense(1, name="lat", activation="sigmoid")(x)
    output_lon = tf.keras.layers.Dense(1, name="lon", activation="sigmoid")(x)

    return tf.keras.Model(inputs=inputs, outputs={"year": output_year, "lat": output_lat, "lon": output_lon})

xception = create_xception()
xception.compile(
    optimizer="adam",
    loss=LOSS,
    metrics=METRICS
)
    

In [8]:
names = ["efficient_net", "res_net", "xception"]

In [None]:
models = [efficient_net, res_net, xception]
results = train_models(models, names)