In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

## Supervised Learning- Classification

- Supervised Pneumonia Classification on Chest X-Ray Images

In [None]:
!unzip /content/drive/MyDrive/MTL_on_Chest_X_Ray_Images/archive.zip

In [None]:
import os
import numpy as np
import pandas as pd
import random
import cv2
import seaborn as sns
from PIL import Image
from skimage.io import imread
from skimage.transform import resize
import matplotlib.pyplot as plt
from glob import glob
import argparse
%matplotlib inline
import tensorflow
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from keras.layers import Dense, GlobalAveragePooling2D, Input, Flatten, Dropout, BatchNormalization
from keras.models import Model, Sequential
from keras import backend as K
from keras.applications.densenet import DenseNet121
from keras.applications.imagenet_utils import preprocess_input
from keras.preprocessing import image
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten, Dense, Input, concatenate
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D

In [None]:
# Path to the directory containing the images
train_dir_path ="chest_xray/chest_xray/train"
test_dir_path= "chest_xray/chest_xray/test"

In [None]:
# Get few samples for both the classes
normal_cases_dir = os.path.join(train_dir_path,'NORMAL')
pneumonia_cases_dir = os.path.join(train_dir_path,'PNEUMONIA')

# Get the list of all the images
normal_cases = glob(normal_cases_dir + '/*.jpeg')
pneumonia_cases = glob(pneumonia_cases_dir + '/*.jpeg')

In [None]:
print(f"Total number of image for normal cases : {len(normal_cases)}\n\
Total number of image for pneumonia cases : {len(pneumonia_cases)}")

In [None]:
samples=[]
for i in range(5):
    samples.append(normal_cases[i])
for i in range(5):
    samples.append(pneumonia_cases[i])

In [None]:
# Plot the data
f, ax = plt.subplots(2,5, figsize=(40,15))
for i in range(10):
    img = imread(samples[i])
    ax[i//5, i%5].imshow(img, cmap='gray')
    if i<5:
        ax[i//5, i%5].set_title("Normal")
    else:
        ax[i//5, i%5].set_title("Pneumonia")
    ax[i//5, i%5].axis('off')
    ax[i//5, i%5].set_aspect('auto')
plt.show()

In [None]:
# Data generation objects
train_datagen = ImageDataGenerator(zoom_range=0.1,
                                   horizontal_flip = True,
                                   fill_mode = 'constant',
                                   validation_split=0.1,
                                   preprocessing_function = preprocess_input)

test_datagen = ImageDataGenerator(preprocessing_function = preprocess_input)

In [None]:
image_size = 224
batch_size = 16

# This is fed to the network in the specified batch sizes and image dimensions
train_gen = train_datagen.flow_from_directory(train_dir_path,
                                              target_size=(image_size, image_size),
                                              batch_size=batch_size,
                                              class_mode='binary',
                                              shuffle=True,
                                              subset='training')

val_gen = train_datagen.flow_from_directory(train_dir_path, # same directory as training data
                                                  target_size=(image_size, image_size),
                                                  batch_size=batch_size,
                                                  shuffle = True,
                                                  class_mode='binary',
                                                  subset='validation')

test_gen = test_datagen.flow_from_directory(test_dir_path,
                                                target_size=(image_size, image_size),
                                                batch_size=batch_size,
                                                class_mode='binary',
                                                shuffle=True)

In [None]:
train_labels=[]

for img_path in normal_cases:
    train_labels.append(0)

for img_path in pneumonia_cases:
    train_labels.append(1)

#convert label list to numpy array
train_labels = np.array(train_labels)

#Compute class weights for each class
from sklearn.utils import class_weight
class_weights = class_weight.compute_class_weight(class_weight='balanced',classes=np.unique(train_labels),y=train_labels)
print(class_weights)

In [None]:
# create the base pre-trained model
base_model = DenseNet121(weights='imagenet',include_top=False)
x = base_model.output

# add a global spatial average pooling layer
x= GlobalAveragePooling2D()(x)

# dropout layer
x= Dropout(0.2)(x)

# add a logistic layers
prediction = Dense(1, activation="sigmoid")(x)

model= Model(inputs=base_model.inputs, outputs=prediction)

In [None]:
#Follow ChexNeXt Paper
base_learning_rate = 0.0001

# compile model
model.compile(optimizer=Adam(learning_rate= base_learning_rate),loss="binary_crossentropy",metrics=['accuracy'])

In [None]:
# epochs
initial_epochs = 20

# Callbacks
#Save best model
ckpt_filename= "/content/drive/MyDrive/dn121_class_weights_pretrained.hdf5"
checkpoint = ModelCheckpoint(filepath=ckpt_filename, save_best_only=True, save_weights_only=True, verbose = 1)
lr_reduce = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=1, verbose=1)
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=2, mode='min', verbose = 1)

In [None]:
# fitting the model
history= model.fit(train_gen,
                  epochs=initial_epochs,
                  validation_data= val_gen,
                  callbacks=[checkpoint, early_stop, lr_reduce],
                  class_weight = {0: class_weights[0], 1: class_weights[1]})

In [None]:
acc=  history.history['accuracy']
val_acc = history.history['val_accuracy']

loss=history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(15,10))
plt.subplot(2,1,1)
plt.plot(acc,label='Training Accuracy')
plt.plot(val_acc, label="validation accuracy")
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')


plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([min(plt.ylim()),2])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

In [None]:
#Load the best model during initial training
model.load_weights(ckpt_filename)

In [None]:
# Finetuning - Unfreeze the last layers of the model
base_model.trainable = True

In [None]:
# compile model
model.compile(optimizer= Adam(learning_rate=base_learning_rate/10), loss= "binary_crossentropy", metrics = ['accuracy'])

# Callbacks
final_ckpt_filename= "/content/drive/MyDrive/dn121_class_weights_pretrained.hdf5"
checkpoint = ModelCheckpoint(filepath=final_ckpt_filename, save_best_only=True, save_weights_only=True, verbose = 1)

checkpoint

In [None]:
lr_reduce = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=1, verbose=1)
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=2, mode='min', verbose = 1)

