In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from argparse import Namespace
from copy import  deepcopy

![](../imgs/framework.png)

As shown in the image, our method processes the input data using the following steps:
1. convert the input speech into Log-Frequency Spectorgram.

# Torch model

In [3]:
from model import AudioModel

  from .autonotebook import tqdm as notebook_tqdm


The settings of the `AudioModel` are defined in the `cfg`. Each setting is a key-value pair, where the key is the name of the setting and the value is the value of the setting. The meaning of each setting is defined as follows:

1. **Network Structure Parameters**:
   - `one_stem=False`: use only 1 stem to extract audio feature and make predictions.
   - `feature_extractor="ResNet"`: Uses ResNet as the feature extractor. One can also use `transformer` to use transformer-based features.
   - `pretrain_transformer_path`: Path or name of the pre-trained transformer model
   - `vocoder_classes=8`: Sets the number of vocoder classes to 8

2. **Loss Function Parameters**:
   - `use_f0_loss=False`: use F0 loss as the pesudo loss rather than using speed loss and compression loss in the content stream. Set to `False` if you want to use speed loss and compression loss in the content stream.
   - `use_speed_loss=True`: Enables speed-related loss function
   - `use_compression_loss=True`: Enables compression loss
   - `use_adversarial_loss=True`: Enable adversarial loss to train the content stream
   - `feat_con_loss=True`: Enables feature contractive loss

3. **Data Augmentation and Training Strategy**:
   - `style_shuffle=True`: Enables style shuffling
   - `feat_shuffle=True`: Enables feature shuffling
   - `aug_policy="ss"`: Sets the data augmentation policy to "ss"
   - `betas=[1, 1, 0.5, 0.5]`: Likely represents optimizer beta parameters or loss weight coefficients


In [4]:
cfg = Namespace(
    one_stem=False,
    use_f0_loss=False,
    use_speed_loss=True,
    use_compression_loss=True,
    use_adversarial_loss=True,
    style_shuffle=True,
    feat_shuffle=True,
    feature_extractor="ResNet",
    pretrain_transformer_path="facebook/wav2vec2-base-960h",
    vocoder_classes=8,
    betas=[1, 0.5, 0.5, 0.5],
    aug_policy="ss",
    feat_con_loss=True
)

## Use ResNet as featrue extractor

In [5]:
demo_model = AudioModel(feature_extractor='ResNet', cfg=cfg)

  WeightNorm.apply(module, name, dim)


In [6]:
x = torch.randn(3, 1, 48000)
spectrogram = demo_model.feature_model.preprocess(x)
print("Spectrogram shape:", spectrogram.shape)
test_res = demo_model.forward(spectrogram, stage='test')
for k, v in test_res.items():
    print(k, v.shape)

Spectrogram shape: torch.Size([3, 1, 257, 257])
hidden_states torch.Size([3, 256, 17, 17])
content_feature torch.Size([3, 512])
speed_logit torch.Size([3, 16])
compression_logit torch.Size([3, 10])
vocoder_feature torch.Size([3, 512])
vocoder_logit torch.Size([3, 9])
content_voc_logit torch.Size([3, 9])
feature torch.Size([3, 1024])
logit torch.Size([3])


## Use Tranformer as the feature extractor

By default, we use the "facebook/wav2vec2-base-960h" as the feature extractor when `feature_extractor = "transformer"`.

In [7]:
cfg2 = deepcopy(cfg)  # copy cfg to demo_model2
cfg2.feature_extractor = "transformer"
demo_model2 = AudioModel(feature_extractor=cfg2.feature_extractor, cfg=cfg2)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
x = torch.randn(3, 1, 48000)
spectrogram = demo_model2.feature_model.preprocess(x)
print("Audio shape:", spectrogram.shape)
test_res = demo_model2.forward(spectrogram, stage='test')
for k, v in test_res.items():
    print(k, v.shape)

Audio shape: torch.Size([3, 48000])
hidden_states torch.Size([3, 149, 768])
content_feature torch.Size([3, 768])
speed_logit torch.Size([3, 16])
compression_logit torch.Size([3, 10])
vocoder_feature torch.Size([3, 768])
vocoder_logit torch.Size([3, 9])
content_voc_logit torch.Size([3, 9])
feature torch.Size([3, 1536])
logit torch.Size([3])


## For Train

when using model for training, you must pass the batch (a dict) to the model. Then, the model will produce the shuffle_label based on the ground truth label for the feature shuffle loss.

