Skip to content

Commit

Permalink
Allow CPU/GPU inputs in DeviceWorkspace
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton <janton@nvidia.com>
  • Loading branch information
jantonguirao committed Jul 28, 2020
1 parent 57417d9 commit acb8d28
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
1 change: 1 addition & 0 deletions dali/operators/random/normal_distribution_op.h
Expand Up @@ -141,6 +141,7 @@ class NormalDistributionCpu : public NormalDistribution<CPUBackend> {

protected:
void RunImpl(workspace_t<CPUBackend> &ws) override;
using NormalDistribution<CPUBackend>::RunImpl;

private:
void AssignTensorToOutput(workspace_t<CPUBackend> &ws);
Expand Down
36 changes: 28 additions & 8 deletions dali/pipeline/operator/operator.h
Expand Up @@ -276,6 +276,30 @@ class DLL_PUBLIC OperatorBase {
template <typename Backend>
class Operator : public OperatorBase {};

template <typename Workspace>
TensorLayout GetInputLayout(Workspace& ws, int i) {
if (ws.template InputIsType<CPUBackend>(i)) {
auto &in = ws.template InputRef<CPUBackend>(i);
return in.GetLayout();
}

assert(ws.template InputIsType<GPUBackend>(i));
auto &in = ws.template InputRef<GPUBackend>(i);
return in.GetLayout();
}

template <typename Workspace>
TensorLayout GetOutputLayout(Workspace &ws, int i) {
if (ws.template OutputIsType<CPUBackend>(i)) {
auto &out = ws.template OutputRef<CPUBackend>(i);
return out.GetLayout();
}

assert(ws.template OutputIsType<GPUBackend>(i));
auto &out = ws.template OutputRef<GPUBackend>(i);
return out.GetLayout();
}

template <>
class Operator<CPUBackend> : public OperatorBase {
public:
Expand All @@ -297,10 +321,8 @@ class Operator<CPUBackend> : public OperatorBase {
ws.GetThreadPool().WaitForWork();

if (ws.NumInput() > 0 && ws.NumOutput() > 0) {
auto &in = ws.template InputRef<CPUBackend>(0);
auto &out = ws.template OutputRef<CPUBackend>(0);
auto in_layout = in.GetLayout();
auto out_layout = out.GetLayout();
auto in_layout = GetInputLayout(ws, 0);
auto out_layout = GetOutputLayout(ws, 0);
DALI_ENFORCE(
!out_layout.empty() || in_layout.empty() || spec_.name() == "DLTensorPythonFunctionImpl",
make_string("Operator: ", spec_.name(), " produced an empty layout. Input layout was ",
Expand Down Expand Up @@ -377,10 +399,8 @@ class Operator<GPUBackend> : public OperatorBase {
SetupSharedSampleParams(ws);
RunImpl(ws);
if (ws.NumInput() > 0 && ws.NumOutput() > 0) {
auto &in = ws.template InputRef<GPUBackend>(0);
auto &out = ws.template OutputRef<GPUBackend>(0);
auto in_layout = in.GetLayout();
auto out_layout = out.GetLayout();
auto in_layout = GetInputLayout(ws, 0);
auto out_layout = GetOutputLayout(ws, 0);
DALI_ENFORCE(
!out_layout.empty() || in_layout.empty() || spec_.name() == "DLTensorPythonFunctionImpl",
make_string("Operator: ", spec_.name(), " produced an empty layout. Input layout was ",
Expand Down

0 comments on commit acb8d28

Please sign in to comment.