In [None]:
import os
import json
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow import keras
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
import warnings
from tqdm import tqdm
import numpy as np
from skimage.color import label2rgb
from skimage.io import imread, imsave
from skimage.segmentation import slic
import platform
import sys

Check if TensorFlow is building with GPU support

In [None]:
print (f"Python Platform: {platform.platform ()}")
print (f"Tensor Flow Version: {tf.__version__}")
print(f"Keras Version: {keras.__version__}")
print ()

print (f"Python {sys.version}")
gpu = len (tf.config.list_physical_devices ('GPU'))>0
print ("GPU is", "available" if gpu else "NOT available")

Setup necessary global variables, modify and or add if needed

In [None]:
project_directory = os.path.dirname(os.getcwd())
data_directory = os.path.join(project_directory, "data") # Original directory
dataset_directory = os.path.join(project_directory, "tf_data") # Processed directory
model_directory = os.path.join(project_directory, "models")

solutions_csv = pd.read_csv(os.path.join(data_directory, "training_solutions_rev1.csv"))

# Model parameters
model_name = "model_v1" # Change this to a unique name
epochs = 5
learning_rate = 0.001
batch_size = 32
random_seat = 123

# Image parameters
img_height = 200
img_width = 200

# Data selection from the solutions csv
data_selection = {
    'Class8.1': 0.35,
    'Class8.2': 0.35,
    'Class8.3': 0.35,
    'Class8.4': 0.35,
    'Class8.5': 0.35,
    'Class8.6': 0.35,
    'Class8.7': 0.35,
}
num_classes = len(data_selection)

Setup and prepare data

In [None]:
def get_center_pixel(segment):
    h, w = segment.shape
    return segment[h // 2, w // 2]

def alter_image(image):
    image_segment = slic(image, n_segments=10, compactness=50)
    
    new_image = image[..., 0]

    for segment_id in np.unique(image_segment):
        if segment_id != get_center_pixel(image_segment):
            new_image[image_segment == segment_id] = 0

    return label2rgb(new_image, image, kind='avg', colors='gray')

def process_image(row, directory):
    image_name = f'{int(row["GalaxyID"])}.jpg'
    src_path = os.path.join(data_directory, "images_training_rev1", image_name)
    dst_path = os.path.join(directory, image_name)
    
    image = imread(src_path)
    altered_image = alter_image(image)
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        imsave(dst_path, altered_image, check_contrast=False)

def setup_dataset(directory: str, dataset: pd.DataFrame):
    try:
        os.removedirs(directory)
    except Exception:
        pass
    os.makedirs(directory, exist_ok=True)
    
    total_images = len(dataset)
    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(process_image, row, directory) for index, row in dataset.iterrows()]
        for future in tqdm(futures, total=total_images, desc='Setting up dataset'):
            future.result()
    
    print(f'Finished setting up dataset in {directory}')

# Remove all sub directories
try:
    shutil.rmtree(dataset_directory)
except FileNotFoundError:
    pass
except Exception as e:
    print(f"An error occurred while trying to remove the directory: {e}")

for name, per in data_selection.items():
    print(f"Setting up dataset for {name.lower()}")
    setup_dataset(os.path.join(dataset_directory, name.lower()), solutions_csv.where(solutions_csv[name] > per).dropna())

In [None]:
train_ds, val_ds = image_dataset_from_directory(
    dataset_directory,
    validation_split=0.2,
    subset="both",
    seed=random_seat,
    image_size=(img_height, img_width),
    batch_size=batch_size,
    color_mode='grayscale'
)

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip('horizontal'),
  tf.keras.layers.RandomRotation(0.2),
  tf.keras.layers.RandomZoom(0.2),
  tf.keras.layers.RandomContrast(0.2),
  tf.keras.layers.RandomBrightness(0.2),
  tf.keras.layers.GaussianNoise(0.1),
])

In [None]:
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

model = tf.keras.Sequential([
    data_augmentation,
    tf.keras.layers.Rescaling(1./255, input_shape=(img_height, img_width, 1)),
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs)

model.save(os.path.join(model_directory, f'{model_name}.keras'))
history_dict = history.history
with open(os.path.join(model_directory, f'{model_name}_history.json'), 'w') as f:
    json.dump(history_dict, f)