-
Notifications
You must be signed in to change notification settings - Fork 18
/
Memory.scala
336 lines (232 loc) · 11.1 KB
/
Memory.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
package com.thoughtworks.compute
import java.nio._
import org.lwjgl.PointerBuffer
import org.lwjgl.system.{CustomBuffer, MemoryUtil, Pointer}
import scala.reflect.ClassTag
/**
* @author 杨博 (Yang Bo) <pop.atry@gmail.com>
*/
trait Memory[Element] {
type HostBuffer
def fromByteBuffer(byteBuffer: ByteBuffer): HostBuffer
def numberOfBytesPerElement: Int
def address(buffer: HostBuffer): Long
def remaining(buffer: HostBuffer): Int
def remainingBytes(buffer: HostBuffer): Int = numberOfBytesPerElement * remaining(buffer)
def get(buffer: HostBuffer, index: Int): Element
def put(buffer: HostBuffer, index: Int, value: Element): Unit
def allocate(numberOfElement: Int): HostBuffer
def free(buffer: HostBuffer): Unit
def toArray(buffer: HostBuffer): Array[Element]
}
object Memory extends LowPriorityMemory {
def apply[Element](implicit memory: Memory[Element]): memory.type = memory
type Aux[Element, HostBuffer0] = Memory[Element] {
type HostBuffer = HostBuffer0
}
trait NioMemory[Element] extends Memory[Element] {
type HostBuffer <: java.nio.Buffer
override def remaining(buffer: HostBuffer): Int = buffer.remaining
}
trait CustomMemory[Element] extends Memory[Element] {
type HostBuffer <: CustomBuffer[HostBuffer]
override def remaining(buffer: HostBuffer): Int = buffer.remaining
override def address(buffer: HostBuffer): Long = (buffer.address)
}
implicit object PointerMemory extends CustomMemory[Pointer] {
override type HostBuffer = PointerBuffer
override def numberOfBytesPerElement: Int = Pointer.POINTER_SIZE
override def fromByteBuffer(byteBuffer: ByteBuffer): PointerBuffer = {
PointerBuffer.create(byteBuffer)
}
override def get(buffer: PointerBuffer, index: Int): Pointer = new Pointer.Default(buffer.get(index)) {}
override def put(buffer: PointerBuffer, index: Int, value: Pointer): Unit = buffer.put(index, value)
override def allocate(numberOfElement: Int): PointerBuffer = MemoryUtil.memAllocPointer(numberOfElement)
override def free(buffer: PointerBuffer): Unit = MemoryUtil.memFree(buffer)
override def toArray(buffer: PointerBuffer): Array[Pointer] = {
val oldPosition = buffer.position()
val bufferToArray = Array.ofDim[Long](buffer.remaining())
buffer.get(bufferToArray, 0, bufferToArray.length)
buffer.position(oldPosition)
bufferToArray.map { long =>
new Pointer.Default(long) {}
}
}
}
// implicit object HNilMemory extends NioMemory[HNil] {
// override type HostBuffer = ByteBuffer
//
// override def fromByteBuffer(byteBuffer: ByteBuffer): ByteBuffer = byteBuffer
//
// override def numberOfBytesPerElement: Int = 0
//
// override def address(buffer: ByteBuffer): Long = MemoryUtil.memAddress(buffer)
//
// override def free(buffer: ByteBuffer): Unit = MemoryUtil.memFree(buffer)
//
// override def get(buffer: ByteBuffer, index: Int): HNil = HNil
//
// override def put(buffer: ByteBuffer, index: Int, value: HNil): Unit = {}
//
// override def allocate(numberOfElement: Int): ByteBuffer = MemoryUtil.memAlloc(1)
//
// }
implicit object IntMemory extends NioMemory[Int] {
override type HostBuffer = IntBuffer
override def fromByteBuffer(byteBuffer: ByteBuffer): IntBuffer = byteBuffer.asIntBuffer
override def numberOfBytesPerElement: Int = java.lang.Integer.BYTES
override def address(buffer: IntBuffer): Long = MemoryUtil.memAddress(buffer)
override def allocate(numberOfElement: Int): IntBuffer = MemoryUtil.memAllocInt(numberOfElement)
override def free(buffer: IntBuffer): Unit = MemoryUtil.memFree(buffer)
override def get(buffer: IntBuffer, index: Int): Int = buffer.get(index)
override def put(buffer: IntBuffer, index: Int, value: Int): Unit = buffer.put(index, value)
override def toArray(buffer: IntBuffer): Array[Int] = {
val oldPosition = buffer.position()
val bufferToArray = Array.ofDim[Int](buffer.remaining())
buffer.get(bufferToArray, 0, bufferToArray.length)
buffer.position(oldPosition)
bufferToArray
}
}
implicit object LongMemory extends NioMemory[Long] {
override type HostBuffer = LongBuffer
override def fromByteBuffer(byteBuffer: ByteBuffer): LongBuffer = byteBuffer.asLongBuffer
override def numberOfBytesPerElement: Int = java.lang.Long.BYTES
override def address(buffer: LongBuffer): Long = MemoryUtil.memAddress(buffer)
override def allocate(numberOfElement: Int): LongBuffer = MemoryUtil.memAllocLong(numberOfElement)
override def free(buffer: LongBuffer): Unit = MemoryUtil.memFree(buffer)
override def get(buffer: LongBuffer, index: Int): Long = buffer.get(index)
override def put(buffer: LongBuffer, index: Int, value: Long): Unit = buffer.put(index, value)
override def toArray(buffer: LongBuffer): Array[Long] = {
val oldPosition = buffer.position()
val bufferToArray = Array.ofDim[Long](buffer.remaining())
buffer.get(bufferToArray, 0, bufferToArray.length)
buffer.position(oldPosition)
bufferToArray
}
}
implicit object DoubleMemory extends NioMemory[Double] {
override type HostBuffer = DoubleBuffer
override def fromByteBuffer(byteBuffer: ByteBuffer): DoubleBuffer = byteBuffer.asDoubleBuffer
override def numberOfBytesPerElement: Int = java.lang.Double.BYTES
override def address(buffer: DoubleBuffer): Long = MemoryUtil.memAddress(buffer)
override def allocate(numberOfElement: Int): DoubleBuffer = MemoryUtil.memAllocDouble(numberOfElement)
override def free(buffer: DoubleBuffer): Unit = MemoryUtil.memFree(buffer)
override def get(buffer: DoubleBuffer, index: Int): Double = buffer.get(index)
override def put(buffer: DoubleBuffer, index: Int, value: Double): Unit = buffer.put(index, value)
override def toArray(buffer: DoubleBuffer): Array[Double] = {
val oldPosition = buffer.position()
val bufferToArray = Array.ofDim[Double](buffer.remaining())
buffer.get(bufferToArray, 0, bufferToArray.length)
buffer.position(oldPosition)
bufferToArray
}
}
implicit object FloatMemory extends NioMemory[Float] {
override type HostBuffer = FloatBuffer
override def fromByteBuffer(byteBuffer: ByteBuffer): FloatBuffer = byteBuffer.asFloatBuffer
override def numberOfBytesPerElement: Int = java.lang.Float.BYTES
override def address(buffer: FloatBuffer): Long = MemoryUtil.memAddress(buffer)
override def allocate(numberOfElement: Int): FloatBuffer = MemoryUtil.memAllocFloat(numberOfElement)
override def free(buffer: FloatBuffer): Unit = MemoryUtil.memFree(buffer)
override def get(buffer: FloatBuffer, index: Int): Float = buffer.get(index)
override def put(buffer: FloatBuffer, index: Int, value: Float): Unit = buffer.put(index, value)
override def toArray(buffer: FloatBuffer): Array[Float] = {
val oldPosition = buffer.position()
val bufferToArray = new Array[Float](buffer.remaining())
buffer.get(bufferToArray, 0, bufferToArray.length)
buffer.position(oldPosition)
bufferToArray
}
}
implicit object ByteMemory extends NioMemory[Byte] {
override type HostBuffer = ByteBuffer
override def fromByteBuffer(byteBuffer: ByteBuffer): ByteBuffer = byteBuffer
override def numberOfBytesPerElement: Int = java.lang.Byte.BYTES
override def address(buffer: ByteBuffer): Long = MemoryUtil.memAddress(buffer)
override def allocate(numberOfElement: Int): ByteBuffer = MemoryUtil.memAlloc(numberOfElement)
override def free(buffer: ByteBuffer): Unit = MemoryUtil.memFree(buffer)
override def get(buffer: ByteBuffer, index: Int): Byte = buffer.get(index)
override def put(buffer: ByteBuffer, index: Int, value: Byte): Unit = buffer.put(index, value)
override def toArray(buffer: ByteBuffer): Array[Byte] = {
val oldPosition = buffer.position()
val bufferToArray = Array.ofDim[Byte](buffer.remaining())
buffer.get(bufferToArray, 0, bufferToArray.length)
buffer.position(oldPosition)
bufferToArray
}
}
implicit object ShortMemory extends NioMemory[Short] {
override type HostBuffer = ShortBuffer
override def fromByteBuffer(byteBuffer: ByteBuffer): ShortBuffer = byteBuffer.asShortBuffer()
override def numberOfBytesPerElement: Int = java.lang.Short.BYTES
override def address(buffer: ShortBuffer): Long = MemoryUtil.memAddress(buffer)
override def allocate(numberOfElement: Int): ShortBuffer = MemoryUtil.memAllocShort(numberOfElement)
override def free(buffer: ShortBuffer): Unit = MemoryUtil.memFree(buffer)
override def get(buffer: ShortBuffer, index: Int): Short = buffer.get(index)
override def put(buffer: ShortBuffer, index: Int, value: Short): Unit = buffer.put(index, value)
override def toArray(buffer: ShortBuffer): Array[Short] = {
val oldPosition = buffer.position()
val bufferToArray = Array.ofDim[Short](buffer.remaining())
buffer.get(bufferToArray, 0, bufferToArray.length)
buffer.position(oldPosition)
bufferToArray
}
}
// TODO: short, bool, char
trait Box[Boxed] {
type Raw
def box(raw: Raw): Boxed
def unbox(boxed: Boxed): Raw
}
object Box {
type Aux[Boxed, Raw0] = Box[Boxed] {
type Raw = Raw0
}
}
final class BoxedMemory[Raw, Boxed, HostBuffer0](implicit box: Box.Aux[Boxed, Raw],
rawMemory: Memory.Aux[Raw, HostBuffer0],
classTag: ClassTag[Boxed])
extends Memory[Boxed] {
override type HostBuffer = HostBuffer0
override def fromByteBuffer(byteBuffer: ByteBuffer): HostBuffer = {
rawMemory.fromByteBuffer(byteBuffer)
}
override def numberOfBytesPerElement: Int = {
rawMemory.numberOfBytesPerElement
}
override def remaining(buffer: HostBuffer): Int = {
rawMemory.remaining(buffer)
}
override def get(buffer: HostBuffer, index: Int): Boxed = {
box.box(rawMemory.get(buffer, index))
}
override def put(buffer: HostBuffer, index: Int, value: Boxed): Unit = {
rawMemory.put(buffer, index, box.unbox(value))
}
override def address(buffer: HostBuffer): Long = {
rawMemory.address(buffer)
}
override def allocate(numberOfElement: Int): HostBuffer0 = rawMemory.allocate(numberOfElement)
override def free(buffer: HostBuffer0): Unit = rawMemory.free(buffer)
override def toArray(buffer: HostBuffer0): Array[Boxed] = {
// ???
val rawArray: Array[Raw] = rawMemory.toArray(buffer)
val boxedArray = new Array[Boxed](rawArray.length)
var i = 0
while (i < rawArray.length) {
boxedArray(i) = box.box(rawArray(i))
i += 1
}
boxedArray
}
}
}
private[compute] trait LowPriorityMemory {
this: Memory.type =>
implicit def boxedMemory[Raw, Boxed, HostBuffer0](implicit box: Box.Aux[Boxed, Raw],
rawMemory: Memory.Aux[Raw, HostBuffer0],
classTag: ClassTag[Boxed]): BoxedMemory[Raw, Boxed, HostBuffer0] = {
new BoxedMemory[Raw, Boxed, HostBuffer0]
}
}