Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: module 'torchvision.edgeailite.xnn.model_surgery' has no attribute 'get_replacements_dict' #7

Open
sathyapatel opened this issue Jan 3, 2022 · 15 comments

Comments

@sathyapatel
Copy link

I'm getting following error when i tried to run a ./run_detection_train.sh

work_dir = './work_dirs/yolov3_regnet_bgr_lite'
gpu_ids = range(0, 1)

2022-01-03 09:13:58,990 - mmdet - INFO - Set random seed to 886029822, deterministic: False
2022-01-03 09:13:59,511 - mmdet - INFO - initialize RegNet with init_cfg {'type': 'Pretrained', 'checkpoint': 'open-mmlab://regnetx_1.6gf'}
2022-01-03 09:13:59,512 - mmcv - INFO - load model from: open-mmlab://regnetx_1.6gf
2022-01-03 09:13:59,512 - mmcv - INFO - load checkpoint from openmmlab path: open-mmlab://regnetx_1.6gf
2022-01-03 09:13:59,562 - mmcv - WARNING - The model and loaded state dict do not match exactly

unexpected key in source state_dict: fc.weight, fc.bias

Traceback (most recent call last):
File "./scripts/train_detection_main.py", line 65, in
train_mmdet.main(args)
File "/home/ubuntu/edgeai-mmdetection/tools/train.py", line 172, in main
model = convert_to_lite_model(model, cfg)
File "/home/ubuntu/edgeai-mmdetection/mmdet/utils/model_surgery.py", line 38, in convert_to_lite_model
replacements_dict = copy.deepcopy(xnn.model_surgery.get_replacements_dict())
AttributeError: module 'torchvision.edgeailite.xnn.model_surgery' has no attribute 'get_replacements_dict'
Done.

@sathyapatel
Copy link
Author

I'd commented line no 172 in tools/train.py

if hasattr(cfg, 'convert_to_lite_model'):
   model = convert_to_lite_model(model, cfg)

Working fine now!! Training has started successfully.

I encountered some bugs in model_surgery.py , Can you please help me out ?

@mathmanu
Copy link
Collaborator

mathmanu commented Jan 3, 2022

You need to pull he repository edgeai-torchvision as it has been updated. Once you pull that, the error will go away.

@sathyapatel
Copy link
Author

Thanks! it worked. Another quick question, Is it possible to train QAT with pretrained weights ? I'd trained Centernet with customized model and datasets. I'm trying to do QAT on Pretrained saved weights.

@mathmanu
Copy link
Collaborator

@ginamathew
Copy link

By the pre-trained weights whether we can use original pytorch weights(without QAT) or weights after some epochs after QAT? As I understand, there will be change in model layers after QAT, so would like to know whether we can use the original model (before QAT) as pre-trained weights for doing QAT?

@ginamathew
Copy link

Also would like to know whether we need to do QAT using all the training images or small percentage of it is enough?

@mathmanu
Copy link
Collaborator

It is possible to load the original floating point weights while doing QAT. Few number of epochs are sufficient for QAT, may be 10. If the dataset is large (like ImageNet) a small portion may be sufficient for QAT.

See the following link:
https://github.com/TexasInstruments/edgeai-torchvision/blob/master/docs/pixel2pixel/Quantization.md

Add see the following example code snippet there:

from torchvision.edgeailite import xnn

# create your model here:
model = ...

# create a dummy input - this is required to analyze the model - fill in the input image size expected by your model.
dummy_input = torch.rand((1,3,384,768))

# wrap your model in xnn.quantize.QuantTrainModule. 
# once it is wrapped, the actual model is in model.module
model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input)

# load your pretrained weights here into model.module
pretrained_data = torch.load(pretrained_path)
model.module.load_state_dict(pretrained_data)

# your training loop here with with loss, backward, optimizer and scheduler. 
# this is the usual training loop - but use a lower learning rate such as 1e-5
model.train()
for images, target in my_dataset_train:
    output = model(images)
    # loss, backward(), optimizer step etc comes here as usual in training

# save the model - the trained module is in model.module
# QAT model can export a clean onnx graph with clips in eval mode.
model.eval()
torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False, do_constant_folding=True, opset_version=9)
torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))

For better accuracy, we have seen that it is better to freeze the BN layers and also the Quant range around half the number of epochs. An example is here:
https://github.com/TexasInstruments/edgeai-torchvision/blob/master/references/edgeailite/engine/train_classification.py#L530

@sathyapatel
Copy link
Author

How much GPU memory do we need for this training ? I tried with single GPU instance of 16 GB throws CUDA out of memory.

I'm passing a Model, dummy input, xnn.quantize.QuantTrainModule to cuda memory. Do you have any solution for this?

@mathmanu
Copy link
Collaborator

GPU memory depends on the batch size used. Reduce the batch size if you get CUDA out of memory

@sathyapatel
Copy link
Author

i reduced to batch size 16,8,4,2,1. facing same memory issue with batch size 1.

If I comment xnn.quantize.QuantTrainModule in the code. Training has started without quantization module.

@mathmanu
Copy link
Collaborator

mathmanu commented Feb 17, 2022

Can you share the exact error?

Which model are you using? What is the input image size being used.

You can also try to do that line
model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input)
in the CPU so that GPU will not be used at this stage. But not sure if and why this should help.

@sathyapatel
Copy link
Author

I'm using Centernet model and pretrained weights
Input Image size is 512 X 320
total sample of 1062 images

ERROR:

