forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
observer_config.h
99 lines (87 loc) · 3.27 KB
/
observer_config.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#pragma once
#include "observers/macros.h"
#include "observers/net_observer_reporter.h"
#include "caffe2/core/common.h"
namespace caffe2 {
/*
netInitSampleRate_ == 1 && operatorNetSampleRatio_ == 1 :
Log operator metrics in every iteration
netInitSampleRate_ == 1 && operatorNetSampleRatio_ == 0 :
Log net metrics in every iterationn
netInitSampleRate_ == n && netFollowupSampleRate_ == m &&
netFollowupSampleCount == c && operatorNetSampleRatio_ == 1 :
Log operator metrics first at odds of 1 / n. Once first logged,
the following c logs are at odds of 1 / min(n, m). Then repeat
netInitSampleRate_ == n && netFollowupSampleRate_ == m &&
netFollowupSampleCount == c && operatorNetSampleRatio_ == 0 :
Log net metrics first at odds of 1 / n. Once first logged,
the following c logs are at odds of 1 / min(n, m). Then repeat
netInitSampleRate_ == n && netFollowupSampleRate_ == m &&
netFollowupSampleCount == c && operatorNetSampleRatio_ == o :
Log net metrics first at odds of 1 / n. Once first logged,
the following c logs are at odds of 1 / min(n, m), if the random number
is multiples of o, log operator metrics instead. Then repeat
skipIters_ == n: skip the first n iterations of the net.
*/
class CAFFE2_OBSERVER_API ObserverConfig {
public:
static void initSampleRate(
int netInitSampleRate,
int netFollowupSampleRate,
int netFollowupSampleCount,
int operatorNetSampleRatio,
int skipIters) {
CAFFE_ENFORCE(netFollowupSampleRate <= netInitSampleRate);
CAFFE_ENFORCE(netFollowupSampleRate >= 1 || netInitSampleRate == 0);
netInitSampleRate_ = netInitSampleRate;
netFollowupSampleRate_ = netFollowupSampleRate;
netFollowupSampleCount_ = netFollowupSampleCount;
operatorNetSampleRatio_ = operatorNetSampleRatio;
skipIters_ = skipIters;
}
static int getNetInitSampleRate() {
return netInitSampleRate_;
}
static int getNetFollowupSampleRate() {
return netFollowupSampleRate_;
}
static int getNetFollowupSampleCount() {
return netFollowupSampleCount_;
}
static int getOpoeratorNetSampleRatio() {
return operatorNetSampleRatio_;
}
static int getSkipIters() {
return skipIters_;
}
static void setReporter(unique_ptr<NetObserverReporter> reporter) {
reporter_ = std::move(reporter);
}
static NetObserverReporter* getReporter() {
CAFFE_ENFORCE(reporter_);
return reporter_.get();
}
static void setMarker(int marker) {
marker_ = marker;
}
static int getMarker() {
return marker_;
}
private:
/* The odds of log net metric initially or immediately after reset */
static int netInitSampleRate_;
/* The odds of log net metric after log once after start of reset */
static int netFollowupSampleRate_;
/* The number of follow up logs to be collected for odds of
netFollowupSampleRate_ */
static int netFollowupSampleCount_;
/* The odds to log the operator metric instead of the net metric.
When the operator is logged the net is not logged. */
static int operatorNetSampleRatio_;
/* skip the first few iterations */
static int skipIters_;
static unique_ptr<NetObserverReporter> reporter_;
/* marker used in identifying the metrics in certain reporters */
static int marker_;
};
}