Skip to content

tutorial-contents/403_RNN_regressor.py bug report #46

@njuhan

Description

@njuhan

你好,
文件tutorial-contents/403_RNN_regressor.py 中的这段代码

r_out = r_out.view(-1, 32)
outs = self.out(r_out)
return outs, h_state

outs的shape是 10x1, 与(batch, time_step, input_size) 不匹配, 导致RuntimeError: input and target shapes do not match: input [10 x 1], target [1 x 10 x 1]
是否可以这样直接添加一个维度,匹配(batch, time_step, input_size)

r_out = r_out.view(-1, 32)  
outs = self.out(r_out) 
return torch.unsqueeze(outs,0), h_state

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions