In [1]:
import pickle

# Load the augmented images and labels from the file
with open('../local_data/augmented_data.pickle', 'rb') as f:
    augmented_images, augmented_labels = pickle.load(f)

In [19]:
augmented_labels[0:5]

[0, 1, 1, 1, 0]

In [2]:
import tensorflow as tf
import numpy as np
from keras.utils import to_categorical

augmented_images = np.array(augmented_images, dtype=np.float32)
augmented_labels = np.array(augmented_labels, dtype=np.int32)
augmented_labels = to_categorical(augmented_labels)

from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(augmented_images, augmented_labels, test_size=0.2, random_state=42)

In [3]:
augmented_labels[0:5]

array([[1., 0.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [1., 0.]])

In [14]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, Conv2D, MaxPooling2D, Normalization, Flatten
from tensorflow.keras.models import Model

base_model = tf.keras.applications.ResNet50(
  input_shape=(512, 512, 3), 
  include_top=False, 
  weights='imagenet',
  pooling=None,
  classes = 2,
  classifier_activation='softmax'
  )

base_model.trainable = False

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
x = Flatten()(x)
x = Dropout(0.5)(x)

predictions = Dense(2, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 0us/step


In [15]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [16]:
model.fit(X_train, y_train, epochs=5, validation_data=(X_val, y_val))

Epoch 1/5
[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m228s[0m 4s/step - accuracy: 0.5462 - loss: 2.3499 - val_accuracy: 0.5800 - val_loss: 0.6585
Epoch 2/5
[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m221s[0m 4s/step - accuracy: 0.5367 - loss: 0.7889 - val_accuracy: 0.5575 - val_loss: 0.6771
Epoch 3/5
[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m227s[0m 5s/step - accuracy: 0.5570 - loss: 0.7104 - val_accuracy: 0.5925 - val_loss: 0.6503
Epoch 4/5
[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m225s[0m 5s/step - accuracy: 0.5846 - loss: 0.6737 - val_accuracy: 0.6000 - val_loss: 0.6416
Epoch 5/5
[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m213s[0m 4s/step - accuracy: 0.5639 - loss: 0.6728 - val_accuracy: 0.5125 - val_loss: 0.6814


<keras.src.callbacks.history.History at 0x238089c5690>

In [43]:
model.save('../models/pre_model_green_4.h5')

