Skip to content

Commit

Permalink
fix cutout
Browse files Browse the repository at this point in the history
  • Loading branch information
ain-soph committed Nov 30, 2021
1 parent 55719a9 commit 1ead700
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions trojanvision/utils/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,15 @@ def __repr__(self) -> str:
return s.format(**self.__dict__)


def cutout(img: torch.Tensor, length: Union[int, tuple[int, int]],
def cutout(img: torch.Tensor, length: Union[int, tuple[int, int], torch.Tensor],
fill_values: Union[float, torch.Tensor] = 0.0) -> torch.Tensor:
if isinstance(length, int):
length = (length, length)
h, w = img.size(-2), img.size(-1)
mask = torch.ones(h, w, dtype=torch.bool, device=img.device)
y = torch.randint(0, h, [1])
x = torch.randint(0, w, [1])
device = length.device if isinstance(length, torch.Tensor) else img.device
y = torch.randint(0, h, [1], device=device)
x = torch.randint(0, w, [1], device=device)
first_half = [length[0] // 2, length[1] // 2]
second_half = [length[0] - first_half[0], length[1] - first_half[1]]

Expand Down

0 comments on commit 1ead700

Please sign in to comment.