In [9]:
x = torch.randn(3, 1, 48000)
batch = {
    'label' : torch.randint(0,2, (3,))
}
spectrogram = demo_model.feature_model.preprocess(x)
train_res = demo_model.forward(spectrogram, stage="train", batch=batch)
print("print train res")
print("#"*10)
for k, v in train_res.items():
    print(k, v.shape)
print("#"*10, '\n', "print batch res")
print("#"*10)
for k, v in batch.items():
    print(k, v.shape)

print train res
##########
hidden_states torch.Size([3, 256, 17, 17])
content_feature torch.Size([3, 512])
speed_logit torch.Size([3, 16])
compression_logit torch.Size([3, 10])
vocoder_feature torch.Size([3, 512])
vocoder_logit torch.Size([3, 9])
content_voc_logit torch.Size([3, 9])
feature torch.Size([3, 1024])
logit torch.Size([3])
shuffle_logit torch.Size([3])
########## 
 print batch res
##########
label torch.Size([3])
shuffle_label torch.Size([3])


# Lit Model

We use the pytorch_lightning to process the data flow, compute the loss and train the model.

In [10]:
from model import AudioModel_lit
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.loggers import  CSVLogger

In [11]:
lit_model = AudioModel_lit(cfg=cfg)

## Test forwarding 

Frist, we randomly generate a batch of audio samples. Warning, the batch must have speed_label and compression_label if you use the speed_loss and compression_loss.

In [12]:
x = torch.randn(3, 1, 48000)
batch = {
    "label": torch.randint(0, 2, (3,)),
    "audio": x,
    "sample_rate": 16000,
    "speed_label": torch.randint(0, 10, (3,)),
    "compression_label": torch.randint(0, 10, (3,)),
}

In our speed transformation, if the audio is not speed-changed, its speed label will be 5, while the compression label will be 0 if the audio is not compressed. When preparing your dataloader, you can use the following code to generate the speed label and compression label for each audio:

In [13]:
from model import RandomAudioCompressionSpeedChanging

speed_compression_transform = RandomAudioCompressionSpeedChanging(p_compression=0.9, sample_rate=16000, p_speed=1.0, min_speed=0.5, max_speed=2.0)

Assume that we read metadata from the audio datasets, and get the following data and metadata for a audio sample:

In [14]:
x = torch.randn(1, 48000)
metadata = {
    "label": 1,
    "audio": x,
    "sample_rate": 16000,
}

Use the folowing code to randomly change the speed and compress audio:

In [15]:
x = speed_compression_transform(x, metadata)
for k, v in metadata.items():
    print(k, v)

label 1
audio tensor([[-0.6149, -0.0170,  1.6779,  ..., -0.2504,  0.3575,  1.1373]])
sample_rate 16000
compression_label 0
speed_label 10
speed 1.5


Note, you have to pass the metadata into `speed_compression_transform` function to get the speed and compression labels.


---

In the lit model, we use the `_shared_pred` method to predict the logits of the input batch. If the stage is train, we also the the audio_transform to augment the spectrogram.


```python
    def _shared_pred(self, batch, batch_idx, stage="train"):
        """common predict step for train/val/test

        Note that the data augmenation is done in the self.model.feature_extractor.

        """
        audio, sample_rate = batch["audio"], batch["sample_rate"]


        audio = self.model.feature_model.preprocess(audio, stage=stage)
        if stage == "train" and self.cfg.feature_extractor == "ResNet":
            audio = self.audio_transform.batch_apply(audio)


        batch_res = self.model(
            audio,
            stage=stage,
            batch=batch if stage == "train" else None,
            one_stem=self.one_stem,
        )

        batch_res["pred"] = (torch.sigmoid(batch_res["logit"]) + 0.5).int()

        return batch_res
```

In [16]:
batch_res = lit_model._shared_pred(batch=batch, batch_idx=0)
for key, value in batch_res.items():
    print(key, value.shape)

hidden_states torch.Size([3, 256, 17, 17])
content_feature torch.Size([3, 512])
speed_logit torch.Size([3, 16])
compression_logit torch.Size([3, 10])
vocoder_feature torch.Size([3, 512])
vocoder_logit torch.Size([3, 9])
content_voc_logit torch.Size([3, 9])
feature torch.Size([3, 1024])
logit torch.Size([3])
shuffle_logit torch.Size([3])
pred torch.Size([3])


## Demo training

We first build a simple dataloaders for training, where all the samples are randomly generated.

In [17]:
from model import EER_Callback, BinaryAUC_Callback, BinaryACC_Callback

In [18]:
import torch
from torch.utils.data import Dataset, DataLoader

In [19]:

