Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

View size not compatible error in rtho.py #8

Open
stevenyu530 opened this issue May 5, 2021 · 2 comments
Open

View size not compatible error in rtho.py #8

stevenyu530 opened this issue May 5, 2021 · 2 comments

Comments

@stevenyu530
Copy link

Team,

While running the sample script ./bin/rtho.py, using README provided command the following error occurs.
python bin/baselines.py --network vgg --dataset cifar_10 --optimizer sgd --momentum 0.9 --lr-scheduler cyclic

Traceback (most recent call last): File "bin/rtho.py", line 125, in <module> train_rtho(args.network, args.dataset, args.num_epoch, args.batch_size, args.optimizer, args.lr, args.momentum, File "bin/rtho.py", line 92, in train_rtho hyper_optim.compute_hg(net, first_grad) File "/home/tyu/sandbox/original_adatune/adatune/adatune/mu_sgd.py", line 68, in compute_hg hvp_flatten = torch.cat([h.view(-1) for h in hvp]) File "/home/tyu/sandbox/original_adatune/adatune/adatune/mu_sgd.py", line 68, in <listcomp> hvp_flatten = torch.cat([h.view(-1) for h in hvp]) RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Same error reported when using adam, in mu_adam.py

Changing line 68 in mu_sgd.py into the following by adding contiguous, fixed the problem and can see the script running.
hvp_flatten = torch.cat([h.contiguous().view(-1) for h in hvp])

It will be good if you can advise whether this will be a correct fix?

@orchidmajumder
Copy link
Contributor

I think it's happening due to the PyTorch version issue - feel free to replace view(-1) with reshape(-1) and let me know if it works.

@MightyElemental
Copy link

I'm also getting this error, and I can confirm that it was resolved by changing view(-1) to reshape(-1).
And just so my PyTorch version is recorded here:

>>> torch.__version__
'2.0.0.dev20230115+cu118'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants