-
Notifications
You must be signed in to change notification settings - Fork 1
/
start_training.py
81 lines (61 loc) · 2.71 KB
/
start_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""Module for training starting"""
import glob
from typing import List, Tuple, Dict
import numpy as np
from modules.data.dataset import Dataset
from modules.network.hopfield import HopfieldNetwork
from modules.utils.plots import plot_images_compare
from modules.train import get_trained_model
from config import Config
def evaluation(model: HopfieldNetwork,
image_paths: List[str],
image_size: Tuple[int, int] = (256, 256),
num_iter: int = 20,
threshold: int = 50,
add_noise: bool = False) -> Dict[str, List[np.ndarray]]:
"""
Evaluates HopfieldNetwork model on images from image_paths
:param model: model instance
:param image_paths: paths to the images
:param image_size: image size
:param num_iter: images passes threw network num_iter times
:param threshold: threshold for sign function in network
:param add_noise: if True adds noise to input image
:return:
"""
dataset_original = Dataset(list_of_paths=image_paths,
image_size=image_size,
add_noise=add_noise)
flatten_images_original = dataset_original.get_all_flatten_images()
predictions = model.predict(data=flatten_images_original)
data = {
'original_image': flatten_images_original,
'prediction_image': predictions
}
return data
if __name__ == '__main__':
use_numbers_dataset = False
if use_numbers_dataset:
image_paths_train = glob.glob(pathname='images_diff/train/*.*', recursive=True)
image_paths_test = glob.glob(pathname='images_diff/test/*.*', recursive=True)
model = get_trained_model(image_paths=image_paths_train,
image_size=Config.image_size,
asynchronous=Config.asynchronous)
data = evaluation(model=model,
image_paths=image_paths_test,
image_size=Config.image_size,
num_iter=Config.num_iter,
threshold=Config.threshold)
plot_images_compare(data=data)
else:
image_paths = glob.glob(pathname='images_same/*.*', recursive=True)
model = get_trained_model(image_paths=image_paths,
image_size=Config.image_size,
asynchronous=False)
data = evaluation(model=model,
image_paths=image_paths,
image_size=Config.image_size,
num_iter=Config.num_iter,
threshold=Config.threshold,
add_noise=True)
plot_images_compare(data=data)