-
Notifications
You must be signed in to change notification settings - Fork 0
/
TestModel.py
executable file
·49 lines (37 loc) · 1.61 KB
/
TestModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#!/bin/python
from Models.FeedForward import Model as FeedForwardAE
from Models.RecurrendAE import Model as LSTMAE
from Models.CNN_AE import Model as CNNAE
from Models.AttentionBasedAE import Model as AttentionModel
from SetWrappers.UCRArchive import loadData as UCRDataSet
from SetWrappers.SMD import loadData as SMDDataSet
from SetWrappers.ECGDataSet import loadData as ECGDataSet
from DataGenerators.Sines import generateData as Sines
from Trainers.SingleInstanceTrainer import Trainer as OnlineTrainer
from Trainers.BatchedTrainer import Trainer as BatchedTrainer
from Benchmark import benchmark,initializeDevice
from Evaluation.QuickOverview import plotOverview
pathToSave = "Results/DebugSession"
device = initializeDevice()
Dimensions = 2 # Dataset dimensions
#trainingSet,validationSet,testSet = Sines(Dimensions)
trainingSet,validationSet,testSet = UCRDataSet(Dimensions,DataSet = "ACSF1")
#trainingSet,validationSet,testSet = ECGDataSet(Dimensions)
#trainingSet,validationSet,testSet = SMDDataSet(Dimensions,nNormalDimensions=1)
model = FeedForwardAE(Dimensions,device,InputSize = trainingSet.Length())
#model = LSTMAE(Dimensions,device,CellKind = "LSTM")
#model = CNNAE(Dimensions,device,InputSize = trainingSet.Length())
#
#model = CNNAE(Dimensions,device,hasFFTEncoder = True)
#model = AttentionModel(Dimensions,device)
#trainer = OnlineTrainer(model,device)
trainer = BatchedTrainer(model,device)
resultFolder = benchmark(trainingSet,
validationSet,
testSet,
model,
trainer,
n_epochs=20,
pathToSave=pathToSave,
device = device)
plotOverview(resultFolder)