Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

not strict mode in ProjectPatchesTokenizer #260

Merged
merged 8 commits into from Jan 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 17 additions & 5 deletions fuse/dl/models/backbones/backbone_vit.py
Expand Up @@ -17,13 +17,25 @@ class ProjectPatchesTokenizer(nn.Module):
batch_size, num_tokens, token_dim
"""

def __init__(self, *, image_shape: Sequence[int], patch_shape: Sequence[int], channels: int, token_dim: int):
def __init__(
self,
*,
image_shape: Sequence[int],
patch_shape: Sequence[int],
channels: int,
token_dim: int,
strict: bool = True,
avihu111 marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__()
assert len(image_shape) == len(patch_shape), "patch and image must have identical dimensions"
image_shape = np.array(image_shape)
patch_shape = np.array(patch_shape)
assert (image_shape % patch_shape == 0).all(), "Image dimensions must be divisible by the patch size."
self.num_tokens = int(np.prod(image_shape // patch_shape))
np_image_shape = np.array(image_shape)
np_patch_shape = np.array(patch_shape)

# when shapes not divisable and strict=False, handaling is similar to pytorch strided convolution
# see here: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d
if strict:
assert (np_image_shape % np_patch_shape == 0).all(), "Image dimensions must be divisible by the patch size."
self.num_tokens = int(np.prod(np_image_shape // np_patch_shape))
patch_shape = tuple(patch_shape)
self.image_dim = len(image_shape)
if self.image_dim == 1:
Expand Down