Skip to content

Commit 68c2167

Browse files
committed
Add AWS Trainium device
1 parent 5df55b6 commit 68c2167

File tree

3 files changed

+5
-0
lines changed

3 files changed

+5
-0
lines changed

dpctl/tensor/_dlpack.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ cdef extern from "dlpack/dlpack.h" nogil:
4141
int device_WebGPU "kDLWebGPU"
4242
int device_Hexagon "kDLHexagon"
4343
int device_MAIA "kDLMAIA"
44+
int device_Trn "kDLTrn"
4445

4546
cpdef object to_dlpack_capsule(usm_ndarray array) except +
4647
cpdef object to_dlpack_versioned_capsule(

dpctl/tensor/_dlpack.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ cdef extern from "dlpack/dlpack.h" nogil:
7676
kDLWebGPU
7777
kDLHexagon
7878
kDLMAIA
79+
kDLTrn
7980

8081
ctypedef struct DLDevice:
8182
DLDeviceType device_type

dpctl/tensor/_usmarray.pyx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ class DLDeviceType(IntEnum):
8686
Qualcomm Hexagon DSP
8787
``kDLMAIA``:
8888
Microsoft MAIA device
89+
``kDLTrn``:
90+
AWS Trainium device
8991
"""
9092
kDLCPU = c_dlpack.device_CPU
9193
kDLCUDA = c_dlpack.device_CUDA
@@ -101,6 +103,7 @@ class DLDeviceType(IntEnum):
101103
kDLWebGPU = c_dlpack.device_WebGPU
102104
kDLHexagon = c_dlpack.device_Hexagon
103105
kDLMAIA = c_dlpack.device_MAIA
106+
kDLTrn = c_dlpack.device_Trn
104107

105108

106109
cdef class InternalUSMArrayError(Exception):

0 commit comments

Comments
 (0)