Skip to content

Commit

Permalink
Fixed bug with tanh
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexPasqua committed Dec 29, 2020
1 parent ddfc526 commit 6ad7942
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
8 changes: 4 additions & 4 deletions src/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def cross_valid(net, inputs, targets, epochs=1, batch_size=1, k_folds=5):
net.compile(opt='gd',
loss='squared',
metr='bin_class_acc',
lr=0.5,
lr_decay='linear',
limit_step=500,
momentum=0.7)
lr=0.65,
# lr_decay='linear',
# limit_step=800,
momentum=0.8)
tr_err, tr_metric, val_err, val_metric = net.fit(tr_x=train_set,
tr_y=train_targets,
val_x=valid_set,
Expand Down
28 changes: 16 additions & 12 deletions src/monk_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@
from model_selection import cross_valid

if __name__ == '__main__':
parameters = {
'input_dim': 17,
'units_per_layer': (4, 1),
'acts': ('leaky_relu', 'tanh'),
'init_type': 'random',
'weights_value': 0.2,
'lower_lim': 0.01,
'upper_lim': 0.2
}
model = Network(**parameters)

# read the dataset
col_names = ['class', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'Id']
monk1_train = pd.read_csv("../datasets/monks/monks-1.train", sep=' ', names=col_names)
Expand All @@ -17,35 +28,28 @@

# transform labels from pandas dataframe to numpy ndarray
labels = labels.to_numpy()[:, np.newaxis]
labels[labels == 0] = -1

# shuffle the whole dataset once
indexes = list(range(len(monk1_train)))
np.random.shuffle(indexes)
monk1_train = monk1_train[indexes]
labels = labels[indexes]

parameters = {
'input_dim': 17,
'units_per_layer': (4, 1),
'acts': ('leaky_relu', 'tanh'),
'init_type': 'random',
'weights_value': 0.2,
'lower_lim': 0.01,
'upper_lim': 0.1
}
model = Network(**parameters)
tr_error_values, tr_metric_values, val_error_values, val_metric_values = cross_valid(net=model,
inputs=monk1_train,
targets=labels,
epochs=400,
batch_size=15,
epochs=1000,
batch_size=len(monk1_train),
k_folds=4)
# plot learning curve
figure, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(range(len(tr_error_values)), tr_error_values, val_error_values)
ax[0].set_xlabel('Epochs', fontweight='bold')
ax[0].set_ylabel('loss', fontweight='bold')
ax[0].grid()
ax[1].plot(range(len(tr_metric_values)), tr_metric_values, val_metric_values)
ax[1].set_xlabel('Epochs', fontweight='bold')
ax[1].set_ylabel('accuracy', fontweight='bold')
ax[1].grid()
plt.show()

0 comments on commit 6ad7942

Please sign in to comment.