From b7086755ebfd9f2ab018c0a40722b8418d9d41fe Mon Sep 17 00:00:00 2001 From: ygong Date: Sun, 8 May 2022 17:56:59 -0400 Subject: [PATCH] fix a bug --- README.md | 2 +- src/dataloader.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8e2a8e5..2fd66d2 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ ## News -May, 2022: It was found that newer `torchaudio` package has different behavior with older ones in time and frequency masking and will cause a bug. Please stick to the version in `requirement.txt`. +May, 2022: It was found that newer `torchaudio` package has different behavior with older ones in time and frequency masking and will cause a bug. We find a workaround and fixed it. March, 2022: We released a new preprint [*CMKD: CNN/Transformer-Based Cross-Model Knowledge Distillation for Audio Classification*](https://arxiv.org/abs/2203.06760), where we proposed a knowledge distillation based method to further improve the AST model performance without changing its architecture. diff --git a/src/dataloader.py b/src/dataloader.py index 517387f..e8f29ee 100644 --- a/src/dataloader.py +++ b/src/dataloader.py @@ -187,10 +187,14 @@ def __getitem__(self, index): freqm = torchaudio.transforms.FrequencyMasking(self.freqm) timem = torchaudio.transforms.TimeMasking(self.timem) fbank = torch.transpose(fbank, 0, 1) + # this is just to satisfy new torchaudio version, which only accept [1, freq, time] + fbank = fbank.unsqueeze(0) if self.freqm != 0: fbank = freqm(fbank) if self.timem != 0: fbank = timem(fbank) + # squeeze it back, it is just a trick to satisfy new torchaudio version + fbank = fbank.squeeze(0) fbank = torch.transpose(fbank, 0, 1) # normalize the input for both training and test