class SimpleTestDataset(Dataset):
    def __init__(self, num_samples=10):
        # Generate synthetic data similar to your example
        self.samples = []
        for _ in range(num_samples):
            self.samples.append({
                "audio": torch.randn(1, 48000),
                "label": torch.randint(0, 2, (1,)).item(),
                "sample_rate": 16000,
                "speed_label": torch.randint(0, 10, (1,)).item(),
                "compression_label": torch.randint(0, 10, (1,)).item(),
            })
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

# Create a simple collate function
def simple_collate_fn(batch):
    audio = torch.stack([item["audio"] for item in batch])
    label = torch.tensor([item["label"] for item in batch])
    speed_label = torch.tensor([item["speed_label"] for item in batch])
    compression_label = torch.tensor([item["compression_label"] for item in batch])
    
    return {
        "audio": audio,
        "label": label,
        "sample_rate": 16000,
        "speed_label": speed_label,
        "compression_label": compression_label,
    }

# Create the dataset and dataloader
test_dataset = SimpleTestDataset(num_samples=20)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=3,
    shuffle=False,
    collate_fn=simple_collate_fn,
)

We build a simple trainer to train and test our model.

In [20]:
trainer = Trainer(
    logger=CSVLogger(save_dir="./logs", version=0),
    max_epochs=4,
    callbacks=[
        BinaryACC_Callback(batch_key="label", output_key="logit"),
        BinaryAUC_Callback(batch_key="label", output_key="logit"),
        EER_Callback(batch_key="label", output_key="logit"),
    ],
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Start training

In [21]:
trainer.fit(lit_model, test_dataloader)

/Volumes/GEIL2T/Softwares/anaconda3/envs/RobustSpeechDetection/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
/Volumes/GEIL2T/Softwares/anaconda3/envs/RobustSpeechDetection/lib/python3.9/site-packages/lightning_fabric/loggers/csv_logs.py:268: Experiment logs directory ./logs/lightning_logs/version_0 exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
/Volumes/GEIL2T/Softwares/anaconda3/envs/RobustSpeechDetection/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory ./logs/lightning_logs/version_0/checkpoints exists and is not empty.

  | Name           | Type                    | Params | Mode 
-------------------------------------------------------------------
0 | model          | AudioModel              | 21.7 M | train
1 | bce_loss       | BCEWithLogitsLoss       |

Epoch 3: 100%|██████████| 7/7 [00:01<00:00,  3.79it/s, v_num=0, train-cls_loss=0.743, train-feat_shuffle_loss=0.0672, train-vocoder_stem_loss=0.000, train-compression_loss=2.190, train-speed_loss=2.540, train-f0_loss=0.000, train-feat_contrast_loss=0.310, train-loss=3.270, train-acc=0.550, train-auc=0.770, train-eer=0.200]

`Trainer.fit` stopped: `max_epochs=4` reached.


Epoch 3: 100%|██████████| 7/7 [00:02<00:00,  2.78it/s, v_num=0, train-cls_loss=0.743, train-feat_shuffle_loss=0.0672, train-vocoder_stem_loss=0.000, train-compression_loss=2.190, train-speed_loss=2.540, train-f0_loss=0.000, train-feat_contrast_loss=0.310, train-loss=3.270, train-acc=0.550, train-auc=0.770, train-eer=0.200]


After training, you can view the logging loss in the logger file, for example `logs/lightning_logs/version_0/metrics.csv`.
![](imgs/loss.png)

## Demo Testing

After testing, the results will also saved in logger file.

In [22]:
trainer.test(lit_model, test_dataloader)

/Volumes/GEIL2T/Softwares/anaconda3/envs/RobustSpeechDetection/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 7/7 [00:00<00:00, 33.88it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test-acc                    0.5
        test-auc            0.9800000190734863
      test-cls_loss         0.6732991933822632
  test-compression_loss     2.1219122409820557
        test-eer            0.10000000149011612
      test-f0_loss                  0.0
 test-feat_contrast_loss    0.32317954301834106
        test-loss           3.1648013591766357
     test-speed_loss        2.5379128456115723
 test-vocoder_stem_loss             0.0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test-cls_loss': 0.6732991933822632,
  'test-vocoder_stem_loss': 0.0,
  'test-compression_loss': 2.1219122409820557,
  'test-speed_loss': 2.5379128456115723,
  'test-f0_loss': 0.0,
  'test-feat_contrast_loss': 0.32317954301834106,
  'test-loss': 3.1648013591766357,
  'test-acc': 0.5,
  'test-auc': 0.9800000190734863,
  'test-eer': 0.10000000149011612}]