-
Notifications
You must be signed in to change notification settings - Fork 144
/
nvidia.go
152 lines (128 loc) · 3.19 KB
/
nvidia.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package nvidia
import (
"fmt"
"strings"
log "github.com/golang/glog"
"github.com/NVIDIA/gpu-monitoring-tools/bindings/go/nvml"
"golang.org/x/net/context"
pluginapi "k8s.io/kubernetes/pkg/kubelet/apis/deviceplugin/v1beta1"
)
var (
gpuMemory uint
metric MemoryUnit
)
func check(err error) {
if err != nil {
log.Fatalln("Fatal:", err)
}
}
func generateFakeDeviceID(realID string, fakeCounter uint) string {
return fmt.Sprintf("%s-_-%d", realID, fakeCounter)
}
func extractRealDeviceID(fakeDeviceID string) string {
return strings.Split(fakeDeviceID, "-_-")[0]
}
func setGPUMemory(raw uint) {
v := raw
if metric == GiBPrefix {
v = raw / 1024
}
gpuMemory = v
log.Infof("set gpu memory: %d", gpuMemory)
}
func getGPUMemory() uint {
return gpuMemory
}
func getDeviceCount() uint {
n, err := nvml.GetDeviceCount()
check(err)
return n
}
func getDevices() ([]*pluginapi.Device, map[string]uint) {
n, err := nvml.GetDeviceCount()
check(err)
var devs []*pluginapi.Device
realDevNames := map[string]uint{}
for i := uint(0); i < n; i++ {
d, err := nvml.NewDevice(i)
check(err)
// realDevNames = append(realDevNames, d.UUID)
var id uint
log.Infof("Deivce %s's Path is %s", d.UUID, d.Path)
_, err = fmt.Sscanf(d.Path, "/dev/nvidia%d", &id)
check(err)
realDevNames[d.UUID] = id
// var KiB uint64 = 1024
log.Infof("# device Memory: %d", uint(*d.Memory))
if getGPUMemory() == uint(0) {
setGPUMemory(uint(*d.Memory))
}
for j := uint(0); j < getGPUMemory(); j++ {
fakeID := generateFakeDeviceID(d.UUID, j)
if j == 0 {
log.Infoln("# Add first device ID: " + fakeID)
}
if j == getGPUMemory()-1 {
log.Infoln("# Add last device ID: " + fakeID)
}
devs = append(devs, &pluginapi.Device{
ID: fakeID,
Health: pluginapi.Healthy,
})
}
}
return devs, realDevNames
}
func deviceExists(devs []*pluginapi.Device, id string) bool {
for _, d := range devs {
if d.ID == id {
return true
}
}
return false
}
func watchXIDs(ctx context.Context, devs []*pluginapi.Device, xids chan<- *pluginapi.Device) {
eventSet := nvml.NewEventSet()
defer nvml.DeleteEventSet(eventSet)
for _, d := range devs {
realDeviceID := extractRealDeviceID(d.ID)
err := nvml.RegisterEventForDevice(eventSet, nvml.XidCriticalError, realDeviceID)
if err != nil && strings.HasSuffix(err.Error(), "Not Supported") {
log.Infof("Warning: %s (%s) is too old to support healthchecking: %s. Marking it unhealthy.", realDeviceID, d.ID, err)
xids <- d
continue
}
if err != nil {
log.Fatalf("Fatal error:", err)
}
}
for {
select {
case <-ctx.Done():
return
default:
}
e, err := nvml.WaitForEvent(eventSet, 5000)
if err != nil && e.Etype != nvml.XidCriticalError {
continue
}
// FIXME: formalize the full list and document it.
// http://docs.nvidia.com/deploy/xid-errors/index.html#topic_4
// Application errors: the GPU should still be healthy
if e.Edata == 31 || e.Edata == 43 || e.Edata == 45 {
continue
}
if e.UUID == nil || len(*e.UUID) == 0 {
// All devices are unhealthy
for _, d := range devs {
xids <- d
}
continue
}
for _, d := range devs {
if extractRealDeviceID(d.ID) == *e.UUID {
xids <- d
}
}
}
}