Skip to content

Commit

Permalink
Fix AssertionError
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Mar 5, 2024
1 parent 983bc38 commit 6eb23e2
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 161 deletions.
163 changes: 28 additions & 135 deletions demos/Cell Detection with Contour Proposal Networks.ipynb

Large diffs are not rendered by default.

24 changes: 11 additions & 13 deletions demos/demo-binary.ipynb
Expand Up @@ -65,6 +65,8 @@
" samples=128, # number of coordinates per contour\n",
" refinement_iterations=3,\n",
" refinement_buckets=6,\n",
" inputs_mean=.5,\n",
" inputs_std=.5,\n",
" tweaks={\n",
" 'BatchNorm2d': {'momentum': 0.05}\n",
" },\n",
Expand All @@ -80,7 +82,7 @@
" amp=torch.cuda.is_available(),\n",
" \n",
" # misc\n",
" num_workers=8 * int(os.name != 'nt'),\n",
" num_workers=0,\n",
" device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
")\n",
"print(conf)"
Expand Down Expand Up @@ -108,9 +110,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"class Dataset:\n",
Expand All @@ -129,13 +129,12 @@
" \n",
" @staticmethod\n",
" def map(image):\n",
" image = image / 127.5\n",
" image -= 1\n",
" image = image / 255\n",
" return image\n",
" \n",
" @staticmethod\n",
" def unmap(image):\n",
" image = (image + 1) * 127.5\n",
" image = image * 255\n",
" image = np.clip(image, 0, 255).astype('uint8')\n",
" return image\n",
" \n",
Expand Down Expand Up @@ -196,7 +195,8 @@
"model = getattr(models, conf.cpn)(in_channels=conf.in_channels, order=conf.order, samples=conf.samples,\n",
" refinement_iterations=conf.refinement_iterations, nms_thresh=conf.nms_thresh,\n",
" score_thresh=conf.score_thresh, contour_head_stride=conf.contour_head_stride,\n",
" classes=conf.classes, refinement_buckets=conf.refinement_buckets)\n",
" classes=conf.classes, refinement_buckets=conf.refinement_buckets,\n",
" backbone_kwargs=dict(inputs_mean=conf.inputs_mean, inputs_std=conf.inputs_std))\n",
"cd.conf2tweaks_(conf.tweaks, model)\n",
"model.to(conf.device)\n",
"optimizer = cd.conf2optimizer(conf.optimizer, model.parameters())\n",
Expand Down Expand Up @@ -262,9 +262,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"for epoch in range(1, conf.epochs):\n",
Expand All @@ -276,7 +274,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -290,7 +288,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.9.16"
}
},
"nbformat": 4,
Expand Down
24 changes: 11 additions & 13 deletions demos/demo-multiclass.ipynb
Expand Up @@ -64,6 +64,8 @@
" samples=128, # number of coordinates per contour\n",
" refinement_iterations=3,\n",
" refinement_buckets=6,\n",
" inputs_mean=.5,\n",
" inputs_std=.5,\n",
" tweaks={\n",
" 'BatchNorm2d': {'momentum': 0.05}\n",
" },\n",
Expand All @@ -79,7 +81,7 @@
" amp=torch.cuda.is_available(),\n",
" \n",
" # misc\n",
" num_workers=8 * int(os.name != 'nt'),\n",
" num_workers=0,\n",
" device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
")\n",
"print(conf)"
Expand Down Expand Up @@ -107,9 +109,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"class Dataset:\n",
Expand All @@ -128,13 +128,12 @@
" \n",
" @staticmethod\n",
" def map(image):\n",
" image = image / 127.5\n",
" image -= 1\n",
" image = image / 255\n",
" return image\n",
" \n",
" @staticmethod\n",
" def unmap(image):\n",
" image = (image + 1) * 127.5\n",
" image = image * 255\n",
" image = np.clip(image, 0, 255).astype('uint8')\n",
" return image\n",
" \n",
Expand Down Expand Up @@ -197,7 +196,8 @@
"model = getattr(models, conf.cpn)(in_channels=conf.in_channels, order=conf.order, samples=conf.samples,\n",
" refinement_iterations=conf.refinement_iterations, nms_thresh=conf.nms_thresh,\n",
" contour_head_stride=conf.contour_head_stride, classes=conf.classes,\n",
" refinement_buckets=conf.refinement_buckets)\n",
" refinement_buckets=conf.refinement_buckets,\n",
" backbone_kwargs=dict(inputs_mean=conf.inputs_mean, inputs_std=conf.inputs_std))\n",
"cd.conf2tweaks_(conf.tweaks, model)\n",
"model.to(conf.device)\n",
"optimizer = cd.conf2optimizer(conf.optimizer, model.parameters())\n",
Expand Down Expand Up @@ -264,9 +264,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"for epoch in range(1, conf.epochs):\n",
Expand All @@ -278,7 +276,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -292,7 +290,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.9.16"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 6eb23e2

Please sign in to comment.