In [1]:
import os, pickle
import tensorflow as tf
import tensorflow.keras.backend as K
from datetime import datetime
import numpy as np
from sklearn.model_selection import KFold
import src.training as tr_fn
from loguru import logger
from config import CFG, GCFG

AUTO = tf.data.experimental.AUTOTUNE
CFG2 = GCFG()
class_dict = pickle.load(open('src/class_dict.pkl', 'rb'))

2023-10-28 02:04:45.804816: I tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc:242] Libtpu path is: libtpu.so
D1028 02:04:45.957949690   76915 config.cc:175]                        gRPC EXPERIMENT tcp_frame_size_tuning               OFF (default:OFF)
D1028 02:04:45.957971562   76915 config.cc:175]                        gRPC EXPERIMENT tcp_rcv_lowat                       OFF (default:OFF)
D1028 02:04:45.957976735   76915 config.cc:175]                        gRPC EXPERIMENT peer_state_based_framing            OFF (default:OFF)
D1028 02:04:45.957980979   76915 config.cc:175]                        gRPC EXPERIMENT memory_pressure_controller          OFF (default:OFF)
D1028 02:04:45.957985248   76915 config.cc:175]                        gRPC EXPERIMENT unconstrained_max_quota_buffer_size OFF (default:OFF)
D1028 02:04:45.957989468   76915 config.cc:175]                        gRPC EXPERIMENT event_engine_client                 OFF (default:OFF)
D1028 02:04:45.95799373

In [6]:
def decode_image(image_data, CFG):
    image = tf.image.decode_jpeg(image_data, channels=3)  # image format uint8 [0,255]
    image = tf.reshape(image, [*CFG.IMAGE_SIZE, 3]) # explicit size needed for TPU
    image = tf.cast(image, tf.float32)
    return image

def read_labeled_tfrecord(CFG, example):
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'dataset': tf.io.FixedLenFeature([], tf.int64),
        'longitude': tf.io.FixedLenFeature([], tf.float32),
        'latitude': tf.io.FixedLenFeature([], tf.float32),
        'norm_date': tf.io.FixedLenFeature([], tf.float32),
        'class_priors': tf.io.FixedLenFeature([], tf.float32),
        'class_id': tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, feature_description)
    image = decode_image(example['image'], CFG)
    label = tf.cast(example['class_id'], tf.int32)
    return image, label

def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transform matrix which transforms indices

    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.0
    shear = math.pi * shear / 180.0

    def get_3x3_mat(lst):
        return tf.reshape(tf.concat([lst], axis=0), [3, 3])

    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1], dtype='float32')
    zero = tf.constant([0], dtype='float32')

    rotation_matrix = get_3x3_mat([c1, s1, zero, -s1, c1, zero, zero, zero, one])
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)

    shear_matrix = get_3x3_mat([one, s2, zero, zero, c2, zero, zero, zero, one])
    # ZOOM MATRIX
    zoom_matrix = get_3x3_mat([one / height_zoom, zero, zero, zero, one / width_zoom, zero, zero, zero, one])
    # SHIFT MATRIX
    shift_matrix = get_3x3_mat([one, zero, height_shift, zero, one, width_shift, zero, zero, one])

    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))


