<a href="https://colab.research.google.com/github/ProtossDragoon/Deep-Learning-with-Python/blob/main/Segmentation_Augmentation_Layer_With_TPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Segmentation Augmentation Layer With TPU

## Author

name : Janghoo Lee <br>
github : https://github.com/ProtossDragoon <br>
contact : dlwkdgn1@naver.com <br>
published date : November, 2021

## ThridParty

- github : https://github.com/qubvel/segmentation_models

## Related Issue

- issue : https://github.com/tensorflow/tensorflow/issues/53051

## Environment

In [None]:
import tensorflow as tf
import numpy as np
import os, sys
import time

### Setup CPU/GPU/TPU

In [None]:
if 'COLAB_TPU_ADDR' in os.environ: # Check TPU
    assert 'COLAB_TPU_ADDR' in os.environ, 'Missing TPU.'
    tf_master = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])
    TPU_ADDRESS = tf_master
            
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
    tf.config.experimental_connect_to_cluster(resolver) # initialize the colab tpu
    tf.tpu.experimental.initialize_tpu_system(resolver) # https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/keras_mnist_tpu.ipynb?hl=ko&authuser=2#scrollTo=Hd5zB1G7Y9-7

    TRAINING_PARALLEL_STRATEGY = tf.distribute.TPUStrategy(resolver) # Choose distribution strategy for parallel processing.
    tpus = tf.config.list_logical_devices('TPU')
    print(f'total {len(tpus)} of TPU devices: {tpus}')

    USE_TPU = True
    USE_GPU = False
else: 
    USE_TPU = False
    print('TPU Not found')
    TRAINING_PARALLEL_STRATEGY = None
    device_name = tf.test.gpu_device_name()
    if not device_name:
        USE_GPU = False
        print('GPU device not found.')
    else:
        USE_GPU = True
        !nvidia-smi -L
        gpus = tf.config.list_logical_devices('GPU')
        print(f'total {len(gpus)} GPU devices: {gpus}')

if USE_TPU:
    CURRENT_DEVICE = 'tpu'
elif USE_GPU:
    CURRENT_DEVICE = 'gpu'
else:
    CURRENT_DEVICE = 'cpu'

INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.


INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.






INFO:tensorflow:Initializing the TPU system: grpc://10.25.231.146:8470


INFO:tensorflow:Initializing the TPU system: grpc://10.25.231.146:8470


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


total 8 of TPU devices: [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]


## Create dummy data for segmentation

In [None]:
def make_dummy_data(
    image_input_hw:tuple, 
    mask_input_hw:tuple,
    class_n:int,
    data_n:int=50,
):
    _default_channel_n = 3
    image_input_shape = [data_n] + list(image_input_hw) + [_default_channel_n]
    mask_input_shape = [data_n] + list(mask_input_hw) + [class_n]
    return np.zeros(image_input_shape), np.zeros(mask_input_shape)

In [None]:
image_input_hw = (224, 224)
mask_input_hw = (224, 224)
class_n = 10

im, ma = make_dummy_data(image_input_hw, mask_input_hw, class_n=class_n)
print(f'image : {im.shape}')
print(f'mask : {ma.shape}')

image : (50, 224, 224, 3)
mask : (50, 224, 224, 10)


In [None]:
# tf.data.Dataset 객체 생성을 도와주는 함수
def make_tensorflow_dataset(
    batched_images,
    batched_masks,
    batch_size:int=10,
):
    """Helper function to create tf.data.Dataset object
    """
    def assert_valid_shape(
        batched_images_shape,
        batched_masks_shape,
    ):
        """Shape assertion function
        """
        assert len(batched_images_shape) == 4
        assert len(batched_masks_shape) == 4
        assert batched_images_shape[0] == batched_masks_shape[0]
        assert batched_images_shape[1] == batched_masks_shape[1]
        assert batched_images_shape[2] == batched_masks_shape[2]

    assert_valid_shape(batched_images.shape, batched_masks.shape)

    images_tf = tf.data.Dataset.from_tensor_slices(batched_images)
    masks_tf = tf.data.Dataset.from_tensor_slices(batched_masks)
    dataset = tf.data.Dataset.zip((images_tf, masks_tf))
    dataset = dataset.shuffle(buffer_size=100).batch(batch_size, drop_remainder=True)
    
    return dataset

In [None]:
# tf.data.Dataset 객체 생성. 
# TensorFlow 의 모델에 입력할 때에는 이 객체를 주로 사용.
tf_dataset = make_tensorflow_dataset(im, ma)

# 데이터 1개만 꺼내보기
def get_a_batch(tf_dataset):
    debug_iter = iter(tf_dataset)
    _im, _ma = batch = next(debug_iter)
    return batch

im, ma = get_a_batch(tf_dataset)
print(f'{type(im)} typed image tensor: {im.shape}')
print(f'{type(ma)} typed mask tensor: {ma.shape}')

<class 'tensorflow.python.framework.ops.EagerTensor'> typed image tensor: (10, 224, 224, 3)
<class 'tensorflow.python.framework.ops.EagerTensor'> typed mask tensor: (10, 224, 224, 10)


