-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trouble with pure inference #4
Comments
Thanks for the kind words. We use multiple GPU to train the model, so the model is that torch.nn.Dataparallel object. Even though you want to do single GPU inference, you need to do following: input_tdim = 1024
ast_mdl = ASTModel(label_dim=2,
fshape=16,
tshape=16,
fstride=10,
tstride=10,
input_fdim=128,
input_tdim=input_tdim,
model_size='tiny',
pretrain_stage=False,
load_pretrained_mdl_path=MODEL)
# convert it to torch.nn.Dataparallel object
ast_mdl = torch.nn.Dataparallel(ast_mdl)
# then do inference as normal
output = ast_mdl(input) Another method is to convert torch.nn.Dataparallel models back to normal torch.model objects. You can search online for the solution. -Yuan |
Also the model input should be a spectrogram that is processed with the same normalization and feature extraction function Line 195 in 35ae7ab
Lines 126 to 127 in 35ae7ab
You can also refer to https://github.com/YuanGongND/ast/blob/master/egs/audioset/inference.py. |
Thanks Yuan for your suggestions. To be clear, where do I add the DataParallel wrapper? In your example, you put it after the ASTModel object, however it is in that initial call where it is failing the load, therefore I suspect I need to modify |
You should do something like this: i.e., |
I don't suggest changing input_tdim = 1024
ast_mdl = ASTModel(label_dim=2,
fshape=16,
tshape=16,
fstride=10,
tstride=10,
input_fdim=128,
input_tdim=input_tdim,
model_size='tiny',
pretrain_stage=False,
load_pretrained_mdl_path=MODEL)
# convert it to torch.nn.Dataparallel object
ast_mdl = torch.nn.Dataparallel(ast_mdl)
# load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cuda')
audio_model.load_state_dict(checkpoint)
# then do inference as normal
output = ast_mdl(input) |
Sorry I might not have been clear here. The In that call, Or are you suggesting that I load the pre-trained model provided by this repo ( Does that make sense? |
That's weird, if you use my recipe to fine-tune the model, the saved model should be already a dataparallel object. |
Yes... I most certainly used your code. I essentially used the AudioSet fine tune script - here is the full
|
I see. It might be caused by a bug in the code. I didn't consider your use case. If the model is not too large, can you send the .pth file to me at yuangong@mit.edu? I can take a look, but not immediately, I will need to find some spare time. |
FWIW, I got some kind of inference pipeline running - although the results do not match the output originally generated in your recipes for fine tuning, so I'm guessing there's major bugs in what I got working. But I thought it might be relevant anyway. This is all done after I successfully ran the fine-tuning scripts on a new dataset for binary classification: First, use the same JSON style approach to make a dataloader (using your
Next, load the original (pre-trained) model from which I fine-tuned from:
Then, load into this the state checkpoint (no idea if this works as expected, but is the only way I got anything to run without errors):
Then, copying bits and pieces from the supplied
For single samples. the prediction probabilities do not sum to 1 nor do they match my expected values from the supplied recipe for fine tuning. |
The problem is that I used a trick to encode the pretraining hyperparameters in the model and use the existence of the hyperparameter to check if the model is a dataparallel object. The SSL pretraining code do save the hyperparameters but the fine-tuning code does not, so when you do another round of testing, the code cannot find the hyperparameter and think the model is not dataparallel. ssast/src/models/ast_models.py Lines 146 to 147 in 35ae7ab
For a temporal workaround, you can change these two lines of code: ssast/src/models/ast_models.py Lines 146 to 147 in 35ae7ab
I will find a time to fix it. |
even I change these two lines,I still can not load the finetune-ed model. |
Hi, I am running into exactly the same error and have trouble to load a finetuned model. I am wondering if @beyondbeneath ever found a solution to this? |
Hello!
Firstly, thanks for this great work!
I managed to modify the AudioSet fine tuning script, and fine tuned a model on a new audio binary classification task. I started with the "Tiny" Patch model and used a batch size of 2. The resulting predictions on the evaluation set looked very promising!.
I'm now trying to write an inference script, to take that saved model to perform inferences, and running into some trouble. Which method do I actually need to call for pure inference? From the documentation it seems to describe only pre-training or fine-tuning, not inference.
More pressing, I can't actually get the model to load. I am trying to load the
best_audio_model.pth
as follows:however this results in the errors:
Is there anything obvious I'm missing or doing wrong? Would appreciate any guidance on how to load this model, and also perform an inference on a new
.wav
file. Thanks!The text was updated successfully, but these errors were encountered: