We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bug描述 按3.9.4小节的net函数来计算,在我的环境里面报错了
def net(X): X = X.view((-1, num_inputs)) H = relu(torch.matmul(X, W1) + b1) return torch.matmul(H, W2) + b2 loss = torch.nn.CrossEntropyLoss() num_epochs, lr = 5, 100.0 d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)
错误如下:
RuntimeError Traceback (most recent call last) <ipython-input-52-c1201a53ebe9> in <module> 1 num_epochs, lr = 5, 100.0 ----> 2 d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr) ~/liang/d2lzh_pytorch.py in train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr, optimizer) 84 train_l_sum, train_acc_sum, n = 0.0, 0.0, 0 85 for X, y in train_iter: ---> 86 y_hat = net(X) 87 l = loss(y_hat, y).sum() 88 <ipython-input-50-c182b51c4bb0> in net(X) 1 def net(X): 2 X = X.view((-1, num_inputs)) ----> 3 H = relu(torch.matmul(X, W1) + b1) 4 return torch.matmul(H, W2) + b2 RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #2 'mat2' in call to _th_mm
版本信息 pytorch: 1.4.0 torchvision:0.5.0 torchtext: ...
The text was updated successfully, but these errors were encountered:
我和您出现了一样的问题,请问您的问题解决了吗
Sorry, something went wrong.
#156 (comment)
我在 net() 方法里面的 W1 和 W2 加上了 float() 转换就行了。如下:
def net(X): X = X.view((-1, num_inputs)) H = relu(torch.matmul(X, W1.float()) + b1) return torch.matmul(H, W2.float()) + b2
#156 (comment) 我在 net() 方法里面的 W1 和 W2 加上了 float() 转换就行了。如下: def net(X): X = X.view((-1, num_inputs)) H = relu(torch.matmul(X, W1.float()) + b1) return torch.matmul(H, W2.float()) + b2
在您的帮助下成功运行了,十分感谢!
No branches or pull requests
bug描述
按3.9.4小节的net函数来计算,在我的环境里面报错了
错误如下:
版本信息
pytorch: 1.4.0
torchvision:0.5.0
torchtext:
...
The text was updated successfully, but these errors were encountered: