In [None]:
import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt

from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Input, GlobalAveragePooling2D, Conv2D, MaxPooling2D, Layer, BatchNormalization, ReLU, Add
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.applications import EfficientNetB1
from tensorflow.data import Dataset, TextLineDataset

tf.config.list_physical_devices('GPU')

In [None]:
class BottleneckBlock(Layer):
    def __init__(self, filters, skip_connection, downsample=False, **kwargs):
        super().__init__(**kwargs)
        self.conv1 = Conv2D(filters, 1, padding='same', strides=2 if downsample else 1)
        self.conv2 = Conv2D(filters, 3, padding='same')
        self.conv3 = Conv2D(filters*4, 1, padding='same')
        self.batch_norm1 = BatchNormalization(momentum=0.9)
        self.batch_norm2 = BatchNormalization(momentum=0.9)
        self.batch_norm3 = BatchNormalization(momentum=0.9)
        self.relu = ReLU()
        self.add = Add()
        self.skip_connection = self.build_skip_connection(skip_connection, filters*4)
        
    def build_skip_connection(self, skip_connection, filters):
        if skip_connection not in ('identity', 'projection', 'padding'):
            raise ValueError('skip_connection must be either identity, projection or padding')
            
        if skip_connection == 'identity':
            return lambda x: x
        
        if skip_connection == 'projection':
            return Conv2D(filters, 1, strides=2)
            
        if skip_connection == 'padding':
            # Pad the last dimension to have the same number of channels once we do the addition
            return lambda x: tf.pad(x, paddings=[[0,0], [0,0], [0,0], [0,abs(filters - x.shape[-1])]])

    def call(self, x, training=False):
        skip_outputs = self.skip_connection(x)

        x = self.conv1(x)
        x = self.batch_norm1(x, training=training)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.batch_norm2(x, training=training)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.batch_norm3(x, training=training)
        x = self.add([x, skip_outputs])
        x = self.relu(x)

        return x
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'conv1': self.conv1,
            'conv2': self.conv2,
            'conv3': self.conv3,
            'batch_norm1': self.batch_norm1,
            'batch_norm2': self.batch_norm2,
            'batch_norm3': self.batch_norm3,
            'relu': self.relu,
            'add': self.add,
            'skip_connection': self.skip_connection
        })
        
        return config

In [None]:
def load_resnet50(input_shape):
    resnet50 = Sequential()

    resnet50.add(Input(input_shape))

    resnet50.add(Conv2D(64, 7, strides=2, padding='same', activation='relu'))
    resnet50.add(MaxPooling2D(strides=2))

    # Block 1
    resnet50.add(BottleneckBlock(filters=64, skip_connection='padding'))
    resnet50.add(BottleneckBlock(filters=64, skip_connection='identity'))
    resnet50.add(BottleneckBlock(filters=64, skip_connection='identity'))

    # Block 2
    resnet50.add(BottleneckBlock(filters=128, skip_connection='projection', downsample=True))
    resnet50.add(BottleneckBlock(filters=128, skip_connection='identity'))
    resnet50.add(BottleneckBlock(filters=128, skip_connection='identity'))
    resnet50.add(BottleneckBlock(filters=128, skip_connection='identity'))

    # Block 3
    resnet50.add(BottleneckBlock(filters=256, skip_connection='projection', downsample=True))
    resnet50.add(BottleneckBlock(filters=256, skip_connection='identity'))
    resnet50.add(BottleneckBlock(filters=256, skip_connection='identity'))
    resnet50.add(BottleneckBlock(filters=256, skip_connection='identity'))
    resnet50.add(BottleneckBlock(filters=256, skip_connection='identity'))
    resnet50.add(BottleneckBlock(filters=256, skip_connection='identity'))

    # Block 4
    resnet50.add(BottleneckBlock(filters=512, skip_connection='projection', downsample=True))
    resnet50.add(BottleneckBlock(filters=512, skip_connection='identity'))
    resnet50.add(BottleneckBlock(filters=512, skip_connection='identity'))

    return resnet50

