In [None]:
%matplotlib inline
%load_ext tensorboard

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import shutil

try:
    shutil.rmtree('logs')
except:
    pass

In [None]:
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.mnist.load_data()

In [None]:
def create_example(x,y):
    c=np.random.randint(0,3) # red or green no blue
    image = 0.5 * np.random.rand(28,28,3)
    image[:,:,c]+= 0.5 * x / 255.
    return image,y,c

In [None]:
x_train[0].shape


In [None]:
plt.imshow(x_train[0])
plt.show()

In [None]:
x,y,c = create_example(x_train[0],y_train[0])
plt.imshow(x)
plt.show()
colors = {0:'red',1:'green',2:'blue'}
print('Color :',colors[c],' Digit : ',y)

In [None]:
#db generator :
def generate_data(x,y,batch_size=32):
    num_examples=len(y)
    
    while True:
        x_batch = np.zeros((batch_size,28,28,3))
        y_batch = np.zeros((batch_size,))
        c_batch = np.zeros((batch_size,))
        
        for i in range(0,batch_size):
            index=np.random.randint(0,num_examples)
            image,digit,color = create_example(x[index],y[index])
            x_batch[i] = image
            y_batch[i] = digit
            c_batch[i] = color
        yield x_batch,[y_batch,c_batch]
        

In [None]:
x1, [y1,c1]=next(generate_data(x_test,y_test,batch_size=1))
plt.imshow(x1[0])
plt.show()
colors = {0:'red',1:'green',2:'blue'}
print('Color :',colors[c1[0]],' Digit : ',y1)

In [None]:
#Creating the model : we have two outputs : color and the label

from tensorflow.keras.layers import Input, Conv2D, Activation,MaxPool2D, Flatten,Add, Dense
from tensorflow.keras.models import Model

input_ = Input(shape=(28,28,3),name="input layer")

conv_1 = Conv2D(filters=32,kernel_size=3,name='conv_1')(input_)
act_1 = Activation(activation='relu',name='act_1')(conv_1)

pool_1 = MaxPool2D(pool_size=4,name='pool_1')(act_1)
flat_1 = Flatten(name='flat_1')(pool_1)

color = Dense(units=3,activation='softmax',name='color')(flat_1) #binary for color

conv_2= Conv2D(32,3,padding='same',name='conv_2')(act_1)#same activation
act_2=Activation('relu',name='act_2')(conv_2)

conv_3 = Conv2D(32,3,padding='same',name='conv_3')(act_2)
add= Add(name='add')([act_1,conv_3])
act_3 = Activation('relu',name='act_3')(add)

pool_2 = MaxPool2D(4,name='pool_2')(act_3)
flat_2 = Flatten(name='flat_2')(pool_2)

digit = Dense(10,activation='softmax',name='digit')(flat_2)

model = Model (input_, [digit,color]) # generate data returns the digit then the color

model.compile(
loss={
    
    'digit': 'sparse_categorical_crossentropy',
    'color':'sparse_categorical_crossentropy'
},
    optimizer='adam',
    metrics=['accuracy']
)

model.summary()


In [None]:
#plotting the model
from tensorflow.keras.utils import plot_model

plot_model(model)

In [None]:
class Logger(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        digit_accuracy = logs.get('digit_accuracy')
        color_accuracy = logs.get('color_accuracy')
        val_digit_accuracy = logs.get('val_digit_accuracy')
        val_color_accuracy = logs.get('val_color_accuracy')
        print('='*30, epoch + 1, '='*30)
        print(f'digit_accuracy: {digit_accuracy:.2f}, color_accuracy: {color_accuracy:.2f}')
        print(f'val_digit_accuracy: {val_digit_accuracy:.2f}, val_color_accuracy: {val_color_accuracy:.2f}')
    

In [None]:
#training the model
train_gen= generate_data(x_train,y_train)
test_gen = generate_data(x_test,y_test)

_ = model.fit(
    train_gen,
    validation_data=test_gen,
    steps_per_epoch=200,
    validation_steps=100,
    epochs=10,
    callbacks=[
        Logger(),
        tf.keras.callbacks.TensorBoard(log_dir='./logs')
    ],
    verbose=False
)


In [None]:
%tensorboard --logdir logs

In [None]:
def test_model(test,show=True):
    x, [y, c] = next(test)

    preds = model.predict(x)
    pred_digit = np.argmax(preds[0])
    pred_color = int(preds[1] > 0.5)
    gt_digit = int(y[0])
    gt_color = int(c[0])

    plt.imshow(x[0])
    if show:
        print(f'GT: {gt_digit}, {colors[gt_color]}')
        print(f'Pr: {pred_digit}, {colors[pred_color]}')
        plt.show()
    else:
        col = 'green' if gt_digit == pred_digit and gt_color == pred_color else 'red'
        plt.ylabel(f'GT: {gt_digit}, {colors[gt_color]}', color=col)
        plt.xlabel(f'Pr: {pred_digit}, {colors[pred_color]}', color=col)
        plt.xticks([])
        plt.yticks([])


test = generate_data(x_test, y_test, batch_size=1)

In [None]:
plt.figure(figsize=(12,12))

for i in range(0,16):
    plt.subplot(4,4,i+1) #subplots start from 1
    test_model(test,False)
plt.show()

In [None]:
def show_confusion_matrix_digit(y_true, y_pred, classes):
    from sklearn.metrics import confusion_matrix
    
    cm = confusion_matrix(y_true, y_pred, normalize='true')

    plt.figure(figsize=(8, 8))
    sp = plt.subplot(1, 1, 1)
    ctx = sp.matshow(cm)
    plt.xticks(list(range(0, 10)), labels=classes)
    plt.yticks(list(range(0, 10)), labels=classes)
    plt.colorbar(ctx)
    plt.show()
    
    

In [None]:
test = generate_data(x_test, y_test, batch_size=32)
x, [y, c] = next(test)
y_pred = list(np.argmax(model.predict(x)[0],axis=-1))
#print(y_pred)

y_true=list(y)
for i in range(10) :
    x, [y, c] = next(test)
    y_pred.extend(list(np.argmax(model.predict(x)[0],axis=-1)))
    y_true.extend(list(y))


In [None]:
show_confusion_matrix_digit(y_true,y_pred,classes=[0,1,2,3,4,5,6,7,8,9])

In [None]:
def show_confusion_matrix_color(y_true, y_pred, classes):
    from sklearn.metrics import confusion_matrix
    
    cm = confusion_matrix(y_true, y_pred, normalize='true')

    plt.figure(figsize=(8, 8))
    sp = plt.subplot(1, 1, 1)
    ctx = sp.matshow(cm)
    plt.xticks(list(range(0, 3)), labels=classes)
    plt.yticks(list(range(0, 3)), labels=classes)
    plt.colorbar(ctx)
    plt.show()
    
    

In [None]:
test = generate_data(x_test, y_test, batch_size=32)
x, [y, c] = next(test)
y_pred = list(np.argmax(model.predict(x)[1],axis=-1))
#print(y_pred)

y_true=list(c)
for i in range(10) :
    x, [y, c] = next(test)
    y_pred.extend(list(np.argmax(model.predict(x)[1],axis=-1)))
    y_true.extend(list(c))


In [None]:
show_confusion_matrix_color(y_true,y_pred,classes=['red','green','blue'])