In [None]:
fine_tune_epochs = 20

# Fitting the model
history_unfreeze = model.fit(train_gen,
                    epochs= fine_tune_epochs,
                    validation_data=val_gen,
                    callbacks=[checkpoint, early_stop, lr_reduce],
                    class_weight = {0: class_weights[0], 1: class_weights[1]})

In [None]:
acc = history_unfreeze.history['accuracy']
val_acc = history_unfreeze.history['val_accuracy']

loss = history_unfreeze.history['loss']
val_loss = history_unfreeze.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy after unfreezing all layers')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss after unfreezing all layers')
plt.xlabel('epoch')
plt.show()

In [None]:
# Load the best model
model.load_weights(final_ckpt_filename)

In [None]:
model.summary()

In [None]:
#Evaluate on the test data
test_loss, test_score = model.evaluate(test_gen)
print("Loss on test set: ", test_loss)
print("Accuracy on test set: ", test_score)

In [None]:
def fn_preprocess_images(data_directory, image_size):
    normal_dir = os.path.join(data_directory, 'NORMAL')
    pneumonia_dir = os.path.join(data_directory, 'PNEUMONIA')

     # Get the list of all the images
    normal_cases = glob(normal_dir + '/*.jpeg')
    pneumonia_cases = glob(pneumonia_dir + '/*.jpeg')

    #Store all images and labels
    image_data_list = []
    labels = []

    for img_path in normal_cases:
        img = tf.keras.utils.load_img(img_path, target_size=(image_size, image_size))
        x = tf.keras.utils.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        x = preprocess_input(x)
        #print('Input image shape:', x.shape)
        image_data_list.append(x)
        labels.append(0)

    for img_path in pneumonia_cases:
        img = tf.keras.utils.load_img(img_path, target_size=(image_size, image_size))
        x = tf.keras.utils.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        x = preprocess_input(x)
        #print('Input image shape:', x.shape)
        image_data_list.append(x)
        labels.append(1)

  # Convert the images to tensor shape (n_images, h, w, channel)
    img_data = np.array(image_data_list)
    img_data=np.rollaxis(img_data,1,0)
    img_data=img_data[0]
    print("Final data shape: "+str(img_data.shape))

    #convert label list to numpy array
    labels = np.array(labels)

    return img_data,labels

In [None]:
test_data, test_labels  = fn_preprocess_images(test_dir_path, image_size = 224)

In [None]:
# Predict on test data
preds = model.predict(test_data)

In [None]:
# Classification report
from sklearn.metrics import accuracy_score,classification_report, roc_curve, confusion_matrix

