-
Notifications
You must be signed in to change notification settings - Fork 0
/
Program.cs
44 lines (34 loc) · 1.32 KB
/
Program.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
namespace PRSGD
{
public class Program
{
public static async Task Main()
{
int dim = 256;
int numWorkers = 4;
// The loss function
var centers = new float[dim];
var lossFunctions = new LossFunction[numWorkers];
for (int w = 0; w < numWorkers; w++)
{
// Create a loss function for the worker w
for (int i = 0; i < dim; i++)
centers[i] = i + w*w;
// centers is cloned inside the constructor of QuadraticLoss
// ==> we can modify the same centers array here to create a new QuadraticLoss
lossFunctions[w] = new QuadraticLoss(dim, centers, noisyGrads:true);
}
var avgLoss = new AverageLoss(lossFunctions);
// PR-SGD hyperparameters
int numIterations = 50;
var numLocalSteps = new int[numWorkers];
for (int i = 0; i < numWorkers; i++)
numLocalSteps[i] = 100;
float learningRate = 0.01f;
// Run PR-SGD
var core = new Core(numWorkers, avgLoss);
await core.RunPRSGD(numIterations, numLocalSteps, learningRate);
core.ShowLoss();
}
}
}