In [None]:
class C:
    BATCH_SIZE = 128

In [None]:
class DataManager:
    def __init__(self, log_dir):
        self.log_dir = log_dir

    def load_dataset(self):
        def build_ds(set_name):
            img_ds = TextLineDataset(str(pathlib.Path(self.log_dir, set_name, 'file_names.csv')))
            img_ds = img_ds.map(lambda x: self.parse_img(x, set_name), num_parallel_calls=tf.data.experimental.AUTOTUNE)

            label_ds = TextLineDataset(str(pathlib.Path(self.log_dir, set_name, 'labels.csv')))
            label_ds = label_ds.map(self.parse_label, num_parallel_calls=tf.data.experimental.AUTOTUNE)

            ds = Dataset.zip((img_ds, label_ds))
#             ds = ds.take(C.BATCH_SIZE*1)
            ds = ds.shuffle(25000)
            ds = ds.batch(C.BATCH_SIZE, drop_remainder=True)
        
            return ds
        
        train_ds = build_ds('train')
        val_ds = build_ds('validation')

        return train_ds, val_ds
        
    def parse_img(self, file_name, set_name):
        img_path = tf.strings.join([self.log_dir, f'/{set_name}/imgs/', file_name, '.png'])
        
        img = tf.io.read_file(img_path)
        img = tf.io.decode_png(img, channels=3)
        img = tf.cast(img, tf.float32)
        img = img / 255.0

        return img
    
    def parse_label(self, label):
        label = tf.strings.split(label, sep=',')
        label = tf.strings.to_number(label, out_type=tf.float32)

        return label

In [None]:
def load_model():
    model = Sequential()

    model.add(Input((100, 100, 3)))

    model.add(load_resnet50(input_shape=(100,100,3)))
#     model.add(tf.keras.applications.ResNet50(input_shape=(100,100,3), include_top=False, weights=None))

#     model.add(Conv2D(64, 3, padding='same', activation='relu'))
#     model.add(Conv2D(64, 3, padding='same', activation='relu'))
#     model.add(MaxPooling2D(strides=2, padding='same'))
#     model.add(Conv2D(128, 3, padding='same', activation='relu'))
#     model.add(Conv2D(128, 3, padding='same', activation='relu'))
#     model.add(MaxPooling2D(strides=2, padding='same'))
#     model.add(Conv2D(256, 3, padding='same', activation='relu'))
#     model.add(Conv2D(256, 3, padding='same', activation='relu'))

#     model.add(EfficientNetB1(input_shape=(100,100,3), include_top=False, weights=None))

    model.add(GlobalAveragePooling2D())
    model.add(Dense(64, activation='relu'))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(8, activation='sigmoid'))
    
    model.summary()

    return model

In [None]:
dm = DataManager('dataset_2_lines')
train_ds, val_ds = dm.load_dataset()

In [None]:
model = load_model()
model.compile(Adam(lr=1e-04), loss='mse')

In [None]:
hist = model.fit(train_ds, epochs=30)

In [None]:
plt.subplot(1, 2, 1)
plt.plot(hist.history['loss'], label='loss')
# plt.plot(hist.history['val_loss'], label='val_loss')
plt.legend()

# plt.subplot(1, 2, 2)
# plt.plot(hist.history['acc'], label='acc')
# plt.plot(hist.history['val_acc'], label='val_acc')
# plt.legend()
# plt.tight_layout()
plt.show()

In [None]:
x,y = next(iter(train_ds))
x = x[:1]
y = y[:1]

preds = model.predict(x, batch_size=1)

print('y          ', y.numpy())
print()
print('preds      ', preds)
print()
print('sub        ', (y-preds).numpy())
print()
print('subsquared ', ((y-preds)**2).numpy())
print()
print('sum        ', tf.math.reduce_sum((y-preds)**2).numpy())
print()
print('avg        ', tf.math.reduce_sum((y-preds)**2).numpy() / 4.0)
print()
print('loss       ', tf.keras.losses.MeanSquaredError()(y, preds).numpy())