## Create model for segmentation

In [None]:
GIT_REPO_THIRDPARTY_NAME = 'segmentation_models'
!git clone https://github.com/qubvel/{GIT_REPO_THIRDPARTY_NAME}.git -q
!pip install {GIT_REPO_THIRDPARTY_NAME} -q

fatal: destination path 'segmentation_models' already exists and is not an empty directory.


In [None]:
def auto_tpu(device='cpu'):
    """Automatically open context manager
    If your colab environment is on 'tpu'
    """
    def decorator(fn):
        def wrapper(*args, **kwargs):
            s = time.time()
            if device == 'tpu':
                with TRAINING_PARALLEL_STRATEGY.scope():
                    ret = fn(*args, **kwargs)
            else:
                ret = fn(*args, **kwargs)
            e = time.time()
            print(f'device: {repr(device)}, time elapse: {e-s:.3} second(s)')
            return ret
        return wrapper
    return decorator

In [None]:
tf.keras.backend.clear_session()

import segmentation_models as sm
sm.set_framework('tf.keras')

@auto_tpu(device=CURRENT_DEVICE)
def create_segmentation_model(class_n):
    model = sm.Unet(
        'vgg16',
        classes=class_n, 
        activation='softmax',
        encoder_weights=None,
        encoder_freeze=False,
    )
    return model

model = create_segmentation_model(class_n)
model.summary()

device: 'tpu', time elapse: 2.28 second(s)
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, None, None,  0           []                               
                                 3)]                                                              
                                                                                                  
 block1_conv1 (Conv2D)          (None, None, None,   1792        ['input_1[0][0]']                
                                64)                                                               
                                                                                                  
 block1_conv2 (Conv2D)          (None, None, None,   36928       ['block1_conv1[0][0]']           
                                64)                

In [None]:
def get_loss(class_n):
    if class_n == 1:
        return sm.losses.BinaryFocalLoss()
    else:
        return sm.losses.CategoricalFocalLoss()

def get_metrics():
    return sm.metrics.IOUScore(threshold=0.5)

@auto_tpu(device=CURRENT_DEVICE)
def run(model):
    model.compile('adam', get_loss(class_n), get_metrics())
    model.fit(tf_dataset)

# training start
run(model)

device: 'tpu', time elapse: 36.1 second(s)


- 지금까지는 문제가 없었다. 일반적인 TPU 동작 예제와 같다.
- 그런데 이제 모델의 앞단에 augmentation 파이프라인을 붙이려고 하면서 문제가 생긴다.
- Classification model 이면 별 문제가 되지 않지만, 우리는 image 와 label(mask) 모두에 augmentation 을 적용해 주어야 하기 때문이다.

## Create model for augmentation

### Functional & Sequential API for segmentation task augmentation

In [None]:
tf.keras.backend.clear_session()

@auto_tpu(device=CURRENT_DEVICE)
def create_augmentation_model(
    image_input_hw, 
    mask_input_hw, 
    class_n:int
):
    _default_channel_n = 3

    # runtime augmentation pipeline
    seq = tf.keras.Sequential(
        [
            tf.keras.layers.RandomFlip("horizontal"),
            tf.keras.layers.RandomRotation(0.02),
        ],
        name='sequential_augmentation_layers'
    )

    image_input_shape = list(image_input_hw) + [_default_channel_n]
    mask_input_shape = list(image_input_hw) + [class_n]
    x_im = tf.keras.Input(shape=image_input_shape)
    x_ma = tf.keras.Input(shape=mask_input_shape)
    return tf.keras.Model(
        inputs=[x_im, x_ma], 
        outputs=[seq(x_im), seq(x_ma)],
        name='sequential_augmentation_model'
        )

aug_model = create_augmentation_model(
    image_input_hw,
    mask_input_hw,
    class_n
)
aug_model.summary()

