Skip to content

Commit

Permalink
feat: Update repack to fit to 512 x 512
Browse files Browse the repository at this point in the history
  • Loading branch information
Kohulan committed May 30, 2023
1 parent 960ffc6 commit 960904b
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions DECIMER/Repack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import Transformer_decoder
import Efficient_Net_encoder
import config
import utils

print(tf.__version__)

Expand All @@ -19,9 +20,9 @@
max_length = pickle.load(open("max_length_TPU_Stereo.pkl", "rb"))

# Image parameters
IMG_EMB_DIM = (10, 10, 232)
IMG_EMB_DIM = (16, 16, 232)
IMG_EMB_DIM = (IMG_EMB_DIM[0] * IMG_EMB_DIM[1], IMG_EMB_DIM[2])
IMG_SHAPE = (299, 299, 3)
IMG_SHAPE = (512, 512, 3)
PE_INPUT = IMG_EMB_DIM[0]
IMG_SEQ_LEN, IMG_EMB_DEPTH = IMG_EMB_DIM
D_MODEL = IMG_EMB_DEPTH
Expand Down Expand Up @@ -84,6 +85,17 @@
start_epoch = int(ckpt_manager.latest_checkpoint.split("-")[-1])


def detokenize_output(predicted_array):
outputs = [tokenizer.index_word[i] for i in predicted_array[0].numpy()]
prediction = (
"".join([str(elem) for elem in outputs])
.replace("<start>", "")
.replace("<end>", "")
)

return prediction


class DECIMER_Predictor(tf.Module):
"""This is a class which takes care of inference. It loads the saved checkpoint and the necessary
tokenizers. The inference begins with the start token (<start>) and ends when the end token(<end>)
Expand Down Expand Up @@ -116,10 +128,9 @@ def __call__(self, Decoded_image):
Returns:
output (tf.Tensor[tf.int64]): predicted output as an array.
"""

assert isinstance(Decoded_image, tf.Tensor)
if len(Decoded_image.shape) == 0:
Decoded_image = Decoded_image[tf.newaxis]
sentence = Decoded_image[tf.newaxis]

_image_batch = tf.expand_dims(Decoded_image, 0)
_image_embedding = encoder(_image_batch, training=False)
Expand Down

0 comments on commit 960904b

Please sign in to comment.