Skip to content

Commit

Permalink
feat: improved DECIMER hand-drawn model
Browse files Browse the repository at this point in the history
  • Loading branch information
Kohulan committed Mar 5, 2024
1 parent 95d6049 commit 03d508b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
28 changes: 20 additions & 8 deletions DECIMER/decimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@

model_urls = {
"DECIMER": "https://zenodo.org/record/8300489/files/models.zip",
"DECIMER_HandDrawn": "https://zenodo.org/records/10781330/files/DECIMER_HandDrawn_model.zip"
"DECIMER_HandDrawn": "https://zenodo.org/records/10781330/files/DECIMER_HandDrawn_model.zip",
}


def get_models(model_urls: dict) -> Tuple[object, tf.saved_model, tf.saved_model]:
"""Download and load models from the provided URLs.
Expand All @@ -54,7 +55,9 @@ def get_models(model_urls: dict) -> Tuple[object, tf.saved_model, tf.saved_model
model_paths = utils.ensure_models(default_path=default_path, model_urls=model_urls)

# Load tokenizers
tokenizer_path = os.path.join(model_paths["DECIMER"], "assets", "tokenizer_SMILES.pkl")
tokenizer_path = os.path.join(
model_paths["DECIMER"], "assets", "tokenizer_SMILES.pkl"
)
tokenizer = pickle.load(open(tokenizer_path, "rb"))

# Load DECIMER models
Expand All @@ -63,8 +66,10 @@ def get_models(model_urls: dict) -> Tuple[object, tf.saved_model, tf.saved_model

return tokenizer, DECIMER_V2, DECIMER_Hand_drawn


tokenizer, DECIMER_V2, DECIMER_Hand_drawn = get_models(model_urls)


def detokenize_output(predicted_array: int) -> str:
"""This function takes the predited tokens from the DECIMER model and
returns the decoded SMILES string.
Expand Down Expand Up @@ -115,7 +120,10 @@ def detokenize_output_add_confidence(
decoded_prediction_with_confidence.append(prediction_with_confidence_[-1])
return decoded_prediction_with_confidence

def predict_SMILES(image_path: str, confidence: bool = False, hand_drawn: bool = False) -> str:

def predict_SMILES(
image_path: str, confidence: bool = False, hand_drawn: bool = False
) -> str:
"""Predicts SMILES representation of a molecule depicted in the given image.
Args:
Expand All @@ -127,18 +135,21 @@ def predict_SMILES(image_path: str, confidence: bool = False, hand_drawn: bool =
str: SMILES representation of the molecule in the input image, optionally with confidence values
"""
chemical_structure = config.decode_image(image_path)

model = DECIMER_Hand_drawn if hand_drawn else DECIMER_V2
predicted_tokens, confidence_values = model(tf.constant(chemical_structure))

predicted_SMILES = utils.decoder(detokenize_output(predicted_tokens))

if confidence:
predicted_SMILES_with_confidence = detokenize_output_add_confidence(predicted_tokens, confidence_values)
predicted_SMILES_with_confidence = detokenize_output_add_confidence(
predicted_tokens, confidence_values
)
return predicted_SMILES, predicted_SMILES_with_confidence

return predicted_SMILES


def main():
"""This function take the path of the image as user input and returns the
predicted SMILES as output in CLI.
Expand All @@ -155,5 +166,6 @@ def main():
SMILES = predict_SMILES(sys.argv[1])
print(SMILES)


if __name__ == "__main__":
main()
9 changes: 3 additions & 6 deletions DECIMER/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,8 @@ def decoder(predictions):
)
return modified

def ensure_models(
default_path: str,
model_urls: dict
) -> dict:

def ensure_models(default_path: str, model_urls: dict) -> dict:
"""Function to ensure models are present locally.
Convenient function to ensure model downloads before usage
Expand All @@ -89,9 +87,8 @@ def ensure_models(
config.download_trained_weights(model_url, default_path)
elif not os.path.exists(model_path):
config.download_trained_weights(model_url, default_path)

# Store the model path
model_paths[model_name] = model_path

return model_paths

0 comments on commit 03d508b

Please sign in to comment.