Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3.9.4节W1和W2的类型问题 #156

Closed
Liang-Liao opened this issue Sep 2, 2020 · 3 comments
Closed

3.9.4节W1和W2的类型问题 #156

Liang-Liao opened this issue Sep 2, 2020 · 3 comments

Comments

@Liang-Liao
Copy link

bug描述
按3.9.4小节的net函数来计算,在我的环境里面报错了

def net(X):
    X = X.view((-1, num_inputs))
    H = relu(torch.matmul(X, W1) + b1)
    return torch.matmul(H, W2) + b2

loss = torch.nn.CrossEntropyLoss()

num_epochs, lr = 5, 100.0
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

错误如下:

RuntimeError  Traceback (most recent call last)
<ipython-input-52-c1201a53ebe9> in <module>
      1 num_epochs, lr = 5, 100.0
----> 2 d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

~/liang/d2lzh_pytorch.py in train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr, optimizer)
     84         train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
     85         for X, y in train_iter:
---> 86             y_hat = net(X)
     87             l = loss(y_hat, y).sum()
     88 

<ipython-input-50-c182b51c4bb0> in net(X)
      1 def net(X):
      2     X = X.view((-1, num_inputs))
----> 3     H = relu(torch.matmul(X, W1) + b1)
      4     return torch.matmul(H, W2) + b2

RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #2 'mat2' in call to _th_mm

版本信息
pytorch: 1.4.0
torchvision:0.5.0
torchtext:
...

@LuckyyySTA
Copy link

我和您出现了一样的问题,请问您的问题解决了吗

@Liang-Liao
Copy link
Author

Liang-Liao commented Sep 14, 2020

#156 (comment)

我在 net() 方法里面的 W1 和 W2 加上了 float() 转换就行了。如下:

def net(X):
    X = X.view((-1, num_inputs))
    H = relu(torch.matmul(X, W1.float()) + b1)
    return torch.matmul(H, W2.float()) + b2

@LuckyyySTA
Copy link

#156 (comment)

我在 net() 方法里面的 W1 和 W2 加上了 float() 转换就行了。如下:

def net(X):
    X = X.view((-1, num_inputs))
    H = relu(torch.matmul(X, W1.float()) + b1)
    return torch.matmul(H, W2.float()) + b2

在您的帮助下成功运行了,十分感谢!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants