# Setup dependencies

In [36]:
%pip install opencv-python-headless matplotlib lxml numpy

[0mNote: you may need to restart the kernel to use updated packages.


In [52]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten, Conv2D, Input

# Clean data

In [128]:
import PIL
from pathlib import Path
from PIL import UnidentifiedImageError

def aggregate_images(
    dataset_root,
    extensions = ["png", "jpg", "jpeg"],
):
    """
    Globs for images in a given data directory and returns them
    """
    dataset_root = Path(dataset_root)
    image_paths = []

    for extension in extensions:
        image_paths.extend(list(dataset_root.glob("**/*.{}".format(extension))))

    return image_paths

base_dir = "/tf/workspace/machine-learning-adventures/widget-classification" 
label_dir os.path.join(base_dir, "data/label")
img_dir = "/tf/workspace/machine-learning-adventures/widget-classification/data/raw"

path = aggregate_images(os.path.join(img_dir, "button"))
for img_p in path:
    try:
        img = PIL.Image.open(img_p)
    except PIL.UnidentifiedImageError:
            print(img_p)

# Load data

In [53]:
import os
import cv2
import numpy as np
import xml.etree.ElementTree as ET

In [60]:
def parse_annotation(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()

    boxes = []
    for member in root.findall('object'):
        xmin = int(member.find('bndbox/xmin').text)
        ymin = int(member.find('bndbox/ymin').text)
        xmax = int(member.find('bndbox/xmax').text)
        ymax = int(member.find('bndbox/ymax').text)
        label = member.find('name').text

        boxes.append([xmin, ymin, xmax, ymax, label])
    
    return boxes

In [61]:
def load_data(image_dir, annotation_dir):
    images = []
    labels = []

    for image_file in os.listdir(image_dir):
        if image_file.endswith('.jpg'):  # Adjust the file extension as per your dataset
            image_path = os.path.join(image_dir, image_file)
            annotation_path = os.path.join(annotation_dir, image_file.replace('.jpg', '.xml'))  # Adjust the extension

            # Load image
            img = cv2.imread(image_path)
            img = cv2.resize(img, (224, 224))  # Resize to match model input
            images.append(img)

            # Parse XML file for annotations
            boxes = parse_annotation(annotation_path)
            labels.append(boxes)

    return np.array(images), np.array(labels)

# Train model

In [81]:
# Load a pre-trained model (for example, MobileNetV2)
base_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3),
                                               include_top=False,
                                               weights='imagenet')
base_model.trainable = False  # Freeze the base model

In [126]:
# Add custom layers for object detection
inputs = Input(shape=(224, 224, 3))
#x = base_model(inputs, training=False)
Conv2D(filters=32, kernel_size=3, activation='relu', inputs=inputs)
MaxPooling2D(pool_size=2)
Conv2D(filters=64, kernel_size=3, activation='relu')
MaxPooling2D(pool_size=2)(x)
Flatten()
Dense(1024, activation='relu')
#outputs = Dense(5, activation='sigmoid')(x)  # 4 for bounding box, 1 for class

ValueError: Unrecognized keyword arguments passed to Conv2D: {'inputs': <KerasTensor shape=(None, 224, 224, 3), dtype=float32, sparse=None, name=keras_tensor_1286>}

In [116]:
model = Model(inputs, outputs)

In [117]:
# Compile the model
model.compile(optimizer='adam',
              loss='mean_squared_error',  # Adjust the loss function as needed
              metrics=['accuracy'])

In [118]:
image_dir = '/tf/workspace/widget-classification/data/labeled/button'
annotation_dir = '/tf/workspace/widget-classification/data/labeled/button'
x_train, y_train = load_data(image_dir, annotation_dir)

In [119]:
# Normalize images
x_train = x_train / 224.0

In [120]:
label_map = {'button': 0}  # Add more classes as needed
y_train = [[*box[:4], label_map[box[4]]] for boxes in y_train for box in boxes]

In [121]:
print(x_train.shape)
print(x_train.dtype)

(25, 224, 224, 3)
float64


In [122]:
print(np.any(np.isnan(x_train)))  # Should be False
print(np.any(np.isinf(x_train)))  # Should be False

False
False


In [123]:
x_train

