From 6ca2fe4f6928a2023327a28c8bf7a04f7dafc86a Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Sun, 22 Jan 2023 16:16:58 -0500 Subject: [PATCH] doc update --- fuse/dl/models/backbones/backbone_vit.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fuse/dl/models/backbones/backbone_vit.py b/fuse/dl/models/backbones/backbone_vit.py index 1c33b1f9..4748ea88 100644 --- a/fuse/dl/models/backbones/backbone_vit.py +++ b/fuse/dl/models/backbones/backbone_vit.py @@ -30,6 +30,9 @@ def __init__( assert len(image_shape) == len(patch_shape), "patch and image must have identical dimensions" np_image_shape = np.array(image_shape) np_patch_shape = np.array(patch_shape) + + # when shapes not divisable and strict=False, handaling is similar to pytorch strided convolution + # see here: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d if strict: assert (np_image_shape % np_patch_shape == 0).all(), "Image dimensions must be divisible by the patch size." self.num_tokens = int(np.prod(np_image_shape // np_patch_shape))