/
CountingCheckpointPolicy.cs
60 lines (53 loc) · 1.96 KB
/
CountingCheckpointPolicy.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// --------------------------------------------------------------------------------------------------------------------
// <copyright file="CountingCheckpointPolicy.cs">
// Copyright (c) by respective owners including Yahoo!, Microsoft, and
// individual contributors. All rights reserved. Released under a BSD
// license as described in the file LICENSE.
// </copyright>
// --------------------------------------------------------------------------------------------------------------------
namespace VW.Azure.Trainer.Checkpoint
{
/// <summary>
/// Implements an example count based checkpoint policy.
/// </summary>
public class CountingCheckpointPolicy : ICheckpointPolicy
{
private readonly int exampleSyncCount;
private int exampleCount;
/// <summary>
/// Initializes a new <see cref="CountingCheckpointPolicy"/> instance.
/// </summary>
public CountingCheckpointPolicy(int exampleSyncCount)
{
this.exampleSyncCount = exampleSyncCount;
}
/// <summary>
/// Return true if the trainer should checkpoint the model, false otherwise.
/// </summary>
/// <param name="examples">Number of examples since last checkpoint.</param>
public bool ShouldCheckpointAfterExample(int examples)
{
this.exampleCount += examples;
if (this.exampleCount >= this.exampleSyncCount)
{
this.exampleCount %= this.exampleSyncCount;
return true;
}
return false;
}
/// <summary>
/// Reset checkpoint policy state.
/// </summary>
public void Reset()
{
this.exampleCount = 0;
}
/// <summary>
/// Serialize to string for logging.
/// </summary>
public override string ToString()
{
return $"CountingCheckpointPolicy: {this.exampleSyncCount}";
}
}
}