# Cassavaのリベンジ用

In [None]:
# efficientnetのロード
# !pip install --quiet efficientnet

In [None]:
# ライブラリーのロード
import numpy as np
import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras.backend as K
from kaggle_datasets import KaggleDatasets
from albumentations import (
    Compose, RandomBrightness, JpegCompression, HueSaturationValue, RandomContrast, HorizontalFlip,
    Rotate, Normalize
)
import albumentations

## TPU or GPUの設定

In [None]:
# TPU or GPU setting
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU {tpu.master()}')
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

In [None]:
N_CLASSES = 5
IMAGE_SIZE = [512, 512]
BATCH_SIZE = 16

In [None]:
CFG["N_CLASSES"]

## データのロードセクション

In [None]:
# データのパスを指定する(tfrecordにすること)
GCS_PATH = KaggleDatasets().get_gcs_path('cassava-leaf-disease-tfrecords-512x512')
FILENAME_COMP = tf.io.gfile.glob(GCS_PATH + '/*.tfrec')

In [None]:
# 数える関数
def count_data_items(filename):
    n = [int(re.compile(r'-([0-9]*)\.').search(filename).group(1)) for name in filenames]
    return np.sum(n)

In [None]:
# affine変換用の行列
# original(https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96)
def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies
        
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear = math.pi * shear / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
        
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape( tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3] )    
    
    # ZOOM MATRIX
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    
    # SHIFT MATRIX
    shift_matrix = tf.reshape( tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3] )
    
    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))

In [None]:
# affine transform
# original(https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96)
# 各種パラメータはalbumentationに合わせたはず...
def affine_transform(image):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = IMAGE_SIZE[0]
    XDIM = DIM%2 #fix for size 331
    
    # rot = 15. * tf.random.normal([1],dtype='float32')
    rot = 45. * tf.random.uniform([1], -1.0, 1.0, dtype='float32')
    # shr = 5. * tf.random.normal([1],dtype='float32') 
    shr = 0. * tf.random.normal([1],dtype='float32') 
    # h_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    h_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    # w_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    w_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    # h_shift = 16. * tf.random.normal([1],dtype='float32') 
    h_shift = 0.0625 * tf.random.uniform([1], -1.0, 1.0, dtype='float32') 
    # w_shift = 16. * tf.random.normal([1],dtype='float32') 
    w_shift = 0.0625 * tf.random.uniform([1], -1.0, 1.0, dtype='float32') 
  
    # GET TRANSFORMATION MATRIX
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

In [None]:
def data_augment(image, label, transpose_arg_p = 0, random_crop_arg_p = 0, 
                 horizontal_flip_arg_p = 0.5, vertical_flip_arg_p = 0.5, 
                 affine_transform_arg_p = 0.5
                ):
    p_random_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_transpose = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_vertical_flip = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_horizontal_flip = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_affine_transform= tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    if p_random_crop <= random_crop_arg_p:
        image = tf.image.random_crop(image, size=[IMAGE_SIZE[0], IMAGE_SIZE[1], 3])
    
    # transpose
    if p_transpose <= transpose_arg_p:
        image = tf.transpose(image, perm=[1, 0, 2])
    
    # vertical_flip
    if p_vertical_flip <= vertical_flip_arg_p:
        image = tf.image.flip_left_right(image)
    
    # horizontal_flip
    if p_vertical_flip <= vertical_flip_arg_p:
        image = tf.image.flip_up_down(image)
    
    # affine_transform
    if p_affine_transform <= affine_transform_arg_p:
        image = affine_transform(image)
    

    
    return image, label

In [None]:
alb_transforms = Compose([
    # albumentations.Cutout(max_h_size=int(IMAGE_SIZE[0] * 0.3), max_w_size=int(IMAGE_SIZE[1] * 0.3), num_holes=1, p=0.5), 
    # RandomBrightness(p=),
    Normalize(), 
])
def aug_fn(image):
    data ={"image":image}
    aug_data = alb_transforms(**data)
    aug_img = aug_data["image"]
    # aug_img = tf.cast(aug_img/255.0, tf.float32)
    return aug_img

def alb_process_data(image, label):
    aug_img = tf.numpy_function(func=aug_fn, inp=[image], Tout=tf.float32)
    return aug_img, label

In [None]:
def decode_image(image_data):
    """
        Decode a JPEG-encoded image to a uint8 tensor.
    """
    image = tf.image.decode_jpeg(image_data, channels=3)
    return image

# フロートに変換することだけに注意
def scale_image(image, label):
    """
        Cast tensor to float and normalizes (range between 0 and 1).
    """
    image = tf.cast(image, tf.float32)
    image /= 255.0
    return image, label

def image_to_float(image, label):
    image = tf.cast(image, tf.float32)
    return image, label


def read_tfrecord(example, labeled=True):
    """
        1. Parse data based on the 'TFREC_FORMAT' map.
        2. Decode image.
        3. If 'labeled' returns (image, label) if not (image, name).
    """
    if labeled:
        TFREC_FORMAT = {
            'image': tf.io.FixedLenFeature([], tf.string), 
            'target': tf.io.FixedLenFeature([], tf.int64), 
        }
    else:
        TFREC_FORMAT = {
            'image': tf.io.FixedLenFeature([], tf.string), 
            'image_name': tf.io.FixedLenFeature([], tf.string), 
        }
    example = tf.io.parse_single_example(example, TFREC_FORMAT)
    image = decode_image(example['image'])
    if labeled:
        label_or_name = tf.cast(example['target'], tf.int32)
        # One-Hot Encoding needed to use "categorical_crossentropy" loss
        label_or_name = tf.one_hot(tf.cast(label_or_name, tf.int32), N_CLASSES)
    else:
        label_or_name = example['image_name']
    return image, label_or_name


def get_dataset(FILENAMES, labeled=True, ordered=False, cached=False, augment=False):
    
    ignore_order = tf.data.Options()
    
    if not ordered:
        ignore_order.experimental_deterministic = False
        dataset = tf.data.Dataset.list_files(FILENAMES)
        dataset = dataset.interleave(tf.data.TFRecordDataset, num_parallel_calls=AUTO)
    else:
        dataset = tf.data.TFRecordDataset(FILENAMES, num_parallel_reads=AUTO)
    
    
    dataset = dataset.with_options(ignore_order)
    # データの読み込み
    dataset = dataset.map(lambda x: read_tfrecord(x, labeled=labeled), num_parallel_calls=AUTO)
    
    # data augmentation
    if augment:
        dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
        # albumentationを使うためにfloatに変換する
        dataset = dataset.map(image_to_float, num_parallel_calls=AUTO)
        dataset = dataset.map(alb_process_data, num_parallel_calls=AUTO)
    else:
        dataset = dataset.map(scale_image, num_parallel_calls=AUTO)  
    dataset = dataset.batch(BATCH_SIZE) 
    return dataset

## データセット表示用

In [None]:
train_dataset = get_dataset(FILENAME_COMP, ordered=True)
dataset = train_dataset.unbatch()
for img, label in dataset:
    plt.imshow(img.numpy())
    break

In [None]:
train_dataset = get_dataset(FILENAME_COMP, ordered=True, augment=True)
dataset = train_dataset.unbatch()
for img, label in dataset:
    plt.imshow(img.numpy())
    break