acc = accuracy_score(test_labels, np.round(preds))*100
print("Test data accuracy : "+str(acc))
print("Classification report")
print(classification_report(test_labels,np.round(preds)))

In [None]:
# Plot the confusion matrix
conf_matrix = confusion_matrix(test_labels, np.round(preds))
plt.figure(figsize=(12.8,6))
sns.heatmap(conf_matrix,
            annot=True,
            xticklabels=['Bacteria Pneumonia', 'Virus Pneumonia'],
            yticklabels=['Bacteria Pneumonia', 'Virus Pneumonia'],
            cmap="Blues",
            fmt='g')
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Confusion matrix')
plt.show()

In [None]:
from sklearn.metrics import roc_curve, auc
y_preds = preds.ravel()
model_fpr, model_tpr, model_threshold = roc_curve(test_labels, y_preds)
model_auc = auc(model_fpr, model_tpr)

In [None]:
plt.plot([0, 1], [0, 1], 'k--')
plt.plot(model_fpr, model_tpr, label='AUC Score(area = {:.3f})'.format(model_auc))
plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.title('DenseNet 121 using class weights - ROC curve ')
plt.legend(loc='best')
plt.show()

In [None]:
!mkdir -p /content/gdrive/My\ Drive/
!touch /content/gdrive/My\ Drive/my_model_weights.h5

In [None]:
def save_model(model, filename):
    filepath = '/content/gdrive/My Drive/' + filename
    model.save(filepath)

save_model(model, 'supervised_learning.h5')


## Unsupervised Learning- Image Segementation

- Unsupervised Pneumonia Image Segementation on Chest X-Ray Images

