# Iteration 2 - TableNet Model 2.0

> **Dataset**: Client Dataset <br/>
> **Model**: TableNet Model <br/>
> **Creator**: Ryo

In [None]:
import os
import xml.etree.ElementTree as ET
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras import layers, models, applications
from tensorflow.keras.layers import RandomRotation, RandomTranslation, RandomZoom
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.metrics import BinaryAccuracy, MeanIoU

In [None]:
pd.set_option('display.max_colwidth', None) 
pd.set_option('display.max_columns', None)  
pd.set_option('display.width', 1000)        

In [None]:
IMAGE_FOLDER = "../data/tablenet_data/images/3. Resized"
MASK_FOLDER = "../data/tablenet_data/masking/3. Resized Masks"

ANNOTATION_TYPES = ["row", "column", "cell", "year", "location"]

In [None]:
def create_complete_file_table(image_folder, mask_folder):
    image_paths = [
        os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith(".JPG")
    ]

    mask_paths = {annotation_type: [] for annotation_type in ANNOTATION_TYPES}

    def find_mask_path(base_name, subfolder):
        for root, _, files in os.walk(subfolder):
            for file in files:
                if file == f"{base_name}.png":
                    return os.path.join(root, file)
        return None

    for image_path in image_paths:
        base_name = os.path.splitext(os.path.basename(image_path))[0] 
        
        for annotation_type in ANNOTATION_TYPES:
            mask_subfolder = os.path.join(mask_folder, annotation_type)
            mask_file = find_mask_path(base_name, mask_subfolder)
            mask_paths[annotation_type].append(mask_file)

    data = {"image": image_paths}
    for annotation_type in ANNOTATION_TYPES:
        data[f"{annotation_type}_mask"] = mask_paths[annotation_type]

    return pd.DataFrame(data)

In [None]:
file_table = create_complete_file_table(IMAGE_FOLDER, MASK_FOLDER)
file_table.head()

In [None]:
def split_dataset(df, test_size=0.2, val_size=0.1, random_state=42):
    train_val_df, test_df = train_test_split(df, test_size=test_size, random_state=random_state)

    train_df, val_df = train_test_split(train_val_df, test_size=val_size, random_state=random_state)

    return train_df, val_df, test_df

In [None]:
train_df, val_df, test_df = split_dataset(
    file_table, 
    test_size=0.05,
    val_size=0.15,
    random_state=42
)

print("Training set size (80%):", len(train_df))
print("Validation set size (15%):", len(val_df))
print("Test set size (5%):", len(test_df))

In [None]:
def extract_paths(df):
    image_paths = df['image'].values
    mask_paths = {col: df[col].values for col in df.columns if '_mask' in col}
    return image_paths, mask_paths

In [None]:
train_image_paths, train_mask_paths = extract_paths(train_df)
val_image_paths, val_mask_paths = extract_paths(val_df)
test_image_paths, test_mask_paths = extract_paths(test_df)

In [None]:
def replicate_data(image_paths, mask_paths, replicate_count):
    image_paths = np.tile(image_paths, replicate_count)
    mask_paths = {key: np.tile(mask_paths[key], replicate_count) for key in mask_paths.keys()}
    return image_paths, mask_paths

In [None]:
def preprocess(image_path, *mask_paths):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32) / 255.0

    masks = [
        tf.image.decode_jpeg(tf.io.read_file(mask_path), channels=1) / 255.0
        for mask_path in mask_paths
    ]

    return [image] + masks

In [None]:
def parse_image(image_path, row_mask_path, column_mask_path, cell_mask_path, year_mask_path, location_mask_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32) / 255.0

    masks = {
        'row_output': tf.cast(tf.image.decode_jpeg(tf.io.read_file(row_mask_path), channels=1), tf.float32) / 255.0,
        'column_output': tf.cast(tf.image.decode_jpeg(tf.io.read_file(column_mask_path), channels=1), tf.float32) / 255.0,
        'cell_output': tf.cast(tf.image.decode_jpeg(tf.io.read_file(cell_mask_path), channels=1), tf.float32) / 255.0,
        'year_output': tf.cast(tf.image.decode_jpeg(tf.io.read_file(year_mask_path), channels=1), tf.float32) / 255.0,
        'location_output': tf.cast(tf.image.decode_jpeg(tf.io.read_file(location_mask_path), channels=1), tf.float32) / 255.0,
    }

    return image, masks

