diff --git a/hipamd/src/hip_graph.cpp b/hipamd/src/hip_graph.cpp index c467f18ea8..229523d6b8 100644 --- a/hipamd/src/hip_graph.cpp +++ b/hipamd/src/hip_graph.cpp @@ -51,7 +51,7 @@ inline hipError_t ihipGraphUpload(hipGraphExec_t graphExec, hipStream_t stream) inline hipError_t ihipGraphAddNode(hip::GraphNode* graphNode, hip::Graph* graph, hip::GraphNode* const* pDependencies, size_t numDependencies, - bool capture = true) { + bool capture = true, int devId = 0) { graph->AddNode(graphNode); std::unordered_set DuplicateDep; for (size_t i = 0; i < numDependencies; i++) { @@ -76,6 +76,9 @@ inline hipError_t ihipGraphAddNode(hip::GraphNode* graphNode, hip::Graph* graph, } } } + if (devId != 0) { + graphNode->SetDeviceId(devId); + } return hipSuccess; } @@ -107,16 +110,14 @@ hipError_t ihipGraphAddKernelNode(hip::GraphNode** pGraphNode, hip::Graph* graph } *pGraphNode = new hip::GraphKernelNode(pNodeParams, pNodeEvents, coopKernel); - if (devId != 0) { - (*pGraphNode)->SetDeviceId(devId); - } - status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture); + status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture, devId); return status; } hipError_t ihipGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* graph, hip::GraphNode* const* pDependencies, size_t numDependencies, - const hipMemcpy3DParms* pCopyParams, bool capture = true) { + const hipMemcpy3DParms* pCopyParams, bool capture = true, + int devId = 0) { if (pGraphNode == nullptr || graph == nullptr || (numDependencies > 0 && pDependencies == nullptr) || pCopyParams == nullptr) { return hipErrorInvalidValue; @@ -133,7 +134,7 @@ hipError_t ihipGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* graph hipError_t ihipDrvGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* graph, hip::GraphNode* const* pDependencies, size_t numDependencies, const HIP_MEMCPY3D* pCopyParams, hipCtx_t ctx, - bool capture = true) { + bool capture = true, int devId = 0) { if (pGraphNode == nullptr || graph == nullptr || (numDependencies > 0 && pDependencies == nullptr) || pCopyParams == nullptr) { return hipErrorInvalidValue; @@ -150,7 +151,7 @@ hipError_t ihipDrvGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* gr hipError_t ihipGraphAddMemcpyNode1D(hip::GraphNode** pGraphNode, hip::Graph* graph, hip::GraphNode* const* pDependencies, size_t numDependencies, void* dst, const void* src, size_t count, hipMemcpyKind kind, - bool capture = true) { + bool capture = true, int devId = 0) { if (pGraphNode == nullptr || graph == nullptr || (numDependencies > 0 && pDependencies == nullptr) || count ==0) { return hipErrorInvalidValue; @@ -167,7 +168,8 @@ hipError_t ihipGraphAddMemcpyNode1D(hip::GraphNode** pGraphNode, hip::Graph* gra hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph, hip::GraphNode* const* pDependencies, size_t numDependencies, const hipMemsetParams* pMemsetParams, bool capture = true, - size_t depth = 1, size_t arrWidth = 1, size_t arrHeight = 1) { + size_t depth = 1, size_t arrWidth = 1, size_t arrHeight = 1, + int devId = 0) { if (pGraphNode == nullptr || graph == nullptr || pMemsetParams == nullptr || (numDependencies > 0 && pDependencies == nullptr) || pMemsetParams->height == 0) { return hipErrorInvalidValue; @@ -204,7 +206,7 @@ hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph return status; } *pGraphNode = new hip::GraphMemsetNode(pMemsetParams, depth, arrWidth, arrHeight); - status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture); + status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture, devId); return status; } @@ -275,9 +277,9 @@ hipError_t ihipExtLaunchKernel(hipStream_t stream, hipFunction_t f, uint32_t glo nodeParams.kernelParams = kernelParams; nodeParams.sharedMemBytes = sharedMemBytes; - status = - ihipGraphAddKernelNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), - s->GetLastCapturedNodes().size(), &nodeParams, &nodeEvents); + status = ihipGraphAddKernelNode( + &pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), + s->GetLastCapturedNodes().size(), &nodeParams, &nodeEvents, capture, 0, s->DeviceId()); if (status != hipSuccess) { return status; @@ -454,7 +456,7 @@ hipError_t capturehipMemcpy3DAsync(hipStream_t& stream, const hipMemcpy3DParms*& hip::GraphNode* pGraphNode; hipError_t status = ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), - s->GetLastCapturedNodes().size(), p); + s->GetLastCapturedNodes().size(), p, true, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -501,7 +503,7 @@ hipError_t capturehipMemcpy2DAsync(hipStream_t& stream, void*& dst, size_t& dpit hipError_t status = ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), - s->GetLastCapturedNodes().size(), &p); + s->GetLastCapturedNodes().size(), &p, true, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -543,7 +545,7 @@ hipError_t capturehipMemcpy2DFromArrayAsync(hipStream_t& stream, void*& dst, siz p.extent = {width / hip::getElementSize(p.srcArray), height, 1}; hipError_t status = ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), - s->GetLastCapturedNodes().size(), &p); + s->GetLastCapturedNodes().size(), &p, true, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -584,7 +586,7 @@ hipError_t capturehipMemcpy2DToArrayAsync(hipStream_t& stream, hipArray_t& dst, p.extent = {width / hip::getElementSize(p.dstArray), height, 1}; hipError_t status = ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), - s->GetLastCapturedNodes().size(), &p); + s->GetLastCapturedNodes().size(), &p, true, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -658,7 +660,7 @@ hipError_t capturehipMemcpyParam2DAsync(hipStream_t& stream, const hip_Memcpy2D* } hipError_t status = ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), - s->GetLastCapturedNodes().size(), &p); + s->GetLastCapturedNodes().size(), &p, true, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -687,7 +689,7 @@ hipError_t capturehipMemcpyAtoHAsync(hipStream_t& stream, void*& dstHost, hipArr p.kind = hipMemcpyDeviceToHost; hipError_t status = ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), - s->GetLastCapturedNodes().size(), &p); + s->GetLastCapturedNodes().size(), &p, true, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -715,7 +717,7 @@ hipError_t capturehipMemcpyHtoAAsync(hipStream_t& stream, hipArray_t& dstArray, p.extent = {ByteCount / hip::getElementSize(p.dstArray), 1, 1}; hipError_t status = ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), - s->GetLastCapturedNodes().size(), &p); + s->GetLastCapturedNodes().size(), &p, true, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -732,12 +734,9 @@ hipError_t capturehipMemcpy(hipStream_t stream, void* dst, const void* src, size std::vector pDependencies = s->GetLastCapturedNodes(); size_t numDependencies = s->GetLastCapturedNodes().size(); hip::Graph* graph = s->GetCaptureGraph(); - hipError_t status = ihipMemcpy_validate(dst, src, sizeBytes, kind); - if (status != hipSuccess) { - return status; - } hip::GraphNode* node = new hip::GraphMemcpyNode1D(dst, src, sizeBytes, kind); - status = ihipGraphAddNode(node, graph, pDependencies.data(), numDependencies); + hipError_t status = ihipGraphAddMemcpyNode1D(&node, graph, pDependencies.data(), numDependencies, + dst, src, sizeBytes, kind, true, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -814,7 +813,7 @@ hipError_t capturehipMemcpyFromSymbolAsync(hipStream_t& stream, void*& dst, cons hip::GraphNode* pGraphNode = new hip::GraphMemcpyNodeFromSymbol(dst, symbol, sizeBytes, offset, kind); status = ihipGraphAddNode(pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), - s->GetLastCapturedNodes().size()); + s->GetLastCapturedNodes().size(), true, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -848,7 +847,7 @@ hipError_t capturehipMemcpyToSymbolAsync(hipStream_t& stream, const void*& symbo hip::Stream* s = reinterpret_cast(stream); hip::GraphNode* pGraphNode = new hip::GraphMemcpyNodeToSymbol(symbol, src, sizeBytes, offset, kind); status = ihipGraphAddNode(pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), - s->GetLastCapturedNodes().size()); + s->GetLastCapturedNodes().size(), true, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -872,9 +871,9 @@ hipError_t capturehipMemsetAsync(hipStream_t& stream, void*& dst, int& value, si hip::Stream* s = reinterpret_cast(stream); hip::GraphNode* pGraphNode; - hipError_t status = - ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), - s->GetLastCapturedNodes().size(), &memsetParams); + hipError_t status = ihipGraphAddMemsetNode( + &pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), + s->GetLastCapturedNodes().size(), &memsetParams, true, 1, 1, 1, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -898,9 +897,9 @@ hipError_t capturehipMemset2DAsync(hipStream_t& stream, void*& dst, size_t& pitc memsetParams.elementSize = 1; hip::Stream* s = reinterpret_cast(stream); hip::GraphNode* pGraphNode; - hipError_t status = - ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), - s->GetLastCapturedNodes().size(), &memsetParams); + hipError_t status = ihipGraphAddMemsetNode( + &pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), + s->GetLastCapturedNodes().size(), &memsetParams, true, 1, 1, 1, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -933,7 +932,7 @@ hipError_t capturehipMemset3DAsync(hipStream_t& stream, hipPitchedPtr& pitchedDe hipError_t status = ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), s->GetLastCapturedNodes().size(), &memsetParams, true, extent.depth, - pitchedDevPtr.xsize, pitchedDevPtr.ysize); + pitchedDevPtr.xsize, pitchedDevPtr.ysize, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -957,7 +956,7 @@ hipError_t capturehipLaunchHostFunc(hipStream_t& stream, hipHostFn_t& fn, void*& hip::GraphNode* pGraphNode = new hip::GraphHostNode(&hostParams); hipError_t status = ihipGraphAddNode(pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), - s->GetLastCapturedNodes().size()); + s->GetLastCapturedNodes().size(), true, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -990,8 +989,9 @@ hipError_t capturehipMallocAsync(hipStream_t stream, hipMemPool_t mem_pool, node_params.bytesize = size; auto mem_alloc_node = new hip::GraphMemAllocNode(&node_params); - auto status = ihipGraphAddNode(mem_alloc_node, s->GetCaptureGraph(), - s->GetLastCapturedNodes().data(), s->GetLastCapturedNodes().size()); + auto status = + ihipGraphAddNode(mem_alloc_node, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), + s->GetLastCapturedNodes().size(), true, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -1006,8 +1006,9 @@ hipError_t capturehipMallocAsync(hipStream_t stream, hipMemPool_t mem_pool, hipError_t capturehipFreeAsync(hipStream_t stream, void* dev_ptr) { hip::Stream* s = reinterpret_cast(stream); auto mem_free_node = new hip::GraphMemFreeNode(dev_ptr); - auto status = ihipGraphAddNode(mem_free_node, s->GetCaptureGraph(), - s->GetLastCapturedNodes().data(), s->GetLastCapturedNodes().size()); + auto status = + ihipGraphAddNode(mem_free_node, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), + s->GetLastCapturedNodes().size(), true, s->DeviceId()); if (status != hipSuccess) { return status; } @@ -1594,6 +1595,9 @@ hipError_t hipGraphLaunch_common(hip::GraphExec* graphExec, hipStream_t stream) if (graphExec == nullptr || !hip::GraphExec::isGraphExecValid(graphExec)) { return hipErrorInvalidValue; } + if (!hip::isValid(stream)) { + return hipErrorContextIsDestroyed; + } if (graphExec->GetNodeCount() == 0) { return hipSuccess; } diff --git a/hipamd/src/hip_graph_internal.hpp b/hipamd/src/hip_graph_internal.hpp index 5934a16f5f..4bc165d0e5 100644 --- a/hipamd/src/hip_graph_internal.hpp +++ b/hipamd/src/hip_graph_internal.hpp @@ -1319,6 +1319,7 @@ class GraphKernelNode : public GraphNode { } hipError_t SetParams(GraphNode* node) override { + dev_id_ = ihipGetDevice(); const GraphKernelNode* kernelNode = static_cast(node); return SetParams(&kernelNode->kernelParams_); } @@ -1528,13 +1529,37 @@ class GraphMemcpyNode1D : public GraphMemcpyNode { hipMemcpyKind kind_; public: + // When device memory is on dev1 and graph node is added from different device update the device + // id accordingly so that node can be executed on dev1. + void UpdateDevId() { + size_t sOffset = 0; + amd::Memory* srcMemory = getMemoryObject(src_, sOffset); + size_t dOffset = 0; + amd::Memory* dstMemory = getMemoryObject(dst_, dOffset); + hip::MemcpyType memType = ihipGetMemcpyType(src_, dst_, kind_); + switch (memType) { + case hipCopyBuffer: + // D2H/H2D source/dst is pinned memory + // Override the device id when node is created + if (!((srcMemory->GetDeviceById() != dstMemory->GetDeviceById()) && + srcMemory->getContext().devices().size() == 1 && + dstMemory->getContext().devices().size() == 1)) { + if (srcMemory->getContext().devices().size() == 1) { + dev_id_ = srcMemory->GetDeviceById()->index(); + } else { + dev_id_ = dstMemory->GetDeviceById()->index(); + } + } + break; + default: + break; + } + } GraphMemcpyNode1D(void* dst, const void* src, size_t count, hipMemcpyKind kind, hipGraphNodeType type = hipGraphNodeTypeMemcpy) - : GraphMemcpyNode(nullptr), - dst_(dst), - src_(src), - count_(count), - kind_(kind) {} + : GraphMemcpyNode(nullptr), dst_(dst), src_(src), count_(count), kind_(kind) { + UpdateDevId(); + } ~GraphMemcpyNode1D() {} @@ -1543,6 +1568,7 @@ class GraphMemcpyNode1D : public GraphMemcpyNode { src_ = rhs.src_; count_ = rhs.count_; kind_ = rhs.kind_; + UpdateDevId(); } GraphNode* clone() const override { return new GraphMemcpyNode1D(*this); } @@ -1647,6 +1673,7 @@ class GraphMemcpyNode1D : public GraphMemcpyNode { src_ = src; count_ = count; kind_ = kind; + UpdateDevId(); return hipSuccess; }