-
Notifications
You must be signed in to change notification settings - Fork 162
/
amp_dygraph.py
93 lines (66 loc) · 2.31 KB
/
amp_dygraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import time
start_time = None
def start_timer():
global start_time
start_time = time.time()
def end_timer_and_print(msg):
end_time = time.time()
print("\n" + msg)
print("total time cost = {:.3f} sec".format(end_time - start_time))
import paddle
import paddle.nn as nn
class SimpleNet(nn.Layer):
def __init__(self, input_size, output_size):
super(SimpleNet, self).__init__()
self.linear1 = nn.Linear(input_size, output_size)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(input_size, output_size)
self.relu2 = nn.ReLU()
self.linear3 = nn.Linear(input_size, output_size)
def forward(self, x):
x = self.linear1(x)
x = self.relu1(x)
x = self.linear2(x)
x = self.relu2(x)
x = self.linear3(x)
return x
epochs = 5
input_size = 4096
output_size = 4096
batch_size = 512
nums_batch = 50
train_data = [paddle.randn((batch_size, input_size)) for _ in range(nums_batch)]
labels = [paddle.randn((batch_size, output_size)) for _ in range(nums_batch)]
mse = paddle.nn.MSELoss()
model = SimpleNet(input_size, output_size)
optimizer = paddle.optimizer.SGD(learning_rate=0.0001, parameters=model.parameters())
start_timer()
for epoch in range(epochs):
datas = zip(train_data, labels)
for i, (data, label) in enumerate(datas):
output = model(data)
loss = mse(output, label)
loss.backward()
optimizer.step()
optimizer.clear_grad()
print(loss)
end_timer_and_print("using defaule mode:")
model = SimpleNet(input_size, output_size)
optimizer = paddle.optimizer.SGD(learning_rate=0.0001, parameters=model.parameters())
# Step1:define GradScaler to scale the loss avoiding float overflow
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
start_timer()
for epoch in range(epochs):
datas = zip(train_data, labels)
for i, (data, label) in enumerate(datas):
# Step2: create the context for AMP to start auto mixed precision training
with paddle.amp.auto_cast():
output = model(data)
loss = mse(output, label)
# Step3: use the GradScaler defined in Step1 to scale the loss and use the loss to back propagation
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
optimizer.clear_grad()
print(loss)
end_timer_and_print("using AMP mode:")