In [None]:
def load_dataset(image_paths, mask_paths, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((
        image_paths,
        mask_paths['row_mask'],
        mask_paths['column_mask'],
        mask_paths['cell_mask'],
        mask_paths['year_mask'],
        mask_paths['location_mask']
    ))

    dataset = dataset.map(
        lambda img, row_mask, col_mask, cell_mask, year_mask, loc_mask: parse_image(
            img, row_mask, col_mask, cell_mask, year_mask, loc_mask
        ),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset, len(image_paths)

In [None]:
BATCH_SIZE_TRAIN = 8
BATCH_SIZE_VAL = 3
BATCH_SIZE_TEST = 3

In [None]:
replicate_count = 3
train_image_paths, train_mask_paths = replicate_data(train_image_paths, train_mask_paths, replicate_count)

train_dataset, train_size = load_dataset(
    train_image_paths,
    train_mask_paths,
    batch_size=BATCH_SIZE_TRAIN
)

val_dataset, val_size = load_dataset(
    val_image_paths,
    val_mask_paths,
    batch_size=BATCH_SIZE_VAL,
)

test_dataset, test_size = load_dataset(
    test_image_paths,
    test_mask_paths,
    batch_size=BATCH_SIZE_TEST,
)

In [None]:
train_dataset

In [None]:
def visualize_batch(batch, num_samples=5):
    images, masks = batch
    
    images = images.numpy()
    masks = {key: mask.numpy() for key, mask in masks.items()}
    
    batch_size = images.shape[0]
    num_samples = min(num_samples, batch_size)
    
    for i in range(num_samples):
        plt.figure(figsize=(20, 5))
        
        plt.subplot(1, len(masks) + 1, 1)
        plt.imshow(images[i])
        plt.title(f"Sample {i + 1}: Image")
        plt.axis('off')

        for j, (mask_name, mask_array) in enumerate(masks.items()):
            plt.subplot(1, len(masks) + 1, j + 2)
            plt.imshow(mask_array[i, :, :, 0], cmap='gray')
            plt.title(f"Sample {i + 1}: {mask_name.replace('_', ' ').capitalize()}")
            plt.axis('off')
        
        plt.show()

In [None]:
for batch in train_dataset.take(1):
    visualize_batch(batch, num_samples=3)

In [None]:
def TableNet(input_shape=(960, 1280, 3)):
    inputs = layers.Input(shape=input_shape)

    # VGG19 as encoder
    vgg19 = applications.VGG19(include_top=False, weights='imagenet', input_tensor=inputs)

    # Freeze the encoder layers
    for layer in vgg19.layers:
        layer.trainable = False

    # Extract features from different layers of VGG19
    skip1 = vgg19.get_layer("block1_conv2").output
    skip2 = vgg19.get_layer("block2_conv2").output
    skip3 = vgg19.get_layer("block3_conv4").output
    skip4 = vgg19.get_layer("block4_conv4").output
    bottleneck = vgg19.get_layer("block5_conv4").output

    def decoder_block(x, skip_features, filters):
        x = layers.Conv2DTranspose(filters, (2, 2), strides=2, padding='same')(x)
        x = layers.Concatenate()([x, skip_features])
        x = layers.SeparableConv2D(filters, (3, 3), padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(negative_slope=0.1)(x)
        x = layers.SeparableConv2D(filters, (3, 3), padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(negative_slope=0.1)(x)
        x = layers.Dropout(0.2)(x)
        return x

    # Row output decoder
    x_row = decoder_block(bottleneck, skip4, 256)
    x_row = decoder_block(x_row, skip3, 128)
    x_row = decoder_block(x_row, skip2, 64)
    x_row = decoder_block(x_row, skip1, 32)
    row_output = layers.Conv2D(1, (1, 1), activation='sigmoid', name='row_output')(x_row)

    # Column output decoder
    x_column = decoder_block(bottleneck, skip4, 256)
    x_column = decoder_block(x_column, skip3, 128)
    x_column = decoder_block(x_column, skip2, 64)
    x_column = decoder_block(x_column, skip1, 32)
    column_output = layers.Conv2D(1, (1, 1), activation='sigmoid', name='column_output')(x_column)

    # Cell output decoder
    x_cell = decoder_block(bottleneck, skip4, 256)
    x_cell = decoder_block(x_cell, skip3, 128)
    x_cell = decoder_block(x_cell, skip2, 64)
    x_cell = decoder_block(x_cell, skip1, 32)
    cell_output = layers.Conv2D(1, (1, 1), activation='sigmoid', name='cell_output')(x_cell)

    # Year output decoder
    x_year = decoder_block(bottleneck, skip4, 256)
    x_year = decoder_block(x_year, skip3, 128)
    x_year = decoder_block(x_year, skip2, 64)
    x_year = decoder_block(x_year, skip1, 32)
    year_output = layers.Conv2D(1, (1, 1), activation='sigmoid', name='year_output')(x_year)

    # Location output decoder
    x_location = decoder_block(bottleneck, skip4, 256)
    x_location = decoder_block(x_location, skip3, 128)
    x_location = decoder_block(x_location, skip2, 64)
    x_location = decoder_block(x_location, skip1, 32)
    location_output = layers.Conv2D(1, (1, 1), activation='sigmoid', name='location_output')(x_location)

    # Define the model
    model = models.Model(
        inputs=inputs,
        outputs=[row_output, column_output, cell_output, year_output, location_output]
    )

    return model

In [None]:
TARGET_SIZE = (960, 1280) # Height x Width
input_shape = (TARGET_SIZE[0], TARGET_SIZE[1], 3)

model = TableNet(input_shape=input_shape)

In [None]:
# model.summary()

In [None]:
def binary_iou(y_true, y_pred):
    y_pred = tf.cast(y_pred > 0.5, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
    return intersection / (union + tf.keras.backend.epsilon())

In [None]:
model.compile(
    optimizer=Adam(learning_rate=1e-4),
    loss={
        'row_output': 'binary_crossentropy',
        'column_output': 'binary_crossentropy',
        'cell_output': 'binary_crossentropy',
        'year_output': 'binary_crossentropy',
        'location_output': 'binary_crossentropy'
    },
    metrics={
        'row_output': [BinaryAccuracy(name='accuracy'), binary_iou],
        'column_output': [BinaryAccuracy(name='accuracy'), binary_iou],
        'cell_output': [BinaryAccuracy(name='accuracy'), binary_iou],
        'year_output': [BinaryAccuracy(name='accuracy'), binary_iou],
        'location_output': [BinaryAccuracy(name='accuracy'), binary_iou]
    }
)

In [None]:
# callbacks = [
#     EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
#     ModelCheckpoint(filepath='best_model.keras', save_best_only=True, monitor='val_loss'),
#     ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)
# ]

In [None]:
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=20,
    # callbacks=callbacks,
    verbose=1
)

In [None]:
model.save('final_tablenet_model.h5')