diff --git a/Makefile b/Makefile index a990e1c..4bb822d 100644 --- a/Makefile +++ b/Makefile @@ -131,4 +131,9 @@ test-coverage: # git rm cache .PHONY: git-rm-cache git-rm-cache: - @git rm -r --cached . \ No newline at end of file + @git rm -r --cached . + +# find files with a specific name +.PHONY: find +find: + @find . -name "dataset" -type d \ No newline at end of file diff --git a/README.md b/README.md index e8d31be..415525d 100644 --- a/README.md +++ b/README.md @@ -155,12 +155,15 @@ transfer_model = TransferModel( dataset_address=Path(__file__).parent / "dataset" ) -transfer_model.plot_classes_number(figure_folder_path=Path(__file__).parent / "figures") +transfer_model.plot_classes_number() transfer_model.analyze_image_names() transfer_model.plot_data_images(num_rows=3, num_cols=3) -transfer_model.train_model() +transfer_model.train_model(epochs=3, + model_save_path=(Path(__file__).parent / ".." / "deep_model").resolve(), + model_name="tf_model_2.h5") transfer_model.plot_metrics_results() transfer_model.results() +# one can pass the model address to the predict_test method transfer_model.predcit_test() transfer_model.grad_cam_viz(num_rows=3, num_cols=2) ``` diff --git a/figures/classes_number.png b/figures/classes_number.png index 9bf06bf..8037e40 100644 Binary files a/figures/classes_number.png and b/figures/classes_number.png differ diff --git a/figures/cluster_number_per_class.png b/figures/cluster_number_per_class.png index 83114da..e379534 100644 Binary files a/figures/cluster_number_per_class.png and b/figures/cluster_number_per_class.png differ diff --git a/figures/image_width_height.png b/figures/image_width_height.png index dce158e..50e979e 100644 Binary files a/figures/image_width_height.png and b/figures/image_width_height.png differ diff --git a/figures/images.png b/figures/images.png index f5f7504..6352602 100644 Binary files a/figures/images.png and b/figures/images.png differ diff --git a/figures/metrics.png b/figures/metrics.png index b6a983f..4fe3856 100644 Binary files a/figures/metrics.png and b/figures/metrics.png differ diff --git a/figures/transf_cam.jpg b/figures/transf_cam.jpg index 46bc96a..9f83a3c 100644 Binary files a/figures/transf_cam.jpg and b/figures/transf_cam.jpg differ diff --git a/neural_network_model/transfer_learning.py b/neural_network_model/transfer_learning.py index 8dfb94a..60c3242 100644 --- a/neural_network_model/transfer_learning.py +++ b/neural_network_model/transfer_learning.py @@ -474,7 +474,10 @@ def train_model(self, epochs=10, batch_size=32, **kwargs): kwargs: model_save_location: location to save the model default is self.model_save_location """ - if kwargs.get("model_save_location"): + if kwargs.get("model_save_path"): + # check if the path exists if not create it + if not os.path.exists(kwargs.get("model_save_path")): + os.makedirs(kwargs.get("model_save_path")) self.model_save_path = kwargs.get("model_save_path") model_name = kwargs.get("model_name", "tf_model.h5") @@ -821,8 +824,11 @@ def grad_cam_viz(self, *args, **kwargs): transfer_model.plot_classes_number() transfer_model.analyze_image_names() transfer_model.plot_data_images(num_rows=3, num_cols=3) - transfer_model.train_model(epochs=3) + transfer_model.train_model(epochs=3, + model_save_path=(Path(__file__).parent / ".." / "deep_model").resolve(), + model_name="tf_model_2.h5") transfer_model.plot_metrics_results() transfer_model.results() + # one can pass the model address to the predict_test method transfer_model.predcit_test() transfer_model.grad_cam_viz(num_rows=3, num_cols=2) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 56ea90a..62b4689 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -178,3 +178,10 @@ def test_property_image_dict_3(mock_iterdir, mock_categorie_property, _object): "pdc_bit": {"image_list": [], "number_of_images": 0}, "rollercone_bit": {"image_list": [], "number_of_images": 0}, } + + +def test_integrated(_object): + _object = Preprocessing(dataset_address=Path(__file__).parent / ".." / "dataset") + _object.download_images() + # _object.augment_data(number_of_images_tobe_gen=10) + # _object.train_test_split()