diff --git a/neural_network_model/model.py b/neural_network_model/model.py index 958a4e7..d16e14f 100644 --- a/neural_network_model/model.py +++ b/neural_network_model/model.py @@ -278,6 +278,16 @@ class Config: # Allow extra fields in the model (to ignore the pydantic ConfigError) extra = Extra.allow + SUPPORTED_FILE_FORMATS: list = [ + ".jpg", + ".jpeg", + ".png", + ".bmp", + ".tif", + ".tiff", + ".gif", + ] + class localBinaryPatterns(BaseModel): NUM_POINTS: int = 8 diff --git a/neural_network_model/transfer_learning.py b/neural_network_model/transfer_learning.py index 0c91617..4f2731b 100644 --- a/neural_network_model/transfer_learning.py +++ b/neural_network_model/transfer_learning.py @@ -326,7 +326,7 @@ def plot_data_images( filepath = self.image_df.Filepath[i] file_extension = os.path.splitext(filepath)[1].lower() - if file_extension in [".png", ".jpg"]: + if file_extension in TRANSFER_LEARNING_SETTING.SUPPORTED_FILE_FORMATS: try: image = plt.imread(filepath) ax.imshow(image, cmap=cmap)