device: 'tpu', time elapse: 0.664 second(s)
Model: "sequential_augmentation_model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 input_2 (InputLayer)           [(None, 224, 224, 1  0           []                               
                                0)]                                                               
                                                                                                  
 sequential_augmentation_layers  (None, 224, 224, No  0          ['input_1[0][0]',                
  (Sequential)            

In [None]:
im, ma = aug_model(get_a_batch(tf_dataset))
print(f'{type(im)} typed image tensor: {im.shape}')
print(f'{type(ma)} typed mask tensor: {ma.shape}')

<class 'tensorflow.python.framework.ops.EagerTensor'> typed image tensor: (10, 224, 224, 3)
<class 'tensorflow.python.framework.ops.EagerTensor'> typed mask tensor: (10, 224, 224, 10)


- 우리의 segmentation model 은 model.fit(tf_dataset) 로 훈련한다.
- 이 model.fit(tf_dataset) 을 그대로 두고 사용하기 위해서는 도대체 어떻게 해야할까?
- tf_datset 은 (image, mask) tuple 로 이루어져 있다.
- model.fit(x=image, y=mask) 로 언패킹되어 들어가는 셈이다.


In [None]:
class AugConcatedSegModel(tf.keras.Model):
    def __init__(
        self,
        inputs=None,
        outputs=None,
        augmentation_model=None, 
        **kwargs
    ):
        super().__init__(inputs=inputs, outputs=outputs, **kwargs)
        self.augmentation_model = augmentation_model

    def train_step(self, data):
        im, ma = data
        im, ma = self.augmentation_model((im, ma))

        with tf.GradientTape() as tape:
            ma_pred = self(im, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(ma, ma_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(ma, ma_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

In [None]:
tf.keras.backend.clear_session()

def new_concatenated_model(
    image_input_hw,
    mask_input_hw,
    class_n
):
    seg_model = create_segmentation_model(class_n)
    aug_model = create_augmentation_model(
        image_input_hw, mask_input_hw, class_n)
    
    _default_channel_n = 3
    image_input_shape = list(image_input_hw) + [_default_channel_n]

    @auto_tpu(device=CURRENT_DEVICE)
    def create():
        im = seg_model.input
        model = AugConcatedSegModel(
            inputs=im,
            outputs=seg_model(im),
            augmentation_model=aug_model,
            name='seg_model_train_with_aug'
        )
        return model
    
    model = create()
    return model

new_seg_model = new_concatenated_model(
    image_input_hw,
    mask_input_hw,
    class_n
)
new_seg_model.summary()

device: 'tpu', time elapse: 2.3 second(s)
device: 'tpu', time elapse: 0.645 second(s)
device: 'tpu', time elapse: 0.5 second(s)
Model: "seg_model_train_with_aug"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, None, None, 3)]   0         
                                                                 
 model (Functional)          (None, None, None, 10)    23753578  
                                                                 
 sequential_augmentation_mod  [(None, 224, 224, 3),    0         
 el (Functional)              (None, 224, 224, 10)]              
                                                                 
Total params: 23,753,578
Trainable params: 23,749,546
Non-trainable params: 4,032
_________________________________________________________________


In [None]:
# start training
run(new_seg_model)

InvalidArgumentError: ignored

In [None]:
# start training
tf.config.set_soft_device_placement(True)
run(new_seg_model)

InvalidArgumentError: ignored



- 하지만 위와 같은 형태로 돌리게 된다면, image 는 좌우반전을 시켰지만, mask 는 좌우반전을 시키지 않는 상황이 나타나게 된다. 
- 따라서, subclassing api 를 통해 augmentation 을 위한 모델을 만들고 상태를 제대로 관리해줄 수 있도록 해야 한다.

## Debug area

In [None]:
#FIXME : 실험중. 이게 진짜 TPU 잘못인가..? Random 때문인가?
tf.keras.backend.clear_session()

@auto_tpu(device=CURRENT_DEVICE)
def create_augmentation_model(
    image_input_hw, 
    mask_input_hw, 
    class_n:int
):
    _default_channel_n = 3

    # 내가 추가하고 싶은 runtime augmentation pipeline
    im_seq = tf.keras.Sequential(
        [
            tf.keras.layers.Conv2D(_default_channel_n, (3,3), padding='same'),
            tf.keras.layers.Conv2D(_default_channel_n, (3,3), padding='same'),
        ],
        name='sequential_image_augmentation_layers_debug'
    )
    ma_seq = tf.keras.Sequential(
        [
            tf.keras.layers.Conv2D(class_n, (3,3), padding='same'),
            tf.keras.layers.Conv2D(class_n, (3,3), padding='same'),
        ],
        name='sequential_mask_augmentation_layers_debug'
    )

    image_input_shape = list(image_input_hw) + [_default_channel_n]
    mask_input_shape = list(image_input_hw) + [class_n]
    x_im = tf.keras.Input(shape=image_input_shape)
    x_ma = tf.keras.Input(shape=mask_input_shape)
    return tf.keras.Model(
        inputs=[x_im, x_ma], 
        outputs=[im_seq(x_im), ma_seq(x_ma)],
        name='sequential_augmentation_model_debug'
        )

aug_model = create_augmentation_model(
    image_input_hw,
    mask_input_hw,
    class_n
)
aug_model.summary()

In [None]:
tf.keras.backend.clear_session()

def test_time(fn, device='cpu'):
    def model_creation():
        backbone_model = tf.keras.applications.VGG16(
            input_shape=(224, 224, 3),
            include_top=False,
            weights=None,
        )
        return backbone_model

    if device == 'tpu':
        with TRAINING_PARALLEL_STRATEGY.scope():
            backbone_model = model_creation()
    else:
        backbone_model = model_creation()

    _s = time.time()
    fn(backbone_model, _d, device=device)
    _e = time.time()
    print(f'device:{repr(device)}, caching... time elapsed:{_e-_s}')
    s = time.time()
    fn(backbone_model, d, device=device)
    e = time.time()
    print(f'device:{repr(device)}, time elapsed:{e-s}')

_d = np.zeros([10, 224, 224, 3], dtype=np.uint8)
d = np.ones([10, 224, 224, 3], dtype=np.uint8)

def run(model, image_data, device):
    if device == 'tpu':
        with TRAINING_PARALLEL_STRATEGY.scope():
            model(image_data)
    else:
        model(image_data)

test_time(run, device=CURRENT_DEVICE)