In [None]:
# Plot the data
f, ax = plt.subplots(2,5, figsize=(40,15))
for i in range(10):
    img = imread(samples[i])
    ax[i//5, i%5].imshow(img, cmap='gray')
    if i<5:
        ax[i//5, i%5].set_title("Normal")
    else:
        ax[i//5, i%5].set_title("Pneumonia")
    ax[i//5, i%5].axis('off')
    ax[i//5, i%5].set_aspect('auto')
plt.show()

In [None]:
for i in range(10):

    # Load the black and white image
    img = cv2.imread(samples[i], cv2.IMREAD_GRAYSCALE)

    # Improve image quality
    img = cv2.medianBlur(img, 5)

    # Apply Canny edge detection on the image
    edges = cv2.Canny(img, 100, 200)

    # Perform a closing operation on the edges to fill in gaps
    kernel = np.ones((5,5),np.uint8)
    closed_edges = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel)

    # Find contours in the image
    contours, hierarchy = cv2.findContours(closed_edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Create a mask of the same shape as the image
    mask = np.zeros_like(img)

    # Draw contours on the mask
    cv2.drawContours(mask, contours, -1, (255,255,255), -1)

    # Apply the mask to the original image
    masked_img = cv2.bitwise_and(img, img, mask=mask)

    # Display the original image, edges, and segmented image
    fig, ax = plt.subplots(1, 4, figsize=(12, 4))
    ax[0].imshow(img, cmap='gray', vmin=img.min(), vmax=img.max())
    ax[0].set_title('Original Image')
    ax[1].imshow(edges, cmap='gray', vmin=edges.min(), vmax=edges.max())
    ax[1].set_title('Canny Edges')
    ax[2].imshow(closed_edges, cmap='gray', vmin=closed_edges.min(), vmax=closed_edges.max())
    ax[2].set_title('Closed Edges')
    ax[3].imshow(masked_img, cmap='gray', vmin=masked_img.min(), vmax=masked_img.max())
    ax[3].set_title('Segmented Image')
    plt.show()


In [None]:
for i in range(10):

    # Load the black and white image
    img = cv2.imread(samples[i], cv2.IMREAD_GRAYSCALE)

    # Apply Gaussian blurring to the image
    blur = cv2.GaussianBlur(img, (5, 5), 0)

    # Perform adaptive thresholding on the blurred image
    thresh = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 11, 4)

    # Apply Canny edge detection on the thresholded image
    edges = cv2.Canny(thresh, 100, 200)

    # Apply a dilation operation to the edges to fill in gaps
    kernel = np.ones((5,5),np.uint8)
    dilated_edges = cv2.dilate(edges,kernel,iterations = 1)

    # Apply a closing operation to the dilated edges to remove noise
    closing_kernel = np.ones((15,15),np.uint8)
    closed_edges = cv2.morphologyEx(dilated_edges, cv2.MORPH_CLOSE, closing_kernel)

    # Find contours in the image
    contours, hierarchy = cv2.findContours(closed_edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Create a mask of the same shape as the image
    mask = np.zeros_like(img)

    # Draw contours on the mask
    cv2.drawContours(mask, contours, -1, (255,255,255), -1)

    # Apply the mask to the original image
    masked_img = cv2.bitwise_and(img, img, mask=mask)

    # Display the original image, edges, and segmented image
    fig, ax = plt.subplots(1, 4, figsize=(12, 4))
    ax[0].imshow(img, cmap='gray')
    ax[0].set_title('Original Image')
    ax[1].imshow(thresh, cmap='gray')
    ax[1].set_title('Thresholded Image')
    ax[2].imshow(edges, cmap='gray')
    ax[2].set_title('Canny Edges')
    ax[3].imshow(masked_img, cmap='gray')
    ax[3].set_title('Segmented Image')
    plt.show()

In [None]:
def segment_image(img_path):
    # Load the black and white image
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)

    # Apply Gaussian blurring to the image
    blur = cv2.GaussianBlur(img, (5, 5), 0)

    # Perform adaptive thresholding on the blurred image
    thresh = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 11, 4)

    # Apply Canny edge detection on the thresholded image
    edges = cv2.Canny(thresh, 100, 200)

    # Apply a dilation operation to the edges to fill in gaps
    kernel = np.ones((5,5),np.uint8)
    dilated_edges = cv2.dilate(edges,kernel,iterations = 1)

    # Apply a closing operation to the dilated edges to remove noise
    closing_kernel = np.ones((15,15),np.uint8)
    closed_edges = cv2.morphologyEx(dilated_edges, cv2.MORPH_CLOSE, closing_kernel)

    # Find contours in the image
    contours, hierarchy = cv2.findContours(closed_edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Create a mask of the same shape as the image
    mask = np.zeros_like(img)

    # Draw contours on the mask
    cv2.drawContours(mask, contours, -1, (255,255,255), -1)

    # Apply the mask to the original image
    masked_img = cv2.bitwise_and(img, img, mask=mask)

    # Display the original image, edges, and segmented image
    fig, ax = plt.subplots(1, 4, figsize=(12, 4))
    ax[0].imshow(img, cmap='gray')
    ax[0].set_title('Original Image')
    ax[1].imshow(thresh, cmap='gray')
    ax[1].set_title('Thresholded Image')
    ax[2].imshow(edges, cmap='gray')
    ax[2].set_title('Canny Edges')
    ax[3].imshow(masked_img, cmap='gray')
    ax[3].set_title('Segmented Image')
    plt.show()

In [None]:
new_output = []
for image_path in normal_cases[:5]:
  output_seg= segment_image(image_path)
  new_output.append(output_seg)

In [None]:
def build_model(input_shape):
    model = Sequential()

    model.add(Conv2D(32, (3, 3), activation='relu', input_shape=input_shape))
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(0.25))

    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(0.25))

    model.add(Conv2D(128, (3, 3), activation='relu'))
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(0.25))

    model.add(Flatten())
    model.add(Dense(64, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(2, activation='softmax'))

    return model

In [None]:
# define paths to your dataset folders
train_data_dir = 'chest_xray/train/'
val_data_dir = 'chest_xray/val/'

# define image size and batch size
img_size = (224, 224)
batch_size = 16

# define data generators for training and validation
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,
                                   shear_range=0.2,
                                   zoom_range=0.2,
                                   horizontal_flip=True)

train_generator = train_datagen.flow_from_directory(train_data_dir,
                                                    target_size=img_size,
                                                    batch_size=batch_size,
                                                    class_mode='categorical')

val_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

val_generator = val_datagen.flow_from_directory(val_data_dir,
                                                target_size=img_size,
                                                batch_size=batch_size,
                                                class_mode='categorical')



In [None]:
# build the model
model = build_model(input_shape=img_size + (3,))
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# define early stopping
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
ckpt_filename= "/content/drive/MyDrive/unsupervised_learning.hdf5"
checkpoint = ModelCheckpoint(filepath=ckpt_filename, save_best_only=True, save_weights_only=True, verbose = 1)

# train the model using the generators
history = model.fit(train_generator,
                    steps_per_epoch=len(train_generator),
                    epochs=20,
                    validation_data=val_generator,
                    validation_steps=len(val_generator),
                    callbacks=[early_stop])



In [None]:
# plot the training and validation accuracy and loss
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(len(history.history['accuracy']))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()


In [None]:
!mkdir -p /content/gdrive/My\ Drive/
!touch /content/gdrive/My\ Drive/my_model_weights.h5

In [None]:
def save_model(model, filename):
    filepath = '/content/gdrive/My Drive/' + filename
    model.save(filepath)

save_model(model, 'Unsupervised_learning.h5')


## Multitask Learning

- Addition of Supervised & Unsupervised learning losses

In [None]:
# Load the pre-trained classification and segmentation models
classification_model = tf.keras.models.load_model('/content/gdrive/My Drive/supervised_learning.h5')
segmentation_model = tf.keras.models.load_model('/content/gdrive/My Drive/Unsupervised_learning.h5')

In [None]:
# Remove the last layer of the classification model
classification_model.layers.pop()

In [None]:
# Add a shared multi-task layer to the segmentation model
shared_layer = Dense(256, activation='relu')(segmentation_model.layers[-2].output)

In [None]:
# Combine the modified classification model and segmentation model into a single model
classification_output = classification_model.layers[-1].output
multi_task_output = concatenate([shared_layer, classification_output])
multi_task_model = tf.keras.Model(inputs=[classification_model.input, segmentation_model.input], outputs=[multi_task_output])

In [None]:
# Define the loss function for the combined model to include both classification and segmentation losses
classification_loss = tf.keras.losses.CategoricalCrossentropy()
segmentation_loss = tf.keras.losses.BinaryCrossentropy()
losses = {
    'classification': classification_loss,
    'segmentation': segmentation_loss
}
loss_weights = {
    'classification': 1.0,
    'segmentation': 0.5
}
multi_task_model.compile(optimizer='adam', loss=losses, loss_weights=loss_weights, metrics=['accuracy'])



In [None]:
multi_task_model.summary()

In [None]:
train_datagen = ImageDataGenerator(
    zoom_range=0.1,
    horizontal_flip=True,
    fill_mode='constant',
    validation_split=0.1,
    preprocessing_function=preprocess_input
)

test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)


