This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
/
Context.scala
83 lines (69 loc) · 2.41 KB
/
Context.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnet
object Context {
val devtype2str = Map(1 -> "cpu", 2 -> "gpu", 3 -> "cpu_pinned")
val devstr2type = Map("cpu" -> 1, "gpu" -> 2, "cpu_pinned" -> 3)
private var _defaultCtx = new Context("cpu", 0)
def defaultCtx: Context = _defaultCtx
def cpu(deviceId: Int = 0): Context = {
new Context("cpu", deviceId)
}
def gpu(deviceId: Int = 0): Context = {
new Context("gpu", deviceId)
}
implicit def ctx2Array(ctx: Context): Array[Context] = Array(ctx)
}
/**
* Constructing a context.
* @param deviceTypeName {'cpu', 'gpu'} String representing the device type
* @param deviceId (default=0) The device id of the device, needed for GPU
*/
class Context(deviceTypeName: String, val deviceId: Int = 0) extends Serializable {
val deviceTypeid: Int = Context.devstr2type(deviceTypeName)
def this(context: Context) = {
this(context.deviceType, context.deviceId)
}
def withScope[T](body: => T): T = {
val oldDefaultCtx = Context.defaultCtx
Context._defaultCtx = this
try {
body
} finally {
Context._defaultCtx = oldDefaultCtx
}
}
/**
* Return device type of current context.
* @return device_type
*/
def deviceType: String = Context.devtype2str(deviceTypeid)
override def toString: String = {
s"$deviceType($deviceId)"
}
override def equals(other: Any): Boolean = {
if (other != null && other.isInstanceOf[Context]) {
val otherInst = other.asInstanceOf[Context]
otherInst.deviceId == deviceId && otherInst.deviceTypeid == deviceTypeid
} else {
false
}
}
override def hashCode: Int = {
toString.hashCode
}
}