In [None]:
import matplotlib.pyplot as plt
import numpy as np
import mymodels
import tensorflow as tf

In [None]:
main_folder = '../'
dataset_folder = main_folder + 'dataset_tfrecord_small/'
dataset2_folder = main_folder + 'dataset2/'
logs_folder = main_folder + 'logs/'
checkpoints_folder = main_folder + 'checkpoints_sect1/'

In [None]:
def _parse_image_function(example_proto, label_shape=[128,10]):
    # Define the features to be extracted (serialized image and label)
    image_feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),  # Expecting the image as a serialized tensor (string)
        'label': tf.io.FixedLenFeature([], tf.string),  # Expecting the label as a serialized tensor (string)
    }

    # Parse the input tf.train.Example proto using the dictionary
    parsed_features = tf.io.parse_single_example(example_proto, image_feature_description)
    
    # Deserialize the image and label tensors
    image = tf.io.parse_tensor(parsed_features['image'], out_type=tf.float32)  # Deserialize image tensor
    label = tf.io.parse_tensor(parsed_features['label'], out_type=tf.int32)  # Deserialize label tensor

    # Ensure that the image tensor has the correct shape
    image.set_shape([128, 128, 128, 1])  # Set the known shape for the image tensor

    # Ensure that the label tensor has the correct shape
    label.set_shape(label_shape)

    return image, label

In [None]:
train_dataset = tf.data.TFRecordDataset(dataset_folder + 'train.tfrecord').map(_parse_image_function)
val_dataset = tf.data.TFRecordDataset(dataset_folder + 'train.tfrecord').map(_parse_image_function)

In [None]:
weights = checkpoints_folder + 'sect1_epoch_100.weights.h5'
model = mymodels.sect1()
model.compile()
model.load_weights(weights)

In [None]:
def crop_images(X_train_canvas, batch):
    X_cropped = []
    crop_amount = 32
    base_grace = 5

    small_data = X_train_canvas #[:100]

    for i in range(small_data.shape[0]):
        grace = base_grace

        image = small_data[i]
        coords = model.predict(image.reshape(1,128,128,1))
        x = int(coords[0][0])
        y = int(coords[0][1])
        if x > 128-crop_amount-grace:
            x = 128-crop_amount-grace
        if x < grace:
            x = grace
        if y > 128-crop_amount-grace:
            y = 128-crop_amount-grace
        if y < grace:
            y = grace
        
        cropped_image = image[y-grace:y+crop_amount+grace, x-grace:x+crop_amount+grace]

        if cropped_image.shape != (42,42,1):
            print(f"Error: {cropped_image.shape}")
            print(x,y)
            plt.imshow(cropped_image)
            plt.show()
            plt.imshow(image)
            plt.show()
        X_cropped.append(cropped_image)
        print(f"\rNum: {i+1} / {small_data.shape[0]} Batch: {batch} ", end='')

    X_cropped = np.array(X_cropped)
    print(X_cropped.shape)
    return X_cropped

In [None]:
def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def example_test(image, label):
    feature = {
        'image': _bytes_feature(image),
        'label': _bytes_feature(label)
    }
    #print(f'feature: {feature["label"]}')
    return tf.train.Example(features=tf.train.Features(feature=feature))

def write_in_batches(data, filename, coords=True):
    with tf.io.TFRecordWriter(filename) as writer:
        for i, (images, labels) in enumerate(data):
            image_tesnor = crop_images(images.numpy(), i)
            # Serialize images and labels as tensors
            image_tesnor = tf.convert_to_tensor(image_tesnor)
            print(image_tesnor.shape)
            
            serialized_image = tf.io.serialize_tensor(image_tesnor).numpy()
            #labels_unonehot = np.argmax(labels, axis=-1)  # Convert one-hot labels to class indices
            if coords:
              label_tensor = tf.convert_to_tensor(labels)
              #reshaped_tensor = tf.expand_dims(label_tensor, axis=-1)  # Adds a new dimension at the last axis
              print(label_tensor.shape)

              serialized_label = tf.io.serialize_tensor(label_tensor).numpy()
            else:
              label_tensor = tf.reshape(tf.convert_to_tensor(labels), [-1])
              reshaped_tensor = tf.expand_dims(label_tensor, axis=-1)  # Adds a new dimension at the last axis
              print(reshaped_tensor.shape)

              serialized_label = tf.io.serialize_tensor(reshaped_tensor).numpy()

            
            # Create tf.train.Example and write it to the file
            tf_example = example_test(serialized_image, serialized_label)
            writer.write(tf_example.SerializeToString())

In [None]:
write_in_batches(train_dataset, 'train_dataset_cropped.tfrecord')
write_in_batches(val_dataset, 'test_dataset_cropped.tfrecord')