In [None]:
train_datagen = ImageDataGenerator(
    zoom_range=0.1,
    horizontal_flip=True,
    fill_mode='constant',
    validation_split=0.1,
    preprocessing_function=preprocess_input
)

test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

image_size = 224
batch_size = 16

train_classification_gen = train_datagen.flow_from_directory(
    train_dir_path,
    target_size=(image_size, image_size),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=True,
    subset='training',
    color_mode='grayscale' # specify grayscale color mode
)

train_segmentation_gen = train_datagen.flow_from_directory(
    train_dir_path,
    target_size=(image_size, image_size),
    batch_size=batch_size,
    class_mode='binary',
    shuffle=True,
    subset='training',
    color_mode='grayscale' # specify grayscale color mode
)

val_classification_gen = train_datagen.flow_from_directory(
    train_dir_path,
    target_size=(image_size, image_size),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=True,
    subset='validation',
    color_mode='grayscale' # specify grayscale color mode
)

val_segmentation_gen = train_datagen.flow_from_directory(
    train_dir_path,
    target_size=(image_size, image_size),
    batch_size=batch_size,
    class_mode='binary',
    shuffle=True,
    subset='validation',
    color_mode='grayscale' # specify grayscale color mode
)

test_gen = test_datagen.flow_from_directory(
    test_dir_path,
    target_size=(image_size, image_size),
    batch_size=batch_size,
    class_mode='binary',
    shuffle=True,
    color_mode='grayscale' # specify grayscale color mode
)




In [None]:
train_dataset = tf.data.Dataset.from_generator(
    lambda: zip(train_classification_gen, train_segmentation_gen),
    output_types=((tf.float32, tf.float32), (tf.float32,)),
    output_shapes=(((batch_size, image_size, image_size, 1), (batch_size, image_size, image_size, 1)), (batch_size,)),
)

val_dataset = tf.data.Dataset.from_generator(
    lambda: zip(val_classification_gen, val_segmentation_gen),
    output_types=((tf.float32, tf.float32), (tf.float32,)),
    output_shapes=(((batch_size, image_size, image_size, 1), (batch_size, image_size, image_size, 1)), (batch_size,)),
)

multi_task_model.fit(
    train_dataset,
    epochs=20,
    validation_data=val_dataset
)