Loaded train 929 samples
/home/ubuntu/anaconda3/envs/edge-ai/lib/python3.7/site-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
cpuset_checked))
Traceback (most recent call last):
File "src/pretrain.py", line 103, in
log_dict_train, _ = trainer.train(epoch, train_loader)
File "/home/ubuntu/edgeai-mmdetection/src/lib/trains/base_trainer.py", line 119, in train
return self.run_epoch('train', epoch, data_loader)
File "/home/ubuntu/edgeai-mmdetection/src/lib/trains/base_trainer.py", line 69, in run_epoch
output, loss, loss_stats = model_with_loss(batch)
File "/home/ubuntu/anaconda3/envs/edge-ai/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/edgeai-mmdetection/src/lib/trains/base_trainer.py", line 19, in forward
outputs = self.model(batch['input'])
File "/home/ubuntu/anaconda3/envs/edge-ai/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/edge-ai/edgeai-mmdetection/edgeai-torchvision/torchvision/edgeailite/xnn/quantize/quant_train_module.py", line 70, in forward
outputs = self.module(inputs, *args, **kwargs)
File "/home/ubuntu/anaconda3/envs/edge-ai/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/edgeai-mmdetection/src/lib/models/networks/mobilev2.py", line 170, in forward
x = self.base.forward_det_features(x)
File "/home/ubuntu/edgeai-mmdetection/src/lib/models/networks/mobilev2.py", line 110, in forward_det_features
x = feature(x)
File "/home/ubuntu/anaconda3/envs/edge-ai/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/anaconda3/envs/edge-ai/lib/python3.7/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home/ubuntu/anaconda3/envs/edge-ai/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/edge-ai/edgeai-mmdetection/edgeai-torchvision/torchvision/edgeailite/xnn/quantize/quant_train_module.py", line 328, in forward
y = super().forward(x, update_activation_range=self.update_activation_range, enable=self.quantize_enable)
File "/home/ubuntu/edge-ai/edgeai-mmdetection/edgeai-torchvision/torchvision/edgeailite/xnn/layers/activation.py", line 87, in forward
self.clips_batch = self.update_clips_act(x.data)
File "/home/ubuntu/edge-ai/edgeai-mmdetection/edgeai-torchvision/torchvision/edgeailite/xnn/layers/activation.py", line 109, in update_clips_act
x_min, x_max = utils.extrema_fast(x, range_shrink_percentile=self.range_shrink_activations)
File "/home/ubuntu/edge-ai/edgeai-mmdetection/edgeai-torchvision/torchvision/edgeailite/xnn/utils/range_utils.py", line 42, in extrema_fast
return extrema(src, range_shrink_percentile, channel_mean, sigma, fast_mode)
File "/home/ubuntu/edge-ai/edgeai-mmdetection/edgeai-torchvision/torchvision/edgeailite/xnn/utils/range_utils.py", line 52, in extrema
hist_array, mn, mx, mult_factor, offset = tensor_histogram(src, fast_mode=fast_mode)
File "/home/ubuntu/edge-ai/edgeai-mmdetection/edgeai-torchvision/torchvision/edgeailite/xnn/utils/range_utils.py", line 111, in tensor_histogram
hist = torch.bincount(tensor_int)
RuntimeError: CUDA out of memory. Tried to allocate 16.00 GiB (GPU 0; 14.76 GiB total capacity; 8.00 GiB already allocated; 5.02 GiB free; 8.43 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

@mathmanu
Copy link
Collaborator

Can you try reducing the input image size - let's see if this is really related to the memory usage.

@sathyapatel
Copy link
Author

issue is insufficient GPU memory. Changing input size of an images doesn't work. still thros same run time error. I removed loading pretrained weights section in the code

Training has started with scratch and xnn.quantize.QuantTrainModule moved to cuda memory

`creating index...
index created!
Loaded train 929 samples
/home/ubuntu/anaconda3/envs/edge-ai/lib/python3.7/site-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked))

/home/ubuntu/anaconda3/envs/edge-ai/lib/python3.7/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
warnings.warn(warning.format(ret))

ctdet/train_test |# | train: [1][0/29]|Tot: 0:00:11 |ETA: 0:00:00 |loss 15.7850 |hm_loss 14.5027 |wh_loss 8.0129 |off_loss 0.4810 |Data

ctdet/train_test |## | train: [1][1/29]|Tot: 0:00:12 |ETA: 0:05:15 |loss 15.6908 |hm_loss 14.3788 |wh_loss 8.2404 |off_loss 0.4879 |Data

ctdet/train_test |### | train: [1][2/29]|Tot: 0:00:14 |ETA: 0:02:57 |loss 15.6135 |hm_loss 14.3530 |wh_loss 7.8958 |off_loss 0.4709 |Data

ctdet/train_test |#### | train: [1][3/29]|Tot: 0:01:25 |ETA: 0:02:10 |loss 15.0860 |hm_loss 13.8389 |wh_loss 7.7483 |off_loss 0.4723 |Data

ctdet/train_test |##### | train: [1][4/29]|Tot: 0:02:20 |ETA: 0:08:59 |loss 14.8916 |hm_loss 13.6098 |wh_loss 8.0658 |off_loss 0.4752 |Data

ctdet/train_test |###### | train: [1][5/29]|Tot: 0:03:21 |ETA: 0:11:15 |loss 14.7081 |hm_loss 13.4047 |wh_loss 8.2868 |off_loss 0.4747 |Data

ctdet/train_test |####### | train: [1][6/29]|Tot: 0:04:20 |ETA: 0:12:55 |loss 14.6586 |hm_loss 13.3553 |wh_loss 8.3268 |off_loss 0.4706 |Data 0.353s(1.196s) |Net 37.158s
`

@mathmanu
Copy link
Collaborator

mathmanu commented Feb 17, 2022

I removed loading pretrained weights section in the code

Is that the change that reduced memory requirement significantly? Surprising!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants