In [1]:
import os

import tensorflow as tf
import numpy as np

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.utils import Sequence

from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import load_img, img_to_array

In [2]:
TARGET_SIZE = (256, 144)

In [3]:
def parse_filename(filename):
    components = filename.split("_")
    offset = 0
    
    if "-" in components[1]:
        offset = 1
    
    if len(components) < 6 + offset:
        return False, []
   
    x = int(components[1 + offset])
    y = int(components[2 + offset])
    z = int(components[3 + offset])
    r = int(components[4 + offset])
    is_flying = int(components[5 + offset].split(".")[0])

    return True, [x, y, z, r, is_flying]

In [None]:
image_data = []
label_data = []

for date_folder in os.listdir("data"):
    date_folder_path = os.path.join("data", date_folder)

    if os.path.isdir(date_folder_path):
        for filename in os.listdir(date_folder_path):
            if filename.endswith(".png"):
                image_path = os.path.join(date_folder_path, filename)
                valid, components = parse_filename(filename)
                
                if valid:
                    image_data.append(image_path)
                    label_data.append(components)

In [None]:
print(image_data.shape, label_data.shape)

In [None]:
label_data = np.array(label_data)

label_data = tf.keras.utils.to_categorical(label_data, num_classes=3)
label_data = label_data.reshape(label_data.shape[0], -1)

def preprocess_image(image_path):
    image = load_img(image_path, target_size=TARGET_SIZE) 
    image = img_to_array(image) / 255.0

    return image

image_data = np.array([preprocess_image(image_path) for image_path in image_data])
image_train_data, image_validation_data, label_train_data, label_validation_data = train_test_split(image_data, label_data, test_size=0.2, random_state=42)

In [None]:
print(image_train_data.shape, image_validation_data.shape, label_train_data.shape, label_validation_data.shape)

In [None]:
model = Sequential()

model.add(Conv2D(32, (3, 3), activation="relu", input_shape=(TARGET_SIZE[0], TARGET_SIZE[1], 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation="relu"))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(128, (3, 3), activation="relu"))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(64, activation='relu'))
model.add(Dense(15, activation='softmax'))

In [None]:
class DataGenerator(Sequence):
    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        
        return batch_x, batch_y

In [None]:
training_generator = DataGenerator(image_train_data, label_train_data, 64)
validating_generator = DataGenerator(image_validation_data, label_validation_data, 64)

In [None]:
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

model.fit(training_generator, epochs=10, validation_data=validating_generator)

In [None]:
def array_to_components(array):
    # turn array [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]

In [None]:
import cv2

def predict_image(filename):
    test_image = cv2.imread(filename)
    test_image = cv2.resize(test_image, TARGET_SIZE)
    test_image = np.expand_dims(test_image, axis=0)

    print(test_image.shape)

    # reverse the X and Y axes of the shape
    test_image = np.swapaxes(test_image, 1, 2)

    test_image = test_image.astype("float32") / 255.0
    prediction = model.predict(test_image)[0]

    # round each prediction to the nearestint
    prediction = np.round(prediction)

    return prediction


In [None]:
print(predict_image("data/2023-07-30_14-25-37/14-30-44_0_0_0_0_2.png"))