Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 43 additions & 39 deletions hipamd/src/hip_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ inline hipError_t ihipGraphUpload(hipGraphExec_t graphExec, hipStream_t stream)

inline hipError_t ihipGraphAddNode(hip::GraphNode* graphNode, hip::Graph* graph,
hip::GraphNode* const* pDependencies, size_t numDependencies,
bool capture = true) {
bool capture = true, int devId = 0) {
graph->AddNode(graphNode);
std::unordered_set<hip::GraphNode*> DuplicateDep;
for (size_t i = 0; i < numDependencies; i++) {
Expand All @@ -76,6 +76,9 @@ inline hipError_t ihipGraphAddNode(hip::GraphNode* graphNode, hip::Graph* graph,
}
}
}
if (devId != 0) {
graphNode->SetDeviceId(devId);
}
return hipSuccess;
}

Expand Down Expand Up @@ -107,16 +110,14 @@ hipError_t ihipGraphAddKernelNode(hip::GraphNode** pGraphNode, hip::Graph* graph
}

*pGraphNode = new hip::GraphKernelNode(pNodeParams, pNodeEvents, coopKernel);
if (devId != 0) {
(*pGraphNode)->SetDeviceId(devId);
}
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture);
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture, devId);
return status;
}

hipError_t ihipGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* graph,
hip::GraphNode* const* pDependencies, size_t numDependencies,
const hipMemcpy3DParms* pCopyParams, bool capture = true) {
const hipMemcpy3DParms* pCopyParams, bool capture = true,
int devId = 0) {
if (pGraphNode == nullptr || graph == nullptr ||
(numDependencies > 0 && pDependencies == nullptr) || pCopyParams == nullptr) {
return hipErrorInvalidValue;
Expand All @@ -133,7 +134,7 @@ hipError_t ihipGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* graph
hipError_t ihipDrvGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* graph,
hip::GraphNode* const* pDependencies, size_t numDependencies,
const HIP_MEMCPY3D* pCopyParams, hipCtx_t ctx,
bool capture = true) {
bool capture = true, int devId = 0) {
if (pGraphNode == nullptr || graph == nullptr ||
(numDependencies > 0 && pDependencies == nullptr) || pCopyParams == nullptr) {
return hipErrorInvalidValue;
Expand All @@ -150,7 +151,7 @@ hipError_t ihipDrvGraphAddMemcpyNode(hip::GraphNode** pGraphNode, hip::Graph* gr
hipError_t ihipGraphAddMemcpyNode1D(hip::GraphNode** pGraphNode, hip::Graph* graph,
hip::GraphNode* const* pDependencies, size_t numDependencies,
void* dst, const void* src, size_t count, hipMemcpyKind kind,
bool capture = true) {
bool capture = true, int devId = 0) {
if (pGraphNode == nullptr || graph == nullptr ||
(numDependencies > 0 && pDependencies == nullptr) || count ==0) {
return hipErrorInvalidValue;
Expand All @@ -167,7 +168,8 @@ hipError_t ihipGraphAddMemcpyNode1D(hip::GraphNode** pGraphNode, hip::Graph* gra
hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph,
hip::GraphNode* const* pDependencies, size_t numDependencies,
const hipMemsetParams* pMemsetParams, bool capture = true,
size_t depth = 1, size_t arrWidth = 1, size_t arrHeight = 1) {
size_t depth = 1, size_t arrWidth = 1, size_t arrHeight = 1,
int devId = 0) {
if (pGraphNode == nullptr || graph == nullptr || pMemsetParams == nullptr ||
(numDependencies > 0 && pDependencies == nullptr) || pMemsetParams->height == 0) {
return hipErrorInvalidValue;
Expand Down Expand Up @@ -204,7 +206,7 @@ hipError_t ihipGraphAddMemsetNode(hip::GraphNode** pGraphNode, hip::Graph* graph
return status;
}
*pGraphNode = new hip::GraphMemsetNode(pMemsetParams, depth, arrWidth, arrHeight);
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture);
status = ihipGraphAddNode(*pGraphNode, graph, pDependencies, numDependencies, capture, devId);
return status;
}

Expand Down Expand Up @@ -275,9 +277,9 @@ hipError_t ihipExtLaunchKernel(hipStream_t stream, hipFunction_t f, uint32_t glo
nodeParams.kernelParams = kernelParams;
nodeParams.sharedMemBytes = sharedMemBytes;

status =
ihipGraphAddKernelNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &nodeParams, &nodeEvents);
status = ihipGraphAddKernelNode(
&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &nodeParams, &nodeEvents, capture, 0, s->DeviceId());

if (status != hipSuccess) {
return status;
Expand Down Expand Up @@ -454,7 +456,7 @@ hipError_t capturehipMemcpy3DAsync(hipStream_t& stream, const hipMemcpy3DParms*&
hip::GraphNode* pGraphNode;
hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), p);
s->GetLastCapturedNodes().size(), p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand Down Expand Up @@ -501,7 +503,7 @@ hipError_t capturehipMemcpy2DAsync(hipStream_t& stream, void*& dst, size_t& dpit

hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &p);
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand Down Expand Up @@ -543,7 +545,7 @@ hipError_t capturehipMemcpy2DFromArrayAsync(hipStream_t& stream, void*& dst, siz
p.extent = {width / hip::getElementSize(p.srcArray), height, 1};
hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &p);
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand Down Expand Up @@ -584,7 +586,7 @@ hipError_t capturehipMemcpy2DToArrayAsync(hipStream_t& stream, hipArray_t& dst,
p.extent = {width / hip::getElementSize(p.dstArray), height, 1};
hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &p);
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand Down Expand Up @@ -658,7 +660,7 @@ hipError_t capturehipMemcpyParam2DAsync(hipStream_t& stream, const hip_Memcpy2D*
}
hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &p);
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand Down Expand Up @@ -687,7 +689,7 @@ hipError_t capturehipMemcpyAtoHAsync(hipStream_t& stream, void*& dstHost, hipArr
p.kind = hipMemcpyDeviceToHost;
hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &p);
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand Down Expand Up @@ -715,7 +717,7 @@ hipError_t capturehipMemcpyHtoAAsync(hipStream_t& stream, hipArray_t& dstArray,
p.extent = {ByteCount / hip::getElementSize(p.dstArray), 1, 1};
hipError_t status =
ihipGraphAddMemcpyNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &p);
s->GetLastCapturedNodes().size(), &p, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand All @@ -732,12 +734,9 @@ hipError_t capturehipMemcpy(hipStream_t stream, void* dst, const void* src, size
std::vector<hip::GraphNode*> pDependencies = s->GetLastCapturedNodes();
size_t numDependencies = s->GetLastCapturedNodes().size();
hip::Graph* graph = s->GetCaptureGraph();
hipError_t status = ihipMemcpy_validate(dst, src, sizeBytes, kind);
if (status != hipSuccess) {
return status;
}
hip::GraphNode* node = new hip::GraphMemcpyNode1D(dst, src, sizeBytes, kind);
status = ihipGraphAddNode(node, graph, pDependencies.data(), numDependencies);
hipError_t status = ihipGraphAddMemcpyNode1D(&node, graph, pDependencies.data(), numDependencies,
dst, src, sizeBytes, kind, true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand Down Expand Up @@ -814,7 +813,7 @@ hipError_t capturehipMemcpyFromSymbolAsync(hipStream_t& stream, void*& dst, cons
hip::GraphNode* pGraphNode =
new hip::GraphMemcpyNodeFromSymbol(dst, symbol, sizeBytes, offset, kind);
status = ihipGraphAddNode(pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size());
s->GetLastCapturedNodes().size(), true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand Down Expand Up @@ -848,7 +847,7 @@ hipError_t capturehipMemcpyToSymbolAsync(hipStream_t& stream, const void*& symbo
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
hip::GraphNode* pGraphNode = new hip::GraphMemcpyNodeToSymbol(symbol, src, sizeBytes, offset, kind);
status = ihipGraphAddNode(pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size());
s->GetLastCapturedNodes().size(), true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand All @@ -872,9 +871,9 @@ hipError_t capturehipMemsetAsync(hipStream_t& stream, void*& dst, int& value, si

hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
hip::GraphNode* pGraphNode;
hipError_t status =
ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &memsetParams);
hipError_t status = ihipGraphAddMemsetNode(
&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &memsetParams, true, 1, 1, 1, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand All @@ -898,9 +897,9 @@ hipError_t capturehipMemset2DAsync(hipStream_t& stream, void*& dst, size_t& pitc
memsetParams.elementSize = 1;
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
hip::GraphNode* pGraphNode;
hipError_t status =
ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &memsetParams);
hipError_t status = ihipGraphAddMemsetNode(
&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &memsetParams, true, 1, 1, 1, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand Down Expand Up @@ -933,7 +932,7 @@ hipError_t capturehipMemset3DAsync(hipStream_t& stream, hipPitchedPtr& pitchedDe
hipError_t status =
ihipGraphAddMemsetNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), &memsetParams, true, extent.depth,
pitchedDevPtr.xsize, pitchedDevPtr.ysize);
pitchedDevPtr.xsize, pitchedDevPtr.ysize, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand All @@ -957,7 +956,7 @@ hipError_t capturehipLaunchHostFunc(hipStream_t& stream, hipHostFn_t& fn, void*&
hip::GraphNode* pGraphNode = new hip::GraphHostNode(&hostParams);
hipError_t status =
ihipGraphAddNode(pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size());
s->GetLastCapturedNodes().size(), true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand Down Expand Up @@ -990,8 +989,9 @@ hipError_t capturehipMallocAsync(hipStream_t stream, hipMemPool_t mem_pool,
node_params.bytesize = size;

auto mem_alloc_node = new hip::GraphMemAllocNode(&node_params);
auto status = ihipGraphAddNode(mem_alloc_node, s->GetCaptureGraph(),
s->GetLastCapturedNodes().data(), s->GetLastCapturedNodes().size());
auto status =
ihipGraphAddNode(mem_alloc_node, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand All @@ -1006,8 +1006,9 @@ hipError_t capturehipMallocAsync(hipStream_t stream, hipMemPool_t mem_pool,
hipError_t capturehipFreeAsync(hipStream_t stream, void* dev_ptr) {
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
auto mem_free_node = new hip::GraphMemFreeNode(dev_ptr);
auto status = ihipGraphAddNode(mem_free_node, s->GetCaptureGraph(),
s->GetLastCapturedNodes().data(), s->GetLastCapturedNodes().size());
auto status =
ihipGraphAddNode(mem_free_node, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(),
s->GetLastCapturedNodes().size(), true, s->DeviceId());
if (status != hipSuccess) {
return status;
}
Expand Down Expand Up @@ -1594,6 +1595,9 @@ hipError_t hipGraphLaunch_common(hip::GraphExec* graphExec, hipStream_t stream)
if (graphExec == nullptr || !hip::GraphExec::isGraphExecValid(graphExec)) {
return hipErrorInvalidValue;
}
if (!hip::isValid(stream)) {
return hipErrorContextIsDestroyed;
}
if (graphExec->GetNodeCount() == 0) {
return hipSuccess;
}
Expand Down
37 changes: 32 additions & 5 deletions hipamd/src/hip_graph_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,7 @@ class GraphKernelNode : public GraphNode {
}

hipError_t SetParams(GraphNode* node) override {
dev_id_ = ihipGetDevice();
const GraphKernelNode* kernelNode = static_cast<GraphKernelNode const*>(node);
return SetParams(&kernelNode->kernelParams_);
}
Expand Down Expand Up @@ -1528,13 +1529,37 @@ class GraphMemcpyNode1D : public GraphMemcpyNode {
hipMemcpyKind kind_;

public:
// When device memory is on dev1 and graph node is added from different device update the device
// id accordingly so that node can be executed on dev1.
void UpdateDevId() {
size_t sOffset = 0;
amd::Memory* srcMemory = getMemoryObject(src_, sOffset);
size_t dOffset = 0;
amd::Memory* dstMemory = getMemoryObject(dst_, dOffset);
hip::MemcpyType memType = ihipGetMemcpyType(src_, dst_, kind_);
switch (memType) {
case hipCopyBuffer:
// D2H/H2D source/dst is pinned memory
// Override the device id when node is created
if (!((srcMemory->GetDeviceById() != dstMemory->GetDeviceById()) &&
srcMemory->getContext().devices().size() == 1 &&
dstMemory->getContext().devices().size() == 1)) {
if (srcMemory->getContext().devices().size() == 1) {
dev_id_ = srcMemory->GetDeviceById()->index();
} else {
dev_id_ = dstMemory->GetDeviceById()->index();
}
}
break;
default:
break;
}
}
GraphMemcpyNode1D(void* dst, const void* src, size_t count, hipMemcpyKind kind,
hipGraphNodeType type = hipGraphNodeTypeMemcpy)
: GraphMemcpyNode(nullptr),
dst_(dst),
src_(src),
count_(count),
kind_(kind) {}
: GraphMemcpyNode(nullptr), dst_(dst), src_(src), count_(count), kind_(kind) {
UpdateDevId();
}

~GraphMemcpyNode1D() {}

Expand All @@ -1543,6 +1568,7 @@ class GraphMemcpyNode1D : public GraphMemcpyNode {
src_ = rhs.src_;
count_ = rhs.count_;
kind_ = rhs.kind_;
UpdateDevId();
}

GraphNode* clone() const override { return new GraphMemcpyNode1D(*this); }
Expand Down Expand Up @@ -1647,6 +1673,7 @@ class GraphMemcpyNode1D : public GraphMemcpyNode {
src_ = src;
count_ = count;
kind_ = kind;
UpdateDevId();
return hipSuccess;
}

Expand Down