### Base model

Our base model is a plain ResNet50 model which is initialized with random weights and therefore has no prior knowledge of any images.

In [None]:
from keras import preprocessing
from keras.applications import ResNet50
import numpy as np
import matplotlib.pyplot as plt
from keras.applications.resnet50 import preprocess_input

In [None]:
random_seed = 42
np.random.seed(random_seed)

In [None]:
train_data_dir = './data_split/train'
test_data_dir = './data_split/test'

img_size = (224, 224)
batch_size = 32

train_ds = preprocessing.image_dataset_from_directory(
    train_data_dir,
    label_mode='categorical',
    image_size=img_size,
    batch_size=batch_size,
    shuffle=True
)

val_ds = preprocessing.image_dataset_from_directory(
    test_data_dir,
    label_mode='categorical',
    image_size=img_size,
    batch_size=batch_size,
    shuffle=False
)

class_names = train_ds.class_names

train_ds = train_ds.map(lambda x, y: (preprocess_input(x), y))
val_ds = val_ds.map(lambda x, y: (preprocess_input(x), y))

In [None]:
# Build model from scratch (random init)
base_model = ResNet50(include_top=True, weights=None, input_shape=(224, 224, 3), classes=len(class_names)) # include_top=True adds the classification layer

base_model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
history = base_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10
)

In [None]:
def plot_loss(history):
    plt.plot(history['loss'], label='loss')
    plt.plot(history['val_loss'], label='val_loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

plot_loss(history.history)

In [None]:
def plot_accuracy(history):
    plt.plot(history['accuracy'], label='accuracy')
    plt.plot(history['val_accuracy'], label='val_accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

plot_accuracy(history.history)

In [None]:
#base_model.save('./models/PlainResNet50.keras')

# TODO: intepretation of the results