### PASTIS Implementation 

In [1]:
import tensorflow as tf
import numpy as np
import math

from tensorflow.keras.layers import Input, Conv2D, Conv1D, UpSampling2D, MaxPooling2D, Dropout, Conv2DTranspose, concatenate





In [2]:
POOL_SIZE = 3
STRIDES = 2
KERNEL_SIZE = 3
UPSAMPLE_FACTOR = 2
KERNEL_SIZE_1x1 = 1
STRIDES_1x1 = 1


In [3]:
def single_conv_block(inputs, n_filters, dropout_prob):
    """
    Single Convolutional downsampling block
    inputs (batch_size,T,H,W,C)
    
    """
    conv = Conv2D(filters=n_filters, kernel_size=KERNEL_SIZE, strides=1, padding='same', kernel_initializer='he_normal')(inputs) # (batch_size,T,H,W,C)
    if dropout_prob > 0: # if dropout, add Dropout layer
        conv = Dropout(dropout_prob)(conv)
    norm = tf.keras.layers.GroupNormalization(groups=4, axis=-1)(conv) # (batch_size,T,H,W,n_filters)
    relu = tf.keras.layers.ReLU()(norm) # (batch_size,T,H,W,n_filters)

    return relu

def conv_block(inputs, n_filters, dropout_prob):
    """
    Convolutional downsampling block
    
    Arguments:
        inputs -- Input tensor (batch_size,T,H,W,C)
        n_filters -- Number of filters for the convolutional layers
        dropout_prob -- Dropout probability
        max_pooling -- Use MaxPooling2D to reduce the spatial dimensions of the output volume
    Returns: 
        next_layer, skip_connection --  Next layer and skip connection outputs
    """

    relu1 = single_conv_block(inputs=inputs, n_filters=n_filters, dropout_prob=dropout_prob) # (batch_size,T,H,W,n_filters)
    relu2 = single_conv_block(inputs=relu1, n_filters=n_filters, dropout_prob=dropout_prob) # (batch_size,T,H,W,n_filters)
    output = tf.concat([relu1, relu2], axis=-1) # (batch_size,T,H,W, n_filters*2)

    return output


def conv_pool_block(inputs, n_filters, dropout_prob):
    """
    inputs (batch_size,T,H,W,C)
    """
    skip_connection = conv_block(inputs=inputs, n_filters=n_filters, dropout_prob=dropout_prob) # (batch_size,T,H,W,n_filters)
    batch_size, T, H, W, n_filters_2 = skip_connection.shape
    temp = tf.reshape(skip_connection, shape=(batch_size*T,  H, W, n_filters_2)) # (batch_size*T,H,W,n_filters_2)
    output = MaxPooling2D(pool_size=POOL_SIZE, strides=STRIDES, padding='same')(temp) # (batch_size*T,(H-1)//2+1,(W-1)//2+1,n_filters_2)
    batch_size_T, H2, W2, n_filters = output.shape
    output_reshaped = tf.reshape(output, shape=(batch_size, T, H2, W2, n_filters)) 

    return output_reshaped, skip_connection



In [4]:
# debug
batch_size, T, H, W, C = 2,4,8,8,6
dropout_prob = 0.15
input_shape = [batch_size, T, H, W, C]
inputs = tf.random.normal(shape=input_shape)
strides = 2
conv_pool_block_outputshape_H = math.floor((H - 1) / strides) + 1
n_filters = 32

assert(single_conv_block(inputs, n_filters, dropout_prob).shape == [batch_size,T,H,W,n_filters])
assert(conv_block(inputs, n_filters, dropout_prob).shape == [batch_size,T,H,W, n_filters*2])
assert(conv_pool_block(inputs, n_filters, dropout_prob)[0].shape == [batch_size,T,conv_pool_block_outputshape_H,conv_pool_block_outputshape_H, n_filters*2])


In [5]:

def single_upconv_block(input, n_filters):
    """
    Convolution unit with Batch Normalization and Relu:

    Arguments:
        input -- Input tensor (batch_size,H,W,C)
        n_filters -- number of filter for the convolutional layers

    Returns: 
        conv -- Tensor output
    """
    conv = Conv2D(n_filters,kernel_size=KERNEL_SIZE, padding="same",kernel_initializer='he_normal')(input)
    norm = tf.keras.layers.BatchNormalization()(conv) # (batch_size,H,W,n_filters)
    relu = tf.keras.layers.ReLU()(norm) # (batch_size,H,W,n_filters)
    return relu


def conv_2D_transpose(input, n_filters):
    return Conv2DTranspose(filters=n_filters, kernel_size=KERNEL_SIZE, strides=STRIDES,padding='same')(input) #(batch_size, H*STRIDES, W*STRIDES, n_filters)


