In [1]:
import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn

mx.cpu(), mx.gpu(), mx.gpu(1)

(cpu(0), gpu(0), gpu(1))

In [None]:
x = nd.array([1, 2, 3])
x

In [9]:
x.context

cpu(0)

In [10]:
a = nd.array([1, 2, 3], ctx=mx.gpu())
a


[1. 2. 3.]
<NDArray 3 @gpu(0)>

In [12]:
B = nd.random.uniform(shape=(2, 3), ctx=mx.gpu())
B


[[0.6686509  0.17409194 0.3850025 ]
 [0.24678314 0.35134333 0.8404298 ]]
<NDArray 2x3 @gpu(0)>

In [13]:
# 我们也可以通过copyto函数和as_in_context函数在设备之间传输数据。
# 下面我们将内存上的NDArray变量x复制到gpu(0)上。
# 如果源和目标的context一致，as_in_context不复制，而copyto总是会新建内存。
y = x.copyto(mx.gpu())
y


[1. 2. 3.]
<NDArray 3 @gpu(0)>

In [14]:
z = x.as_in_context(mx.gpu())
z


[1. 2. 3.]
<NDArray 3 @gpu(0)>

In [15]:
y.as_in_context(mx.gpu()) is y

True

In [16]:
y.copyto(mx.gpu()) is y

False

In [17]:
(z + 2).exp() * y
# MXNet要求计算的所有输入数据都在内存或同一块显卡的显存上。这样设计的原因是CPU和不同的GPU之间的数据交互通常比较耗时。
# 当我们打印NDArray或将NDArray转换成NumPy格式时，如果数据不在内存里，MXNet会将它先复制到内存，从而造成额外的传输开销。


[ 20.085537 109.1963   445.2395  ]
<NDArray 3 @gpu(0)>

In [18]:
net = nn.Sequential()
net.add(nn.Dense(1))
net.initialize(ctx=mx.gpu())
net(y)


[[0.0068339 ]
 [0.01366779]
 [0.02050169]]
<NDArray 3x1 @gpu(0)>

In [19]:
net[0].weight.data()


[[0.0068339]]
<NDArray 1x1 @gpu(0)>

In [8]:
import datetime
X = nd.random.normal(shape=(3000, 3000), ctx=mx.gpu())
Y = nd.random.normal(shape=(3000, 3000), ctx=mx.gpu())
start = datetime.datetime.now()
Z=X*Y
end = datetime.datetime.now()
(end - start).total_seconds()

0.0

In [9]:
A = nd.random.normal(shape=(3000, 3000))
B = nd.random.normal(shape=(3000, 3000))
start = datetime.datetime.now()
C=A*B
end = datetime.datetime.now()
(end - start).total_seconds()

0.000999

In [10]:
start = datetime.datetime.now()
print(Z)
end = datetime.datetime.now()
(end - start).total_seconds()


[[ 0.89465535 -1.9110221  -0.10733627 ... -0.1253494   0.04253551
   0.22577497]
 [ 0.5332856  -0.10255311  1.0588793  ... -1.2813464  -1.7994809
   0.4718116 ]
 [ 1.0075074   0.33341265 -0.56720525 ...  2.0934336  -0.06873859
   0.00940023]
 ...
 [ 1.7307513   0.5995128  -1.3888181  ... -0.03910211 -0.9913452
   0.04238579]
 [ 2.2259228   0.50668216 -0.5604384  ... -0.2725489  -1.4792203
   0.74461156]
 [-0.18553199  0.99963915 -1.2432638  ...  0.05341112  0.13708664
   0.3441057 ]]
<NDArray 3000x3000 @gpu(0)>


0.023164

In [11]:
start = datetime.datetime.now()
print(C)
end = datetime.datetime.now()
(end - start).total_seconds()


[[ 2.0642863e-01 -1.7820270e+00 -2.3931114e-01 ...  9.9502936e-02
  -1.2102258e+00  3.3463445e-01]
 [ 2.4327440e+00  9.3441612e-01  8.3713609e-01 ... -4.4053119e-01
  -3.3660698e-01 -9.5323630e-02]
 [-4.7700271e-01 -5.2013403e-01  6.5845318e-02 ...  4.0167645e-01
  -2.1521257e-01  4.8741961e-01]
 ...
 [ 1.0601908e+00  3.1669009e-01  5.7301050e-01 ...  1.7588601e+00
  -7.5194031e-02 -8.6234754e-01]
 [-2.4109283e-02 -3.0116704e-01  2.1263559e+00 ... -1.7738518e+00
   8.4608167e-01 -1.6735253e-01]
 [ 1.2022187e-03  3.6249746e-02  1.7142835e-01 ... -6.0034692e-01
   1.7511520e+00 -2.0559394e-01]]
<NDArray 3000x3000 @cpu(0)>


0.019946