 Implementing multi-modal image fusion using self-supervised transformers on the COCO dataset

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

Define the self-supervised transformer block

In [None]:
def self_supervised_transformer(inputs):
    # Define the auxiliary tasks
    colorization = Conv2D(filters=3, kernel_size=3, padding='same', activation='sigmoid')(inputs)
    super_resolution = Conv2D(filters=64, kernel_size=3, padding='same', activation='relu')(inputs)
    super_resolution = Conv2D(filters=3, kernel_size=3, padding='same', activation='sigmoid')(super_resolution)
    inpainting = Conv2D(filters=3, kernel_size=3, padding='same', activation='sigmoid')(inputs)
    
    # Concatenate the inputs and auxiliary tasks
    concat = Concatenate()([inputs, colorization, super_resolution, inpainting])
    
    # Apply the transformer block
    transformer = tf.keras.layers.MultiHeadAttention(num_heads=8, key_dim=64)(concat, concat)
    transformer = BatchNormalization()(transformer)
    transformer = Activation('relu')(transformer)
    transformer = Conv2D(filters=64, kernel_size=3, padding='same', activation='relu')(transformer)
    transformer = Conv2D(filters=3, kernel_size=3, padding='same', activation='sigmoid')(transformer)
    
    # Return the fused image
    return transformer

Define the multimodal image fusion model

In [None]:

def multimodal_image_fusion():
    # Define the input shape
    input_shape = (256, 256, 3)
    
    # Define the input layer
    inputs = Input(shape=input_shape)
    
    # Apply the self-supervised transformer block
    fused_image = self_supervised_transformer(inputs)
    
    # Define the model
    model = Model(inputs=inputs, outputs=fused_image)
    
    return model

Load the COCO dataset

In [None]:
dataset = tfds.load('coco/2017', split='train', as_supervised=True)
dataset = dataset.map(lambda x, y: (tf.image.resize(x, (256, 256)), tf.image.resize(y, (256, 256))))

Normalize the pixel values

In [None]:

dataset = dataset.map(lambda x, y: (x / 255.0, y / 255.0))

Define the model

In [None]:

model = multimodal_image_fusion()

Compile the model

In [None]:

model.compile(optimizer=Adam(lr=1e-4),
              loss='mean_squared_error')

Train the model

In [None]:

model.fit(dataset, epochs=10)