def transform(image, CFG):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = CFG.IMAGE_SIZE[0]
    XDIM = DIM % 2  # fix for size 331   

    rot = CFG.ROT_ * tf.random.normal([1], dtype='float32')
    shr = CFG.SHR_ * tf.random.normal([1], dtype='float32')
    h_zoom = 1.0 + tf.random.normal([1], dtype='float32') / CFG.HZOOM_
    w_zoom = 1.0 + tf.random.normal([1], dtype='float32') / CFG.WZOOM_
    h_shift = CFG.HSHIFT_ * tf.random.normal([1], dtype='float32')
    w_shift = CFG.WSHIFT_ * tf.random.normal([1], dtype='float32')

    # GET TRANSFORMATION MATRIX
    m = get_mat(rot, shr, h_zoom, w_zoom, h_shift, w_shift)

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat(tf.range(DIM // 2, -DIM // 2, -1), DIM)
    y = tf.tile(tf.range(-DIM // 2, DIM // 2), [DIM])
    z = tf.ones([DIM * DIM], dtype='int32')
    idx = tf.stack([x, y, z])

    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m, tf.cast(idx, dtype='float32'))
    idx2 = K.cast(idx2, dtype='int32')
    idx2 = K.clip(idx2, -DIM // 2 + XDIM + 1, DIM // 2)

    # FIND ORIGIN PIXEL VALUES
    idx3 = tf.stack([DIM // 2 - idx2[0,], DIM // 2 - 1 + idx2[1,]])
    d = tf.gather_nd(image, tf.transpose(idx3))

    return tf.reshape(d, [DIM, DIM, 3])


def prepare_image(img, CFG, augment=True, dim=256):
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.cast(img, tf.float32) / 255.0

    if augment:
        img = transform(img, CFG)
        img = tf.image.random_flip_left_right(img)
        # img = tf.image.random_hue(img, 0.01)
        img = tf.image.random_saturation(img, 0.7, 1.3)
        img = tf.image.random_contrast(img, 0.8, 1.2)
        img = tf.image.random_brightness(img, 0.1)

    img = tf.reshape(img, [CFG.IMAGE_SIZE[0], CFG.IMAGE_SIZE[0], 3])

    return img


def get_dataset(
    files, CFG, augment=False, shuffle=False, repeat=False, labeled=True, batch_size=16, dim=256
    ):
    ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTO)
    ds = ds.cache()

    if repeat:
        ds = ds.repeat()

    if shuffle:
        ds = ds.shuffle(1024 * 8)
        opt = tf.data.Options()
        opt.experimental_deterministic = False
        ds = ds.with_options(opt)

    if labeled:
        ds = ds.map(lambda example: read_labeled_tfrecord(CFG, example), num_parallel_calls=AUTO)
    else:
        ds = ds.map(lambda example: read_unlabeled_tfrecord(example), num_parallel_calls=AUTO)

    ds = ds.map(
        lambda img, imgname_or_label: (prepare_image(
            img, CFG, augment=augment, dim=dim), imgname_or_label), num_parallel_calls=AUTO
    )

    ds = ds.batch(batch_size * CFG.REPLICAS)
    ds = ds.prefetch(AUTO)
    return ds

In [7]:
def get_history(model, fold, files_train, files_valid, CFG):
    logger.info("Training...")
    history = model.fit(
        get_dataset(files_train, CFG),
        epochs=CFG.EPOCHS,
        callbacks=tr_fn.make_callbacks(CFG),
        steps_per_epoch=CFG.STEPS_PER_EPOCH,
        validation_data=get_dataset(files_valid, CFG),  # class_weight = {0:1,1:2},
        verbose=CFG.VERBOSE,
    )
    return history

def get_gcs_path(image_size):
    GCS_PATH_SELECT = {
        192: f"{CFG2.GCS_REPO}/tfrecords-jpeg-192x192",
        224: f"{CFG2.GCS_REPO}/tfrecords-jpeg-224x224v2",
        384: f"{CFG2.GCS_REPO}/tfrecords-jpeg-384x384",
        512: f"{CFG2.GCS_REPO}/tfrecords-jpeg-512x512",
    }
    GCS_PATH = GCS_PATH_SELECT[image_size]    
    return GCS_PATH

In [4]:
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)

CFG2.REPLICAS = strategy.num_replicas_in_sync
print("Number of accelerators: ", strategy.num_replicas_in_sync)

INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.
INFO:tensorflow:Initializing the TPU system: local


2023-10-28 02:04:58.898527: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5642a0268de0 initialized for platform TPU (this does not guarantee that XLA will be used). Devices:
2023-10-28 02:04:58.898567: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): TPU, 2a886c8
2023-10-28 02:04:58.898579: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (1): TPU, 2a886c8
2023-10-28 02:04:58.898589: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (2): TPU, 2a886c8
2023-10-28 02:04:58.898599: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (3): TPU, 2a886c8
2023-10-28 02:04:58.898609: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (4): TPU, 2a886c8
2023-10-28 02:04:58.898619: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (5): TPU, 2a886c8
2023-10-28 02:04:58.898629: I tensorflow/compiler/xla/service/service.cc:176]   StreamEx

INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO

In [8]:
GCS_PATH = get_gcs_path(CFG2.IMAGE_SIZE[0])
skf = KFold(n_splits=CFG2.FOLDS, shuffle=True, random_state=CFG2.SEED)
oof_pred = []
oof_tar = []
oof_val = []
oof_names = []
oof_folds = []

for fold, (idxT, idxV) in enumerate(skf.split(np.arange(107))):
    # DISPLAY FOLD INFO
    print("#" * 25)
    print("#### FOLD", fold + 1)

    files_train = tf.io.gfile.glob([f"{GCS_PATH}/train{x:02d}*.tfrec" for x in idxT])
    files_valid = tf.io.gfile.glob(f"{GCS_PATH}/train{x:02d}*.tfrec" for x in idxV)
    files_test = tf.io.gfile.glob(f"{GCS_PATH}/val*.tfrec")

    CFG2.NUM_TRAINING_IMAGES = tr_fn.count_data_items(files_train)
    CFG2.NUM_VALIDATION_IMAGES = tr_fn.count_data_items(files_valid)

    CFG = CFG(REPLICAS=CFG2.REPLICAS, NUM_TRAINING_IMAGES=CFG2.NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES=CFG2.NUM_VALIDATION_IMAGES)

    logger.debug(
        f"# Image Size {CFG2.IMAGE_SIZE} with Model {CFG2.MODEL} and batch_sz {CFG2.BASE_BATCH_SIZE * CFG2.REPLICAS}"
    )

    logger.info("Build & Compile Model...")
    K.clear_session()
    with strategy.scope():
        model = tr_fn.create_model(CFG, class_dict)
        opt = tr_fn.create_optimizer(CFG)
        loss = tf.keras.losses.SparseCategoricalCrossentropy()

        top3_acc = tf.keras.metrics.SparseTopKCategoricalAccuracy(
            k=3, name='sparse_top_3_categorical_accuracy'
        )
    model.compile(optimizer=opt, loss=loss, metrics=['sparse_categorical_accuracy', top3_acc])

    logger.info("Training Model...")
    # TRAIN
    history = get_history(model, fold, files_train, files_valid, CFG)

    # PREDICT OOF USING TTA
    logger.info("Predicting OOF with TTA...")
    ds_valid = get_dataset(files_valid, CFG),
    ct_valid = tr_fn.count_data_items(files_valid)
    STEPS = CFG.TTA * ct_valid / CFG.BATCH_SIZES / 4 / CFG.REPLICAS
    pred = model.predict(ds_valid, steps=STEPS, verbose=CFG.VERBOSE)[
        : CFG.TTA * ct_valid,
    ]
    oof_pred.append(np.mean(pred.reshape((ct_valid, CFG.TTA), order="F"), axis=1))
    # oof_pred.append(model.predict(get_dataset(files_valid,dim=CFG.IMG_SIZES),verbose=1))

    # GET OOF TARGETS AND NAMES
    ds_valid = get_dataset(
        files_valid,
        CFG,
        augment=False,
        repeat=False,
        dim=CFG.IMG_SIZES,
        labeled=True,
        return_image_names=True,
    )
    oof_tar.append(
        np.array([target.numpy() for img, target in iter(ds_valid.unbatch())])
    )
    oof_folds.append(np.ones_like(oof_tar[-1], dtype="int8") * fold)
    ds = get_dataset(
        files_valid,
        CFG,
        augment=False,
        repeat=False,
        dim=CFG.IMG_SIZES,
        labeled=False,
        return_image_names=True,
    )
    oof_names.append(
        np.array(
            [
                img_name.numpy().decode("utf-8")
                for img, img_name in iter(ds.unbatch())
            ]
        )
    )

    # REPORT RESULTS
    auc = roc_auc_score(oof_tar[-1], oof_pred[-1])
    oof_val.append(np.max(history.history["val_auc"]))
    logger.info(
        f"#### FOLD {fold + 1} OOF AUC without TTA = {oof_val[-1]}, with TTA = {auc}"
    )

#########################
#### FOLD 1


TypeError: 'CFG' object is not callable