diff --git a/src/config.py b/src/config.py index 4ab4db54..6217efbd 100644 --- a/src/config.py +++ b/src/config.py @@ -34,7 +34,7 @@ class config: label_file_path = os.path.join(labeled_data_path, 'label.txt') lr = 0.005 - epochs = 100 + epochs = 20 batch_size = 16 log_interval = 100 diff --git a/src/main.py b/src/main.py index b6bf2cd4..1f783e04 100644 --- a/src/main.py +++ b/src/main.py @@ -78,5 +78,17 @@ def test(): model = ResNetMini(3, 2) +def transfer_model(): + model = ResNetMini(3, 2) + model.load_state_dict(torch.load(config.model_path)) + model.eval() + torch.onnx.export(model, + torch.randn(1, 3, 64, 64), + config.model_onnx_path, + verbose=False, + export_params=True) + + if __name__ == '__main__': - train() \ No newline at end of file + train() + transfer_model() \ No newline at end of file diff --git a/src/test_cv2.py b/src/test_cv2.py new file mode 100644 index 00000000..3cc0345b --- /dev/null +++ b/src/test_cv2.py @@ -0,0 +1,49 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import cv2 +import numpy as np +from config import config, mkdir + + +def val_cv2(img_path): + img_list = os.listdir(img_path) + # print(img_list) + + model = cv2.dnn.readNetFromONNX(config.model_onnx_path) + acc = 0 + + for img_name in img_list: + # print(img_name) + img = cv2.imread(os.path.join(img_path, img_name)) + img = cv2.resize(img, (64, 64)) + + blob = cv2.dnn.blobFromImage(img, 1 / 255.0, (64, 64), (0, 0, 0), swapRB=True, crop=False) + model.setInput(blob) + out = model.forward() + # print(out.shape) + # print(out) + label = np.argmax(out, axis=1)[0] + # print(label) + if label == 1: + acc += 1 + # break + + print(f'{img_path}\n err: {len(img_list)-acc} acc: {acc / len(img_list)}') + + +if __name__ == '__main__': + val_img_path = [ + os.path.join(config.data_path, 'val', 'airplane in the sky flying left', 'yes'), + os.path.join(config.data_path, 'val', 'airplane in the sky flying left', 'bad'), + os.path.join(config.data_path, 'val', 'airplanes in the sky that are flying to the right', + 'yes'), + os.path.join(config.data_path, 'val', 'airplanes in the sky that are flying to the right', + 'bad'), + ] + for val_img_path_ in val_img_path: + val_cv2(val_img_path_) + + test_img_path = os.path.join(config.data_path, 'test') + for class_ in config.classes: + mkdir(os.path.join(test_img_path, class_), remove=True)