-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
206 lines (164 loc) · 5.74 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Reduce terminal logging verbosity
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
print("Starting up...")
# Import libraries
print("\nImporting libraries...")
import tensorflow as tf
from tensorflow.python import keras, data
from tensorflow.python.keras import layers, losses
print("Tensorflow imported")
from keras import utils
from keras.layers import Rescaling
from keras.callbacks import History
print("Keras imported")
import pathlib
import matplotlib.pyplot as plt
print("Other dependencies imported")
print("All libraries imported\n")
# Environment Variables
dataPath = "../Generated Images"
testPath = "../Equalised Images"
modelPath = ".model"
print("Environment variables set\n")
# Class container for model
class CNN(tf.Module):
def __init__(self, imgDim: tuple = (512, 512, 3)):
self.model = keras.Sequential(
[
layers.Conv2D(8, (4, 4), padding = 'same', input_shape = imgDim, activation = 'relu'),
Rescaling(1. / 255),
layers.MaxPooling2D(pool_size = (2, 2)),
layers.Conv2D(16, (4, 4), padding = 'same', activation = 'relu'),
layers.MaxPooling2D(pool_size = (2, 2)),
layers.Conv2D(8, (4, 4), padding = 'same', activation = 'relu'),
layers.MaxPooling2D(pool_size = (2, 2)),
layers.Dropout(0.2),
layers.Flatten(),
layers.Dense(64, activation = "relu"),
layers.Dense(2)
]
)
self.model.compile(
optimizer = "adam",
loss = losses.SparseCategoricalCrossentropy(from_logits = True)
)
self.trainData = None
self.testData = None
self.history = None
def train(self, seed: int = 1542, epochs: int = 10):
# Configuring datasets for better loading performance
tuner = data.AUTOTUNE
self.trainData = self.trainData.cache().shuffle(seed).prefetch(buffer_size = tuner)
self.testData = self.testData.cache().prefetch(buffer_size = tuner)
# Training CNN Model
return self.model.fit(self.trainData,
validation_data = self.testData,
epochs = epochs)
def trainSummary(self):
loss = self.history.history['loss']
val_loss = self.history.history['val_loss']
epochs_range = self.history.epoch
plt.figure(figsize=(8, 8))
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
@tf.function
def loadData(self, trainTestSplit: float = 0.35, seed: int = 1574):
dataDir = pathlib.Path(dataPath)
batchSize = 32
# Creating dataset loaders
trainData = utils.image_dataset_from_directory(dataDir,
validation_split = trainTestSplit,
subset = "training",
seed = seed,
image_size = (512, 512),
batch_size = batchSize)
testData = utils.image_dataset_from_directory(dataDir,
validation_split = trainTestSplit,
subset = "validation",
seed = seed,
image_size = (512, 512),
batch_size = batchSize)
return trainData, testData
def rawPredict(self, testPath: str = "test"):
testDir = pathlib.Path(testPath)
testData = utils.image_dataset_from_directory(testDir,
validation_split = 0,
seed = 0,
image_size = (512, 512))
return testData, self.model.predict(testData)
def saveModel(self, modelFile: str = ".model"):
self.model.save(modelFile)
def loadModel(self, modelFile: str = ".model"):
self.model.load_weights(modelFile)
def debugInfo(self):
print(self.model)
for layer in self.model.layers:
print(layer)
self.model.summary()
print(self.model.get_weights())
# Getting the OS currently running the model
from platform import system
currentOs = system()
# Helper functions
def clrscr():
if currentOs == 'Linux':
os.system("clear")
else:
os.system("cls")
def interpretPred(data, pred):
[]
# Wrapper functions
def manualTrain():
[]
def rawPredict():
[]
def editEnvVars():
[]
# Main menu function
def mainMenu() -> str:
clrscr()
print("Main Menu\n")
print("T -> Train the model")
print("P -> Make a prediction")
print("E -> Edit Environment Variables")
print("Q -> Quit")
ch = input("\nEnter your choice: ")
ch = ch[0].lower()
return ch
# Operation mode specified by launch arguments
# Available options:
# Refresh mode: Overwrite existing model files to start a fresh model
# Continue mode: Use existing model files to continue from previously saved model
# Model is automatically saved at the end of each program run
import argparse
parser = argparse.ArgumentParser(prog = "cnn-model", description = "CNN Windows EXE Classification Program")
parser.add_argument('-m', '--model-path', type = str, help = "Model files destination", dest = "modelPath", default = ".model")
parser.add_argument('-d', '--data-path', type = str, help = "Image Dataset path", dest = "dataPath", default = "../Generated Images")
parser.add_argument('-t', '--test-path', type = str, help = "Image Dataset for Testing path", dest = "testPath", default = "../Equalised Images")
# Mode Selection
progMode = parser.add_mutually_exclusive_group(required = True)
progMode.add_argument('-r', '--refresh', action = "store_true", help = "Run program in refresh mode", dest = "ref")
progMode.add_argument('-c', '--continue', action = "store_true", help = "Run program in continue mod", dest = "con")
args = parser.parse_args()
if args.ref:
print("Initialising new Model...")
model = CNN()
print("Saving model (overwrites existing model files in path)...")
model.saveModel(modelFile = args.modelPath)
print("Model successfully saved!")
elif args.con:
print("Creating Model...")
model = CNN()
print("Loading model values")
model.loadModel(modelFile = args.modelPath)
print("Model successfully loaded!")
input("\nPress enter to continue to main menu...")
# Main menu loop
ex = 'a'
while ex != 'q':
ex = mainMenu()
model.saveModel(modelFile = args.modelPath)