Skip to content

Commit

Permalink
update logic on patch check so non-patched rectangular regions aren't…
Browse files Browse the repository at this point in the history
… rejected (#524)
  • Loading branch information
daviddpruitt committed May 29, 2024
1 parent c88a2fe commit 253ea57
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions examples/generative/corrdiff/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,11 @@ def main(cfg: DictConfig) -> None:
c.patch_shape_x = img_shape_x
if (c.patch_shape_y is None) or (c.patch_shape_y > img_shape_y):
c.patch_shape_y = img_shape_y
if c.patch_shape_x != c.patch_shape_y:
raise NotImplementedError("Rectangular patch not supported yet")
if c.patch_shape_x % 32 != 0 or c.patch_shape_y % 32 != 0:
raise ValueError("Patch shape needs to be a multiple of 32")
if c.patch_shape_x != img_shape_x or c.patch_shape_y != img_shape_y:
if c.patch_shape_x != c.patch_shape_y:
raise NotImplementedError("Rectangular patch not supported yet")
if c.patch_shape_x % 32 != 0 or c.patch_shape_y % 32 != 0:
raise ValueError("Patch shape needs to be a multiple of 32")
logger0.info("Patch-based training enabled")
else:
logger0.info("Patch-based training disabled")
Expand Down

0 comments on commit 253ea57

Please sign in to comment.