def up_conv_block(previous_layer, skip_connexion, n_filters):
    """
    Convolutional upsampling block
    
    Arguments:
        previous_layer -- Input tensor from previous layer (batch_size, H, W, C1)
        skip_connexion -- Input tensor from previous skip layer (batch_size, H*STRIDES, W*STRIDES, C2)
        n_filters -- Number of filters for the convolutional layers

    Returns: 
        conv -- Tensor output (batch_size, H*STRIDES, W*STRIDES, n_filters*2)
    """
    up = conv_2D_transpose(previous_layer, n_filters) # (batch_size, H*STRIDES, W*STRIDES, n_filters)
    merge = tf.concat([up, skip_connexion], axis=-1) # (batch_size, H*STRIDES, W*STRIDES, n_filters + C2)
                        
    out1 = single_upconv_block(merge, n_filters=n_filters) # (batch_size, H*STRIDES, W*STRIDES, n_filters)
    out2 = single_upconv_block(out1, n_filters=n_filters) # (batch_size, H*STRIDES, W*STRIDES, n_filters)
    
    output = tf.concat([out1, out2], axis=-1) # (batch_size, H*STRIDES, W*STRIDES, n_filters*2)
    return output

In [6]:
#debug
batch_size, H, W, C1, C2 = 2,8,8,6,7
input_shape = [batch_size, H, W, C1]
inputs = tf.random.normal(shape=input_shape)
skip_connexion_shape = [batch_size, H*STRIDES, W*STRIDES, C2]
skip_connexion = tf.random.normal(shape=skip_connexion_shape)

n_filters = 32

assert(single_upconv_block(inputs, n_filters).shape == [batch_size, H, W, n_filters])
previous_layer = conv_2D_transpose(inputs, n_filters)
assert(previous_layer.shape == [batch_size, H*STRIDES, W*STRIDES, n_filters])
assert(up_conv_block(inputs, skip_connexion, n_filters).shape == [batch_size, H*STRIDES, W*STRIDES, n_filters*2])



In [7]:
def compute_attention_masks(input, G, key_dim):
    """
       Temporal Lightweight Attention Encoder 
    
    Arguments:
        input  -- Input tensor (batch, T, H, W, C)
        G -- number of attention heads
        key_dim -- Dimension of the key and query vectors.
    Returns:
        attention_mask -- attention mask (batch_size, G, T, H, W)
    """
    Q_shape = [G, key_dim]
    Q0 = tf.random.normal(shape=Q_shape)
    Q = tf.Variable(initial_value=Q0, shape=Q_shape, trainable=True) # (G, key_dim)

    C = input.shape[-1]
    Wk_shape = (G, key_dim, C)
    Wk0 = tf.random.normal(shape=Wk_shape)
    Wk = tf.Variable(initial_value=Wk0, shape = Wk_shape, trainable=True) # (G, key_dim, C)
    input_bcast = input[:,tf.newaxis,tf.newaxis,...]                   # (batch, 1, 1,       T, H, W,C)
    Wk_bcast = Wk[tf.newaxis,:,:, tf.newaxis,tf.newaxis, tf.newaxis,:] # (1    , G, key_dim, 1, 1, 1,C)
    K = tf.reduce_sum(input_bcast * Wk_bcast, axis=-1)                # (batch, G, key_dim, T, H, W)
    Q_bcast = Q[tf.newaxis, :, :, tf.newaxis, tf.newaxis, tf.newaxis] # (1,     G, key_dim, 1, 1, 1)
    QK = tf.reduce_sum(Q_bcast*K, axis=2) # (batch, G, T, H, W)

    attention_mask = tf.nn.softmax(QK/tf.math.sqrt(float(key_dim)), axis=2) # (batch_size, G, T, H, W)
    return attention_mask



In [8]:
#debug
batch_size, T, H, W, C = 2, 8, 64, 64, 6
input_shape = [batch_size, T, H, W, C]
input = tf.random.normal(shape=input_shape)
G, key_dim = 9, 32

attention_mask = compute_attention_masks(input, G, key_dim)
assert(attention_mask.shape == [batch_size, G, T, H, W])

In [9]:

def conv_one_one(input, n_filters):
    """
    Convolution 1x1 bloc
    
    Arguments:
        input_size -- Input shape (batch_size, H, W, C)
        n_filters -- Number of filters for the convolutional layers

    Returns: 
        conv output Tensor (batch_size, H, W, n_filters)
    """

    conv = Conv2D(filters=n_filters, kernel_size=KERNEL_SIZE_1x1, strides=STRIDES_1x1, padding='same', activation='relu')(input) # (batch_size, H, W, n_filters)
    return conv 

