In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from sklearn.model_selection import train_test_split
from utils import load_galaxy_data

import app


input_data, labels = load_galaxy_data()

print(input_data.shape, labels.shape)

x_train, x_test, y_train, y_test = train_test_split(input_data, labels, test_size=0.2, random_state=222, stratify=labels, shuffle=True)

generator = ImageDataGenerator(rescale=1.0/128)

train_iterator = generator.flow(x_train, y_train, batch_size = 5)
test_iterator = generator.flow(x_test, y_test, batch_size = 5)

model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=(128,128,3)))
model.add(tf.keras.layers.Conv2D(8,(3,3), strides=2, activation='relu', padding='valid'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2), strides=2))
model.add(tf.keras.layers.Conv2D(8,(3,3), strides=2, activation='relu', padding='valid'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2), strides=2))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(16, activation='relu'))
model.add(tf.keras.layers.Dense(4, activation='softmax'))

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=.001), loss=[tf.keras.losses.CategoricalCrossentropy()], metrics=[tf.keras.metrics.CategoricalAccuracy(), tf.keras.metrics.AUC()])

print(model.summary())
batch_size = 16
model.fit(train_iterator, steps_per_epoch=len(x_train)/batch_size, epochs=8, validation_data=test_iterator, validation_steps=len(x_test)/batch_size)

