In [17]:
import tensorflow as tf
import pathlib
import cv2

# Load your pre-trained segmentation model
def weighted_crossentropy(class_weights):
      def loss_fn(y_true, y_pred):
          # Apply softmax activation to logits
          y_pred = tf.nn.softmax(y_pred, axis=-1)

          # Flatten both inputs
          y_true_f = tf.reshape(y_true, [-1])
          y_pred_f = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])

          # Calculate weighted cross-entropy loss
          loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.cast(y_true_f, dtype=tf.int32), logits=y_pred_f)
          weighted_loss = tf.reduce_mean(tf.multiply(loss, tf.gather(class_weights, tf.cast(y_true_f, dtype=tf.int32))))

          return weighted_loss

      return loss_fn

segmentation_model = tf.keras.models.load_model('saved_seg_models/model2023-03-18 15_39_47.663348', compile=False)
segmentation_model.compile(
      optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
      loss=weighted_crossentropy,
      metrics=["accuracy"],
    )

# Freeze the layers of the segmentation model
for layer in segmentation_model.layers:
    layer.trainable = False

# Load your training dataset and pass each image through the segmentation model to obtain its segmented output
train_dir = pathlib.Path('plantvillage')
batch_size = 32
img_height = 512
img_width = 512

training_dataset = tf.keras.utils.image_dataset_from_directory(
  train_dir,
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)  # load your training dataset

segmented_training_dataset = []
for images, labels in training_dataset:
     segmentation_model.predict(images)
    # for index in range(len(images)):
    #     image = images[index]
    #     label = labels[index]
    # #     print(images.shape, label.shape)
    #     segmented_image = segmentation_model.predict(image[tf.newaxis, ...])
    #     segmented_training_dataset.append((segmented_image, label))

# Remove all pixels labeled as class 0 and 1 from the segmented output
segmented_training_dataset = tf.convert_to_tensor(segmented_training_dataset)
class_mask = tf.math.logical_and(tf.not_equal(segmented_training_dataset, 0), tf.not_equal(segmented_training_dataset, 1))
segmented_training_dataset = tf.where(class_mask, segmented_training_dataset, tf.zeros_like(segmented_training_dataset))

# Use the segmented output with class 0 and 1 removed as input to the segmentation model to extract features
feature_extraction_model = tf.keras.models.Sequential([
    segmentation_model.layers[0],  # input layer
    segmentation_model.layers[1],  # first convolutional layer
    segmentation_model.layers[2],  # second convolutional layer
    segmentation_model.layers[3]   # third convolutional layer
])

# Use the extracted features as input to the image classification model
classification_model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(64, 64, 1)),  # flatten the extracted features
    tf.keras.layers.Dense(64, activation='relu'),     # fully connected layer
    tf.keras.layers.Dense(10, activation='softmax')   # output layer with 10 classes
])

# Train the classifier on the extracted features using your training dataset
classification_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
classification_model.fit(segmented_training_dataset, epochs=1)

# # Test the performance of the trained image classification model on a test dataset
# test_dataset = ...  # load your test dataset
# segmented_test_dataset = []
# for image in test_dataset:
#     segmented_image = segmentation_model.predict(image)
#     segmented_test_dataset.append(segmented_image)
# segmented_test_dataset = tf.convert_to


Found 18835 files belonging to 10 classes.


2023-03-21 16:25:49.231150: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


ResourceExhaustedError: Graph execution error:

Detected at node 'model/up_sampling2d_2/resize/ResizeBilinear' defined at (most recent call last):
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/runpy.py", line 86, in _run_code
      exec(code, run_globals)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/traitlets/config/application.py", line 992, in launch_instance
      app.start()
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 711, in start
      self.io_loop.start()
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
      self._run_once()
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/asyncio/base_events.py", line 1906, in _run_once
      handle._run()
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 411, in do_execute
      res = shell.run_cell(
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 531, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2945, in run_cell
      result = self._run_cell(
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3000, in _run_cell
      return runner(coro)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3203, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3382, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3442, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/var/folders/qf/vw9dzh494wd2yqv7qpcv886h0000gn/T/ipykernel_6621/1175652454.py", line 48, in <module>
      segmentation_model.predict(images)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/engine/training.py", line 2033, in predict
      tmp_batch_outputs = self.predict_function(iterator)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/engine/training.py", line 1845, in predict_function
      return step_function(self, iterator)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/engine/training.py", line 1834, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/engine/training.py", line 1823, in run_step
      outputs = model.predict_step(data)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/engine/training.py", line 1791, in predict_step
      return self(x, training=False)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/engine/training.py", line 490, in __call__
      return super().__call__(*args, **kwargs)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/engine/base_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/engine/functional.py", line 458, in call
      return self._run_internal_graph(
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/engine/functional.py", line 596, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/engine/base_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/layers/reshaping/up_sampling2d.py", line 129, in call
      return backend.resize_images(
    File "/opt/homebrew/Caskroom/miniconda/base/envs/project/lib/python3.10/site-packages/keras/backend.py", line 3432, in resize_images
      x = tf.image.resize(x, new_shape, method=interpolations[interpolation])
Node: 'model/up_sampling2d_2/resize/ResizeBilinear'
OOM when allocating tensor with shape[32,512,512,256] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator Simple allocator
	 [[{{node model/up_sampling2d_2/resize/ResizeBilinear}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_predict_function_500116]