array([[[[0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ],
         ...,
         [0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ]],

        [[0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ],
         [0.04910714, 0.04910714, 0.04910714],
         ...,
         [0.04910714, 0.04910714, 0.04910714],
         [0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ]],

        [[0.        , 0.        , 0.        ],
         [0.04910714, 0.04910714, 0.04910714],
         [0.49553571, 0.49553571, 0.49553571],
         ...,
         [0.49553571, 0.49553571, 0.49553571],
         [0.04910714, 0.04910714, 0.04910714],
         [0.        , 0.        , 0.        ]],

        ...,

        [[0.        , 0.        , 0.        ],
         [0.04910714, 0.04910714, 0.04910714]

In [124]:
model.summary()

In [125]:
# Train the model
model.fit(x_train, y_train, epochs=10, batch_size=32)

ValueError: Unrecognized data type: x=[[[[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   ...
   [0.04910714 0.04910714 0.04910714]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   [0.49553571 0.49553571 0.49553571]
   ...
   [0.49553571 0.49553571 0.49553571]
   [0.04910714 0.04910714 0.04910714]
   [0.         0.         0.        ]]

  ...

  [[0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   [0.49553571 0.49553571 0.49553571]
   ...
   [1.09821429 1.09821429 1.09821429]
   [1.13392857 1.13392857 1.13392857]
   [1.13839286 1.13839286 1.13839286]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   ...
   [1.13392857 1.13392857 1.13392857]
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]]]


 [[[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   ...
   [0.04910714 0.04910714 0.04910714]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   [0.49553571 0.49553571 0.49553571]
   ...
   [0.49553571 0.49553571 0.49553571]
   [0.04910714 0.04910714 0.04910714]
   [0.         0.         0.        ]]

  ...

  [[0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   [0.49553571 0.49553571 0.49553571]
   ...
   [1.09821429 1.09821429 1.09821429]
   [1.13392857 1.13392857 1.13392857]
   [1.13839286 1.13839286 1.13839286]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   ...
   [1.13392857 1.13392857 1.13392857]
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]]]


 [[[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   ...
   [0.04910714 0.04910714 0.04910714]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   [0.49553571 0.49553571 0.49553571]
   ...
   [0.49553571 0.49553571 0.49553571]
   [0.04910714 0.04910714 0.04910714]
   [0.         0.         0.        ]]

  ...

  [[0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   [0.49553571 0.49553571 0.49553571]
   ...
   [1.09821429 1.09821429 1.09821429]
   [1.13392857 1.13392857 1.13392857]
   [1.13839286 1.13839286 1.13839286]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   ...
   [1.13392857 1.13392857 1.13392857]
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]]]


 ...


 [[[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   ...
   [0.04910714 0.04910714 0.04910714]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   [0.49553571 0.49553571 0.49553571]
   ...
   [0.49553571 0.49553571 0.49553571]
   [0.04910714 0.04910714 0.04910714]
   [0.         0.         0.        ]]

  ...

  [[0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   [0.49553571 0.49553571 0.49553571]
   ...
   [1.09821429 1.09821429 1.09821429]
   [1.13392857 1.13392857 1.13392857]
   [1.13839286 1.13839286 1.13839286]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   ...
   [1.13392857 1.13392857 1.13392857]
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]]]


 [[[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   ...
   [0.04910714 0.04910714 0.04910714]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   [0.49553571 0.49553571 0.49553571]
   ...
   [0.49553571 0.49553571 0.49553571]
   [0.04910714 0.04910714 0.04910714]
   [0.         0.         0.        ]]

  ...

  [[0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   [0.49553571 0.49553571 0.49553571]
   ...
   [1.09821429 1.09821429 1.09821429]
   [1.13392857 1.13392857 1.13392857]
   [1.13839286 1.13839286 1.13839286]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   ...
   [1.13392857 1.13392857 1.13392857]
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]]]


 [[[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   ...
   [0.04910714 0.04910714 0.04910714]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   [0.49553571 0.49553571 0.49553571]
   ...
   [0.49553571 0.49553571 0.49553571]
   [0.04910714 0.04910714 0.04910714]
   [0.         0.         0.        ]]

  ...

  [[0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   [0.49553571 0.49553571 0.49553571]
   ...
   [1.09821429 1.09821429 1.09821429]
   [1.13392857 1.13392857 1.13392857]
   [1.13839286 1.13839286 1.13839286]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.04910714 0.04910714 0.04910714]
   ...
   [1.13392857 1.13392857 1.13392857]
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]
   [1.13839286 1.13839286 1.13839286]]]] (of type <class 'numpy.ndarray'>)

In [None]:
# Save the model
model.save('widget-classification.h5')

# Scale data

# Build Deep Learning Model

# Plot performance

# Evaluate model performance

# Test model

# Save model