Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Atashnezhad committed Sep 7, 2023
1 parent bde8738 commit d38e703
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions neural_network_model/transfer_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ def predict_one_image(
logger.info(f"Loading the model from {model_path}")
self.model = tf.keras.models.load_model(model_path)
else:
logger.info(f"Using the self.model from memory")
logger.info("Using the self.model from memory")
# Predict the image
prediction = self.model.predict(img_array)
predicted_class = tf.argmax(prediction, axis=1)[0]
Expand Down Expand Up @@ -1153,9 +1153,9 @@ def predict_image_patch_classes(self, **kwargs):
):
# Extract the patch using the sliding window
patch = image[
y:y + window_height,
x:x + window_width
]
y: y + window_height,
x: x + window_width
]

# Save the patch as an image
patch_filename = os.path.join(patch_images_dir, f'patch_{count}.jpg')
Expand Down Expand Up @@ -1260,9 +1260,9 @@ def predict_image_patch_classes_2(self, **kwargs):
):
# Extract the patch using the sliding window
patch = image[
y:y + window_height,
x:x + window_width
]
y: y + window_height,
x: x + window_width
]

# Save the patch as an image
patch_filename = os.path.join(patch_images_dir, f'patch_{count}.jpg')
Expand Down Expand Up @@ -1313,6 +1313,7 @@ def predict_image_patch_classes_2(self, **kwargs):

if __name__ == "__main__":
from neural_network_model.process_data import Preprocessing
from PIL import Image # noqa: F811

# download the dataset
# obj = Preprocessing()
Expand All @@ -1325,11 +1326,11 @@ def predict_image_patch_classes_2(self, **kwargs):
# transfer_model.plot_classes_number()
# transfer_model.analyze_image_names()
# transfer_model.plot_data_images(num_rows=3, num_cols=3, cmap="jet")
transfer_model.train_model(
epochs=5,
model_save_path=(Path(__file__).parent / ".." / "deep_model").resolve(),
model_name="tf_model_core_1.h5",
)
# transfer_model.train_model(
# epochs=5,
# model_save_path=(Path(__file__).parent / ".." / "deep_model").resolve(),
# model_name="tf_model_core_1.h5",
# )
# transfer_model.plot_metrics_results()
# transfer_model.results()
# one can pass the model address to the predict_test method
Expand Down Expand Up @@ -1382,7 +1383,6 @@ def predict_image_patch_classes_2(self, **kwargs):

img_path = Path(__file__).parent / ".." / "dataset_core" / "patch_images" / "patch_0.jpg"
# load the img from img_path
from PIL import Image

img = Image.open(img_path)
# Resize the image to the expected shape (224x224)
Expand Down

0 comments on commit d38e703

Please sign in to comment.