Skip to content

Commit

Permalink
Merge pull request #652 from VijayKalmath/Fix-WordCnn_GradientCalcula…
Browse files Browse the repository at this point in the history
…tion

Fix WordCNN input embedding gradient calculation
  • Loading branch information
jxmorris12 authored Jun 3, 2022
2 parents 2f26520 + e377aca commit 2f3e561
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion textattack/models/wrappers/pytorch_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,16 @@ def grad_hook(module, grad_in, grad_out):
loss.backward()

# grad w.r.t to word embeddings
grad = torch.transpose(emb_grads[0], 0, 1)[0].cpu().numpy()

# Fix for Issue #601

# Check if gradient has shape [max_sequence,1,_] ( when model input in transpose of input sequence)

if emb_grads[0].shape[1] == 1:
grad = torch.transpose(emb_grads[0], 0, 1)[0].cpu().numpy()
else:
# gradient has shape [1,max_sequence,_]
grad = emb_grads[0][0].cpu().numpy()

embedding_layer.weight.requires_grad = original_state
emb_hook.remove()
Expand Down

0 comments on commit 2f3e561

Please sign in to comment.