def upsample_attention(attention_mask):
    """"
        Attention mask upsampling

        Arguments:
            attention_mask -- Input shape (batch_size, G, T, H, W)
        Returns:
            up_attention -- Output Tensor [batch_size, G, T, H*UPSAMPLE_FACTOR, W*UPSAMPLE_FACTOR]

    """
    batch_size, G, T, H, W = attention_mask.shape
    attention_mask_reshaped = tf.reshape(attention_mask, shape=[batch_size*G, T, H, W])
    up_attention = UpSampling2D(size=UPSAMPLE_FACTOR, interpolation='bilinear', data_format='channels_first')(attention_mask_reshaped) # (batch_size*G, T, H*UPSAMPLE_FACTOR, W*UPSAMPLE_FACTOR)
    up_attention = tf.reshape(up_attention, shape=[batch_size, G, T, H*UPSAMPLE_FACTOR, W*UPSAMPLE_FACTOR])
    # so the upsampled dim are the last2
    return up_attention

def block_wise_temporal_ws(input, attention_mask):
    """
    Arguments:
        input -- (batch_size, T, H, W, C)
        attention_mask -- (batch_size, G, T, H, W)
    Returns :
        weighted_sum (batch_size, H, W, C)
    
    """
    input_reshaped = tf.transpose(input, perm=[0,4,1,2,3]) # (batch_size, C, T, H, W)
    batch_size, C, T, H, W = input_reshaped.shape
    input_reshaped = tf.reshape(input_reshaped, shape = [batch_size, G, C//G, T, H, W])
    attention_mask_bcast = attention_mask[:, :, tf.newaxis, ...]
    weighted_sum = tf.math.reduce_sum(input_reshaped * attention_mask_bcast, axis=3) # (batch_size, G, C//G, H, W)
    weighted_sum = tf.reshape(weighted_sum, shape=[batch_size, C, H, W])
    weighted_sum = tf.transpose(weighted_sum, perm=[0,2,3,1]) # (batch_size, H, W, C)
    return weighted_sum

def temporal_ws_and_conv_one_one(input, attention_mask, n_filters):
    ws = block_wise_temporal_ws(input, attention_mask)
    return conv_one_one(ws, n_filters)

In [10]:
# debug conv_one_one & upsample_attention
batch_size, T, H, W = 2, 8, 64, 64
G = 4
attention_mask_shape = [batch_size, G, T, H, W]
input = tf.random.normal(shape=attention_mask_shape)
key_dim = 32
C = 6
n_filters = 4
input_one_one = tf.random.normal(shape=[batch_size, H, W, C])

assert(upsample_attention(input).shape == [batch_size, G, T, H*UPSAMPLE_FACTOR, W*UPSAMPLE_FACTOR])
assert(conv_one_one(input_one_one,n_filters).shape == [batch_size, H, W, n_filters])

In [11]:
#debug block_wise_temporal_ws
batch_size, T, H, W = 2, 8, 64, 64
C = 12
G = 4
input = tf.random.normal(shape=[batch_size, T, H, W, C]) # output from conv block
attention_mask = tf.random.normal(shape=[batch_size, G, T, H, W]) 

assert(block_wise_temporal_ws(input,attention_mask).shape == [batch_size, H, W, C])

In [12]:
def unet_model(input, n_filters, G, dropout_prob, key_dim, L):
    """
    Unet model model with LTAE temporal attention head
    
    Arguments:
        inputs -- Input tensor (batch_size, T, H, W, C)
        n_filters -- Number of filters for the convolutional layers
        n_classes -- Number of output classes
        G -- number of attention heads
        key_dim -- dimension of encoding of the attention head
        L  -- number of layer/levels in the UNET architecture
        dropout_prob -- Dropout_prob probability

    Returns: 
        outputs -- list of output tensor
    """

    # encoder
    skip_connections = []
    input_conv = input 
    for _ in range(L-1):
        conv, skip_connection = conv_pool_block(input_conv, n_filters, dropout_prob)
        skip_connections.append(skip_connection)
        input_conv = conv
    skip_connection = conv_block(input_conv, n_filters, dropout_prob)
    skip_connections.append(skip_connection)
    assert(len(skip_connections) == L)

    # decoder
    attention_mask = compute_attention_masks(skip_connections[-1], G, key_dim)

    # apply temporal attention mask and conv1x1
    for i in range(L-1, -1, -1): #(L-1, L-2, ..., 1)
        skip_connections[i] = temporal_ws_and_conv_one_one(skip_connections[i], attention_mask, n_filters)
        if i > 0:
            attention_mask = upsample_attention(attention_mask)

    # decoder
    outputs = [skip_connections[-1]]
    previous_layer = skip_connections[-1]
    for i in range(L-1, 0, -1): #(L-1, L-2, ..., 1)
        previous_layer = up_conv_block(previous_layer, skip_connections[i-1], n_filters)
        outputs.append(previous_layer)

    return reversed(outputs)

In [13]:
# debug
n_filters, G, dropout_prob, key_dim, L = 8, 2, 0.15, 32, 4
batch_size,T,H,W,C = 2, 4, 64, 64, 32
input_shape = (batch_size,T,H,W,C)
input = tf.random.normal(shape=input_shape)

outputs = unet_model(input, n_filters, G, dropout_prob, key_dim, L)
