In [35]:
# autopep8: off
import shutil
import os
os.environ['SM_FRAMEWORK'] = 'tf.keras'
# autopep8: on

import yaml
import cv2
import keras
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import albumentations as A
import segmentation_models as sm
import wandb

backend=tf.keras.backend
layers=tf.keras.layers
models=tf.keras.models
keras_utils = tf.keras.utils

In [36]:
from keras_applications import get_submodules_from_kwargs

from segmentation_models.models._common_blocks import Conv2dBn
from segmentation_models.models._utils import freeze_model
from segmentation_models.backbones.backbones_factory import Backbones

In [37]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]

In [38]:
DATA_DIR = './Images/'
MASK_DIR = './Masks/'

In [39]:
def visualize(rows=1, **images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(rows, n//rows, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()


def denormalize(x):
    """Scale image to range 0..1 for correct plot"""
    x_max = np.percentile(x, 98)
    x_min = np.percentile(x, 2)
    x = (x - x_min) / (x_max - x_min)
    x = x.clip(0, 1)
    return x


def image_to_same_shape(image, height, width):
    if len(image.shape) == 2:
        old_image_height, old_image_width = image.shape
    else:
        old_image_height, old_image_width, channels = image.shape

    # create new image of desired size and color (blue) for padding
    new_image_width = width
    new_image_height = height
    color = (0)
    if len(image.shape) == 2:
        result = np.full((new_image_height, new_image_width),
                         color, dtype=np.uint8)
    else:
        result = np.full((new_image_height, new_image_width,
                         channels), color, dtype=np.uint8)

    # compute center offset
    x_center = (new_image_width - old_image_width) // 2
    y_center = (new_image_height - old_image_height) // 2

    # copy img image into center of result image
    result[y_center:y_center+old_image_height,
           x_center:x_center+old_image_width] = image

    return result


class Dataset:
    """Dataset. Read images, apply augmentation and preprocessing transformations.

    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)

    """

    CLASSES = ['unlabelled', 'seed', 'pulp', 'albedo', 'flavedo']

    def __init__(
            self,
            ids,
            images_dir,
            masks_dir,
            classes=None,
            augmentation=None,
            preprocessing=None,
    ):
        self.ids = ids
        # self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id)
                           for image_id in self.ids]
        self.masks_fps = [os.path.join(
            masks_dir, image_id)+'.png' for image_id in self.ids]

        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(
            cls.lower()) for cls in classes]

        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __getitem__(self, i):

        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)

        image = image_to_same_shape(image, 1024, 1024)
        mask = image_to_same_shape(mask, 1024, 1024)

        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v*(255//4)) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')

        # add background if mask is not binary
        if mask.shape[-1] != 1:
            background = 1 - mask.sum(axis=-1, keepdims=True)
            mask = np.concatenate((mask, background), axis=-1)

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask

    def __len__(self):
        return len(self.ids)


class_labels = {0: "Seed",
                1: "Pulp",
                2: "Albedo",
                3: "Flavedo",
                4: "Background"
                }


class Dataloder(keras.utils.Sequence):
    """Load data from dataset and form batches

    Args:
        dataset: instance of Dataset class for image loading and preprocessing.
        batch_size: Integet number of images in batch.
        shuffle: Boolean, if `True` shuffle image indexes each epoch.
    """

    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(dataset))

        self.on_epoch_end()

    def __getitem__(self, i):

        # collect batch data
        start = i * self.batch_size
        stop = (i + 1) * self.batch_size
        data = []
        for j in range(start, stop):
            data.append(self.dataset[j])

        # transpose list of lists
        batch = [np.stack(samples, axis=0) for samples in zip(*data)]

        return batch

    def __len__(self):
        """Denotes the number of batches per epoch"""
        return len(self.indexes) // self.batch_size

    def on_epoch_end(self):
        """Callback function to shuffle indexes each epoch"""
        if self.shuffle:
            self.indexes = np.random.permutation(self.indexes)


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform

    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose

    """

    _transform = [
        A.Lambda(image=preprocessing_fn),
    ]
    return A.Compose(_transform)

In [40]:
def filter_keras_submodules(kwargs):
    """Selects only arguments that define keras_application submodules. """
    submodule_keys = kwargs.keys() & {'backend', 'layers', 'models', 'utils'}
    return {key: kwargs[key] for key in submodule_keys}

In [41]:
BATCH_SIZE = 1
CLASSES = ['seed', 'pulp', 'albedo', 'flavedo']
LR = 0.0001
EPOCHS = 10
n_classes = len(CLASSES) + 1
activation = 'softmax'

In [42]:
model = sm.Linknet("efficientnetb0", classes=5, activation="softmax")

In [43]:
model.get_layer("top_activation")

<keras.layers.core.activation.Activation at 0x2b9028a8a90>

In [44]:
def get_submodules():
    return {
        'backend': backend,
        'models': models,
        'layers': layers,
        'utils': keras_utils,
    }

In [45]:
from segmentation_models.models.linknet import Conv3x3BnReLU, Conv1x1BnReLU

In [46]:
def DecoderUpsamplingX2Block(filters, stage, use_batchnorm, branch):
    conv_block1_name = 'decoder_stage{}a_{}'.format(stage, branch)
    conv_block2_name = 'decoder_stage{}b_{}'.format(stage, branch)
    conv_block3_name = 'decoder_stage{}c_{}'.format(stage, branch)
    up_name = 'decoder_stage{}_upsampling_{}'.format(stage, branch)
    add_name = 'decoder_stage{}_add_{}'.format(stage, branch)

    channels_axis = 3 if backend.image_data_format() == 'channels_last' else 1

    def wrapper(input_tensor, skip=None):
        input_filters = backend.int_shape(input_tensor)[channels_axis]
        output_filters = backend.int_shape(skip)[channels_axis] if skip is not None else filters

        x = Conv1x1BnReLU(input_filters // 4, use_batchnorm, name=conv_block1_name)(input_tensor)
        x = layers.UpSampling2D((2, 2), name=up_name)(x)
        x = Conv3x3BnReLU(input_filters // 4, use_batchnorm, name=conv_block2_name)(x)
        x = Conv1x1BnReLU(output_filters, use_batchnorm, name=conv_block3_name)(x)

        if skip is not None:
            x = layers.Add(name=add_name)([x, skip])
        return x

    return wrapper


def DecoderTransposeX2Block(filters, stage, use_batchnorm, branch):
    conv_block1_name = 'decoder_stage{}a_{}'.format(stage, branch)
    transpose_name = 'decoder_stage{}b_transpose_{}'.format(stage, branch)
    bn_name = 'decoder_stage{}b_bn_{}'.format(stage, branch)
    relu_name = 'decoder_stage{}b_relu_{}'.format(stage, branch)
    conv_block3_name = 'decoder_stage{}c_{}'.format(stage, branch)
    add_name = 'decoder_stage{}_add_{}'.format(stage, branch)

    channels_axis = bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1

    def wrapper(input_tensor, skip=None):
        input_filters = backend.int_shape(input_tensor)[channels_axis]
        output_filters = backend.int_shape(skip)[channels_axis] if skip is not None else filters

        x = Conv1x1BnReLU(input_filters // 4, use_batchnorm, name=conv_block1_name)(input_tensor)
        x = layers.Conv2DTranspose(
            filters=input_filters // 4,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding='same',
            name=transpose_name,
            use_bias=not use_batchnorm,
        )(x)

        if use_batchnorm:
            x = layers.BatchNormalization(axis=bn_axis, name=bn_name)(x)

        x = layers.Activation('relu', name=relu_name)(x)
        x = Conv1x1BnReLU(output_filters, use_batchnorm, name=conv_block3_name)(x)

        if skip is not None:
            x = layers.Add(name=add_name)([x, skip])

        return x

    return wrapper

In [47]:
def build_linknet(
        backbone,
        decoder_block,
        skip_connection_layers,
        decoder_filters=(256, 128, 64, 32, 16),
        n_upsample_blocks=5,
        classes=1,
        activation='sigmoid',
        use_batchnorm=True,
        branch=0
):
    input_ = backbone.input
    x = backbone.output

    # extract skip connections
    skips = ([backbone.get_layer(name=i).output if isinstance(i, str)
              else backbone.get_layer(index=i).output for i in skip_connection_layers])

    # add center block if previous operation was maxpooling (for vgg models)
    if isinstance(backbone.layers[-1], layers.MaxPooling2D):
        x = Conv3x3BnReLU(512, use_batchnorm, name='center_block1')(x)
        x = Conv3x3BnReLU(512, use_batchnorm, name='center_block2')(x)

    # building decoder blocks
    for i in range(n_upsample_blocks):

        if i < len(skips):
            skip = skips[i]
        else:
            skip = None

        x = decoder_block(decoder_filters[i], stage=i, use_batchnorm=use_batchnorm, branch=branch)(x, skip)

    # model head (define number of output classes)
    x = layers.Conv2D(
        filters=classes,
        kernel_size=(3, 3),
        padding='same',
        use_bias=True,
        kernel_initializer='glorot_uniform'
    )(x)
    x = layers.Activation(activation, name=f"{activation}_{branch}")(x)

    return x

In [48]:
def Linknet(
        backbone_name='vgg16',
        input_shape=(None, None, 3),
        classes=1,
        activation='sigmoid',
        weights=None,
        encoder_weights='imagenet',
        encoder_freeze=False,
        encoder_features='default',
        decoder_block_type='upsampling',
        decoder_filters=(None, None, None, None, 16),
        decoder_use_batchnorm=True,
        decoder_branches=2,
        **kwargs
):

    global backend, layers, models, keras_utils
    submodule_args = filter_keras_submodules(kwargs)
    backend, layers, models, keras_utils = get_submodules_from_kwargs(submodule_args)

    if decoder_block_type == 'upsampling':
        decoder_block = DecoderUpsamplingX2Block
    elif decoder_block_type == 'transpose':
        decoder_block = DecoderTransposeX2Block
    else:
        raise ValueError('Decoder block type should be in ("upsampling", "transpose"). '
                         'Got: {}'.format(decoder_block_type))

    backbone = Backbones.get_backbone(
        backbone_name,
        input_shape=input_shape,
        weights=encoder_weights,
        include_top=False,
        **kwargs,
    )

    if encoder_features == 'default':
        encoder_features = Backbones.get_feature_layers(backbone_name, n=4)

    task1_branch_output  = build_linknet(
        backbone=backbone,
        decoder_block=decoder_block,
        skip_connection_layers=encoder_features,
        decoder_filters=decoder_filters,
        classes=classes,
        activation=activation,
        n_upsample_blocks=len(decoder_filters),
        use_batchnorm=decoder_use_batchnorm,
        branch=0
    )
    
    task2_branch_output  = build_linknet(
        backbone=backbone,
        decoder_block=decoder_block,
        skip_connection_layers=encoder_features,
        decoder_filters=decoder_filters,
        classes=classes,
        activation=activation,
        n_upsample_blocks=len(decoder_filters),
        use_batchnorm=decoder_use_batchnorm,
        branch=1
    )

    # lock encoder weights for fine-tuning
    if encoder_freeze:
        freeze_model(backbone, **kwargs)

    # loading model weights
    if weights is not None:
        model.load_weights(weights)
        
    model = tf.keras.Model(inputs = backbone.input, outputs = [task1_branch_output, task2_branch_output])

    return model

In [94]:
branched_model = Linknet("efficientnetb0", classes=5, activation="softmax", backend=backend, layers=layers, models=models, utils=keras_utils)

In [95]:
# define optomizer
optim = keras.optimizers.Adam(1e-3)

# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
# set class weights for dice_loss (car: 1.; pedestrian: 2.; background: 0.5;)
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

metrics = [sm.metrics.IOUScore(threshold=0.5),
           sm.metrics.FScore(threshold=0.5)]

In [96]:
model.compile(optim, "mae", metrics)

In [97]:
branched_model.compile(optimizer=optim,
                       loss={'softmax_0': "mae", 
                             'softmax_1': "mae"},
                       metrics=metrics,)

In [53]:
data_ids = [image_id.replace(".png", "") for image_id in os.listdir(MASK_DIR)]
SIZE = len(data_ids)
TRAIN_SIZE = int(0.6 * SIZE)
VAL_SIZE = int(0.2 * SIZE)

# Dataset for train images
# train_dataset = dataset[:TRAIN_SIZE]
train_dataset = Dataset(
    data_ids[:TRAIN_SIZE],
    DATA_DIR,
    MASK_DIR,
    classes=CLASSES,
)

# Dataset for validation images
# valid_dataset = dataset[TRAIN_SIZE:VAL_SIZE]
valid_dataset = Dataset(
    data_ids[TRAIN_SIZE:TRAIN_SIZE+VAL_SIZE],
    DATA_DIR,
    MASK_DIR,
    classes=CLASSES,
)

# Dataset for test images
# test_dataset = dataset[VAL_SIZE:]
test_dataset = Dataset(
    data_ids[TRAIN_SIZE+VAL_SIZE:],
    DATA_DIR,
    MASK_DIR,
    classes=CLASSES,
)

dataset_image, mask = test_dataset[0]  # get some sample

train_dataloader = Dataloder(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = Dataloder(valid_dataset, batch_size=1, shuffle=False)

In [54]:
checkpoint_dir = f'./keras_checkpoints/branching'
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
checkpoint_path = os.path.join(checkpoint_dir, "best_model_{epoch:02d}")
callbacks = [
    keras.callbacks.ModelCheckpoint(
        checkpoint_path, save_weights_only=True, save_best_only=True),
    keras.callbacks.ReduceLROnPlateau(),
]

In [55]:
%%script false

# train model
history = model.fit(
    train_dataloader,
    steps_per_epoch=len(train_dataloader),
    epochs=EPOCHS,
    callbacks=callbacks,
    validation_data=valid_dataloader,
    validation_steps=len(valid_dataloader),
)

test_dataloader = Dataloder(test_dataset, batch_size=1, shuffle=False)
# load best weights
latest = tf.train.latest_checkpoint(checkpoint_dir)
scores = model.evaluate(test_dataloader)
model.load_weights(latest)
scores = model.evaluate(test_dataloader)

print("Loss: {:.5}".format(scores[0]))
for metric, value in zip(metrics, scores[1:]):
    print("mean {}: {:.5}".format(metric.__name__, value))

Couldn't find program: 'false'


In [56]:
%%script false

del train_dataloader
del valid_dataloader
del test_dataloader
del train_dataset
del valid_dataset
del test_dataset

Couldn't find program: 'false'


In [92]:
import gc
gc.collect()

0

In [98]:
class Dataset2:
    """Dataset. Read images, apply augmentation and preprocessing transformations.

    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)

    """

    CLASSES = ['unlabelled', 'seed', 'pulp', 'albedo', 'flavedo']

    def __init__(
            self,
            ids,
            images_dir,
            masks_dir,
            classes=None,
            augmentation=None,
            preprocessing=None,
    ):
        self.ids = ids
        # self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id)
                           for image_id in self.ids]
        self.masks_fps = [os.path.join(
            masks_dir, image_id)+'.png' for image_id in self.ids]

        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(
            cls.lower()) for cls in classes]

        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __getitem__(self, i):

        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)

        image = image_to_same_shape(image, 1024, 1024)
        mask = image_to_same_shape(mask, 1024, 1024)

        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v*(255//4)) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')

        # add background if mask is not binary
        if mask.shape[-1] != 1:
            background = 1 - mask.sum(axis=-1, keepdims=True)
            mask = np.concatenate((mask, background), axis=-1)

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, [mask, mask]

    def __len__(self):
        return len(self.ids)

In [99]:
data_ids = [image_id.replace(".png", "") for image_id in os.listdir(MASK_DIR)]
SIZE = len(data_ids)
TRAIN_SIZE = int(0.6 * SIZE)
VAL_SIZE = int(0.2 * SIZE)

# Dataset for train images
# train_dataset = dataset[:TRAIN_SIZE]
train_dataset2 = Dataset2(
    data_ids[:TRAIN_SIZE],
    DATA_DIR,
    MASK_DIR,
    classes=CLASSES,
)

# Dataset for validation images
# valid_dataset = dataset[TRAIN_SIZE:VAL_SIZE]
valid_dataset2 = Dataset2(
    data_ids[TRAIN_SIZE:TRAIN_SIZE+VAL_SIZE],
    DATA_DIR,
    MASK_DIR,
    classes=CLASSES,
)

# Dataset for test images
# test_dataset = dataset[VAL_SIZE:]
test_dataset2 = Dataset2(
    data_ids[TRAIN_SIZE+VAL_SIZE:],
    DATA_DIR,
    MASK_DIR,
    classes=CLASSES,
)

# dataset_image2, (mask1, mask2) = test_dataset2[0]  # get some sample

train_dataloader2 = Dataloder(
    train_dataset2, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader2 = Dataloder(valid_dataset2, batch_size=1, shuffle=False)

In [100]:
output = branched_model.predict(dataset_image[tf.newaxis, ...])
np.array(output).shape



(2, 1, 1024, 1024, 5)

In [101]:
# train model
history = branched_model.fit(
    train_dataloader2,
    steps_per_epoch=len(train_dataloader2),
    epochs=EPOCHS,
    callbacks=callbacks,
    validation_data=valid_dataloader2,
    validation_steps=len(valid_dataloader2),
)

test_dataloader2 = Dataloder(test_dataset2, batch_size=1, shuffle=False)
# load best weights
latest = tf.train.latest_checkpoint(checkpoint_dir)
branched_model.load_weights(latest)
scores = branched_model.evaluate(test_dataloader2)

print("Loss: {:.5}".format(scores[0]))
for metric, value in zip(metrics, scores[1:]):
    print("mean {}: {:.5}".format(metric.__name__, value))

Epoch 1/10


InvalidArgumentError: Graph execution error:

Detected at node 'mean_absolute_error/remove_squeezable_dimensions/Squeeze' defined at (most recent call last):
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\traitlets\config\application.py", line 1043, in launch_instance
      app.start()
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\ipykernel\kernelapp.py", line 712, in start
      self.io_loop.start()
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\tornado\platform\asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\asyncio\base_events.py", line 570, in run_forever
      self._run_once()
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\asyncio\base_events.py", line 1859, in _run_once
      handle._run()
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\asyncio\events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\ipykernel\kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\ipykernel\kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\ipykernel\kernelbase.py", line 406, in dispatch_shell
      await result
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\ipykernel\kernelbase.py", line 730, in execute_request
      reply_content = await reply_content
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\ipykernel\ipkernel.py", line 383, in do_execute
      res = shell.run_cell(
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\ipykernel\zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\IPython\core\interactiveshell.py", line 2961, in run_cell
      result = self._run_cell(
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\IPython\core\interactiveshell.py", line 3016, in _run_cell
      result = runner(coro)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\IPython\core\interactiveshell.py", line 3221, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\IPython\core\interactiveshell.py", line 3400, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\IPython\core\interactiveshell.py", line 3460, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\shaki\AppData\Local\Temp\ipykernel_22804\3465977345.py", line 2, in <module>
      history = branched_model.fit(
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\keras\engine\training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\keras\engine\training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\keras\engine\training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\keras\engine\training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\keras\engine\training.py", line 994, in train_step
      loss = self.compute_loss(x, y, y_pred, sample_weight)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\keras\engine\training.py", line 1052, in compute_loss
      return self.compiled_loss(
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\keras\engine\compile_utils.py", line 265, in __call__
      loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\keras\losses.py", line 152, in __call__
      losses = call_fn(y_true, y_pred)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\keras\losses.py", line 265, in call
      y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\keras\utils\losses_utils.py", line 200, in squeeze_or_expand_dimensions
      y_true, y_pred = remove_squeezable_dimensions(y_true, y_pred)
    File "C:\Users\shaki\miniconda3\envs\fruitQuality\lib\site-packages\keras\utils\losses_utils.py", line 139, in remove_squeezable_dimensions
      labels = tf.squeeze(labels, [-1])
Node: 'mean_absolute_error/remove_squeezable_dimensions/Squeeze'
Can not squeeze dim[4], expected a dimension of 1, got 5
	 [[{{node mean_absolute_error/remove_squeezable_dimensions/Squeeze}}]] [Op:__inference_train_function_135794]

In [104]:
!python --version

Python 3.8.16
