You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, Thanks a lot for your kind and great contribution.
I am currently trying to prune a custom resnet18 model which was trained for face recognition.
The model is pretty much the same as the normal resnet18 with some minor differences (you can see the actual model definition here
I used your prune_model() function from examples/prune_resnet18_cifar10.py#L83 and only changed the resnet.BasicBlock to IRBlock and the input size from 32 to 112 in my model the rest is the same :
here is the whole script :
importtorchimporttorch_pruningaspruningfrommodelsimportresnet18, load_model, BasicBlock, Bottleneck, IRBlock, SEBlock, ResNetdefprune_model(model):
model.cpu()
# my resnet18 was trained on 112x112 images, so we changed 32 to 112DG=pruning.DependencyGraph().build_dependency( model, torch.randn(1, 3, 112, 112))
defprune_conv(conv, pruned_prob):
weight=conv.weight.detach().cpu().numpy()
out_channels=weight.shape[0]
L1_norm=np.sum(weight, axis=(1, 2, 3))
num_pruned=int(out_channels*pruned_prob)
prune_index=np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Normplan=DG.get_pruning_plan(conv, pruning.prune_conv, prune_index)
plan.exec()
block_prune_probs= [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
blk_id=0forminmodel.modules():
ifisinstance( m, IRBlock):
prune_conv( m.conv1, block_prune_probs[blk_id] )
prune_conv( m.conv2, block_prune_probs[blk_id] )
blk_id+=1returnmodel# load the resnet18 model : model=resnet18(pretrained=False, use_se=True)
model=load_model(model, 'BEST_checkpoint_r18.tar')
model.eval()
# prune the model prune_model(model)
but upon running this snippet of code, I get this error ;
Hi, Thanks a lot for your kind and great contribution.
I am currently trying to prune a custom resnet18 model which was trained for face recognition.
The model is pretty much the same as the normal resnet18 with some minor differences (you can see the actual model definition here
Heres my model if you are intrested
I used your
prune_model()
function from examples/prune_resnet18_cifar10.py#L83 and only changed theresnet.BasicBlock
toIRBlock
and the input size from 32 to 112 in my model the rest is the same :here is the whole script :
but upon running this snippet of code, I get this error ;
Could you kindly please tell me what I'm missing here?
Thanks a lot in advance
The text was updated successfully, but these errors were encountered: