Skip to content

Commit

Permalink
Merge pull request #47 from alihassanijr/main
Browse files Browse the repository at this point in the history
Minor bug fix and cleanup
  • Loading branch information
honghuis committed Jul 22, 2022
2 parents a9a7580 + a380984 commit f64cc00
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ NAT-Tiny reaches 83.2% top-1 accuracy on ImageNet with only
![computeplot_light](assets/computeplot_light.png#gh-light-mode-only)

# How it works
Natural Attention localizes the query's (red) receptive field to its nearest neighborhood (green).
Neighborhood Attention localizes the query's (red) receptive field to its nearest neighborhood (green).
This is equivalent to dot-product self attention when the neighborhood size is identical to the image dimensions.
Note that the edges are special (edge) cases.

Expand Down
14 changes: 7 additions & 7 deletions classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,13 +655,13 @@ def main(args):

if args.log_wandb:
if has_wandb:
wandb.init(entity='compactdonut',
id=args.experiment,
project=args.project,
name=args.model,
config=args,
resume=bool(args.resume)
)
wandb.init(
id=args.experiment,
project=args.project,
name=args.model,
config=args,
resume=bool(args.resume)
)
else:
builtin_print("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
Expand Down
6 changes: 3 additions & 3 deletions natten/nattencuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def forward(ctx, attn, value):
attn = attn.contiguous()
value = value.contiguous()
out = nattenav_cuda.forward(
attn,
attn,
value)
ctx.save_for_backward(attn, value)
return out
Expand Down Expand Up @@ -114,8 +114,8 @@ def forward(self, x):
pad_l = pad_t = pad_r = pad_b = 0
if H < self.kernel_size or W < self.kernel_size:
pad_l = pad_t = 0
pad_r = max(0, self.window_size - W)
pad_b = max(0, self.window_size - H)
pad_r = max(0, self.kernel_size - W)
pad_b = max(0, self.kernel_size - H)
x = pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, H, W, _ = x.shape
qkv = self.qkv(x).reshape(B, H, W, 3, self.num_heads, self.head_dim).permute(3, 0, 4, 1, 2, 5)
Expand Down

0 comments on commit f64cc00

Please sign in to comment.