Skip to content

Commit

Permalink
Fix syntax in reformated paddle-paddle and torch notebook snippets (#…
Browse files Browse the repository at this point in the history
…5523)

* Fix paddle paddle external input example
* Fix torch example external input example 

Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>

---------

Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
  • Loading branch information
stiepan committed Jun 14, 2024
1 parent 755439d commit 242d65a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
18 changes: 11 additions & 7 deletions docs/examples/frameworks/paddle/paddle-external_input.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -166,19 +166,23 @@
}
],
"source": [
"from nvidia.dali.plugin.paddle import DALIClassificationIterator as PaddleIterator\n",
"from nvidia.dali.plugin.paddle import (\n",
" DALIClassificationIterator as PaddleIterator,\n",
")\n",
"from nvidia.dali.plugin.paddle import LastBatchPolicy\n",
"\n",
"eii = ExternalInputIterator(batch_size, 0, 1)\n",
"pipe = ExternalSourcePipeline(batch_size=batch_size, num_threads=2, device_id=0,\n",
" external_data=eii)\n",
"pii = PaddleIterator(pipe, last_batch_padded=True,\n",
" last_batch_policy=LastBatchPolicy.PARTIAL)\n",
"pipe = ExternalSourcePipeline(\n",
" batch_size=batch_size, num_threads=2, device_id=0, external_data=eii\n",
")\n",
"pii = PaddleIterator(\n",
" pipe, last_batch_padded=True, last_batch_policy=LastBatchPolicy.PARTIAL\n",
")\n",
"\n",
"for e in range(epochs):\n",
" for i, data in enumerate(pii):\n",
" print(f\"epoch: {e}, iter {i}, real batch size: \"\n",
" f\"{len(np.array(data[0][\"data\"]))}\")\n",
" real_batch_size = len(np.array(data[0][\"data\"]))\n",
" print(f\"epoch: {e}, iter {i}, real batch size: {real_batch_size}\")\n",
" pii.reset()"
]
}
Expand Down
18 changes: 11 additions & 7 deletions docs/examples/frameworks/pytorch/pytorch-external_input.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -169,19 +169,23 @@
}
],
"source": [
"from nvidia.dali.plugin.pytorch import DALIClassificationIterator as PyTorchIterator\n",
"from nvidia.dali.plugin.pytorch import (\n",
" DALIClassificationIterator as PyTorchIterator,\n",
")\n",
"from nvidia.dali.plugin.pytorch import LastBatchPolicy\n",
"\n",
"eii = ExternalInputIterator(batch_size, 0, 1)\n",
"pipe = ExternalSourcePipeline(batch_size=batch_size, num_threads=2, device_id=0,\n",
" external_data=eii)\n",
"pii = PyTorchIterator(pipe, last_batch_padded=True,\n",
" last_batch_policy=LastBatchPolicy.PARTIAL)\n",
"pipe = ExternalSourcePipeline(\n",
" batch_size=batch_size, num_threads=2, device_id=0, external_data=eii\n",
")\n",
"pii = PyTorchIterator(\n",
" pipe, last_batch_padded=True, last_batch_policy=LastBatchPolicy.PARTIAL\n",
")\n",
"\n",
"for e in range(epochs):\n",
" for i, data in enumerate(pii):\n",
" print(f\"epoch: {e}, iter {i}, real batch size: \"\n",
" f\"{len(data[0][\"data\"])}\")\n",
" real_batch_size = len(data[0][\"data\"])\n",
" print(f\"epoch: {e}, iter {i}, real batch size: {real_batch_size}\")\n",
" pii.reset()"
]
}
Expand Down

0 comments on commit 242d65a

Please sign in to comment.