Skip to content

Commit

Permalink
Merge pull request #9 from Atashnezhad/feature/fix_bug_2
Browse files Browse the repository at this point in the history
edit make file
  • Loading branch information
Atashnezhad committed Jul 17, 2023
2 parents 49c5d85 + 5b6acca commit 6a8d5ea
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 5 deletions.
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,9 @@ test-coverage:
# git rm cache
.PHONY: git-rm-cache
git-rm-cache:
@git rm -r --cached .
@git rm -r --cached .

# find files with a specific name
.PHONY: find
find:
@find . -name "dataset" -type d
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,15 @@ transfer_model = TransferModel(
dataset_address=Path(__file__).parent / "dataset"
)

transfer_model.plot_classes_number(figure_folder_path=Path(__file__).parent / "figures")
transfer_model.plot_classes_number()
transfer_model.analyze_image_names()
transfer_model.plot_data_images(num_rows=3, num_cols=3)
transfer_model.train_model()
transfer_model.train_model(epochs=3,
model_save_path=(Path(__file__).parent / ".." / "deep_model").resolve(),
model_name="tf_model_2.h5")
transfer_model.plot_metrics_results()
transfer_model.results()
# one can pass the model address to the predict_test method
transfer_model.predcit_test()
transfer_model.grad_cam_viz(num_rows=3, num_cols=2)
```
Expand Down
Binary file modified figures/classes_number.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/cluster_number_per_class.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/image_width_height.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/images.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/metrics.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/transf_cam.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 8 additions & 2 deletions neural_network_model/transfer_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,10 @@ def train_model(self, epochs=10, batch_size=32, **kwargs):
kwargs:
model_save_location: location to save the model default is self.model_save_location
"""
if kwargs.get("model_save_location"):
if kwargs.get("model_save_path"):
# check if the path exists if not create it
if not os.path.exists(kwargs.get("model_save_path")):
os.makedirs(kwargs.get("model_save_path"))
self.model_save_path = kwargs.get("model_save_path")
model_name = kwargs.get("model_name", "tf_model.h5")

Expand Down Expand Up @@ -821,8 +824,11 @@ def grad_cam_viz(self, *args, **kwargs):
transfer_model.plot_classes_number()
transfer_model.analyze_image_names()
transfer_model.plot_data_images(num_rows=3, num_cols=3)
transfer_model.train_model(epochs=3)
transfer_model.train_model(epochs=3,
model_save_path=(Path(__file__).parent / ".." / "deep_model").resolve(),
model_name="tf_model_2.h5")
transfer_model.plot_metrics_results()
transfer_model.results()
# one can pass the model address to the predict_test method
transfer_model.predcit_test()
transfer_model.grad_cam_viz(num_rows=3, num_cols=2)
7 changes: 7 additions & 0 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,10 @@ def test_property_image_dict_3(mock_iterdir, mock_categorie_property, _object):
"pdc_bit": {"image_list": [], "number_of_images": 0},
"rollercone_bit": {"image_list": [], "number_of_images": 0},
}


def test_integrated(_object):
_object = Preprocessing(dataset_address=Path(__file__).parent / ".." / "dataset")
_object.download_images()
# _object.augment_data(number_of_images_tobe_gen=10)
# _object.train_test_split()

0 comments on commit 6a8d5ea

Please sign in to comment.