Skip to content

Commit

Permalink
model.py: properly encoding text
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaelvalle committed Mar 20, 2021
1 parent 817ec30 commit d5362cc
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,19 @@ def __init__(self, hparams):
batch_first=True, bidirectional=True)

def forward(self, x, input_lengths):
for conv in self.convolutions:
x = F.dropout(F.relu(conv(x)), drop_rate, self.training)

x = x.transpose(1, 2)
if x.size()[0] > 1:
print("here")
x_embedded = []
for b_ind in range(x.size()[0]): # TODO: Speed up
curr_x = x[b_ind:b_ind+1, :, :input_lengths[b_ind]].clone()
for conv in self.convolutions:
curr_x = F.dropout(F.relu(conv(curr_x)), drop_rate, self.training)
x_embedded.append(curr_x[0].transpose(0, 1))
x = torch.nn.utils.rnn.pad_sequence(x_embedded, batch_first=True)
else:
for conv in self.convolutions:
x = F.dropout(F.relu(conv(x)), drop_rate, self.training)
x = x.transpose(1, 2)

# pytorch tensor are not reversible, hence the conversion
input_lengths = input_lengths.cpu().numpy()
Expand Down

0 comments on commit d5362cc

Please sign in to comment.