In [1]:
import torch
from argparse import  Namespace

## Phoneme Recognition Model

In [2]:
from phoneme_GAT.phoneme_model import BaseModule, load_phoneme_model, optim_param

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from ay2.tools.text._phonemes import Phonemer_Tokenizer_Recombination

1. You can download the pretrained phoneme recognition model in [google drive](https://drive.google.com/file/d/1SbqynkUQxxlhazklZz9OgcVK7Fl2aT-z/view?usp=drive_link).
2. Change `pretrained_path` to you own custom path.
3. Remember to change the `pretrained_path` and `vocab_path` in the `load_phoneme_model` function of `phoneme_GAT.phoneme_model`.

In [4]:
network_param = Namespace(
    network_name="WavLM",
    pretrained_path = "/home/ay/data/phonemes/wavlm/best-epoch=42-val-per=0.407000.ckpt",
    freeze=True,
    freeze_transformer=True,
    eos_token="</s>",
    bos_token="<s>",
    unk_token="<unk>",
    pad_token="<pad>",
    word_delimiter_token="|",
    vocab_size=200,
)

To build the phoneme recognition model,
1. you must specify the pretrained_path!!!! Please download the provided pretrained phoneme model; or you can train yourself model through `train_phoneme_model.py`.
2. in the `load_phoneme_model` function, you have to change the correct `vocab_path`

In [5]:
total_num_phonemes = 687  ## 198, or 687

phoneme_model = load_phoneme_model(
    network_name=network_param.network_name,
    pretrained_path=network_param.pretrained_path,
    total_num_phonemes=total_num_phonemes,
)
assert len(phoneme_model.tokenizer.total_phonemes) == total_num_phonemes

Now, load vocab json files from  /home/ay/tmp/PLFD-ADD/vocab_phoneme Please make sure the vocab files are correct
Load WavLM model!!!!!!!


Some weights of WavLMForCTC were not initialized from the model checkpoint at /home/ay/.cache/huggingface/hub/models--microsoft--wavlm-base/snapshots/efa81aae7ff777e464159e0f877d54eac5b84f81/ and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([687, 768])


# Load model

## Audio model

In [6]:
from phoneme_GAT.modules import Phoneme_GAT_lit,Phoneme_GAT

In [7]:
audio_model = Phoneme_GAT(
    backbone='wavlm',
    use_raw=0,
    use_GAT=1,
    n_edges=10,
)

Now, load vocab json files from  /home/ay/tmp/PLFD-ADD/vocab_phoneme Please make sure the vocab files are correct
Load WavLM model!!!!!!!


Some weights of WavLMForCTC were not initialized from the model checkpoint at /home/ay/.cache/huggingface/hub/models--microsoft--wavlm-base/snapshots/efa81aae7ff777e464159e0f877d54eac5b84f81/ and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([687, 768])


Generate a random audio to test the model.

In [8]:
x = torch.randn(3, 1, 48000)
num_frames = torch.full((x.shape[0],), 48000 // 320 - 1) # (batch_size,)
res = audio_model(x, num_frames=num_frames)

In [9]:
for key, value in res.items():
    print(key, value.shape)

logit torch.Size([3])
hidden_states torch.Size([3, 768])
phoneme_feat torch.Size([3, 149, 768])
encoder_feat torch.Size([3, 149, 768])
phoneme_cls_logit torch.Size([3, 687])
phoneme_cls_label torch.Size([3])
aug_logit torch.Size([3])
aug_frame_logit torch.Size([3])
aug_labels torch.Size([3])


# Lit model

The settings of the `AudioModel` are defined in the `cfg`. Each setting is a key-value pair, where the key is the name of the setting and the value is the value of the setting. The meaning of each setting is defined as follows:

1. **Network Structure Parameters**:
   - `backbone` : "wavlm", the backbone of the phoneme recognition model.
   - `use_raw` : `False`, whether to use raw transformer as the backbone
   - `use_GAT`: `True`, whether to use GAT
   - `n_edges`: `10`, the nubmer of edges for each node in the GAT
   - `use_pool`: `True`, whether to use pooling


2. **Loss Function Parameters**:
   - `use_clip`: `True`, whether to use clip loss


3. **Data Augmentation and Training Strategy**:
   - `use_aug`: `True`, whether to use data augmentation in the training


In [10]:
from argparse import Namespace

# Construct the configuration using Namespace
cfg = Namespace(
    PhonemeGAT=Namespace(
        backbone="wavlm",  # wavlm or wav2vec
        use_raw=False,              # whether to use raw transformer as the backbone
        use_GAT=True,              # whether to use GAT
        n_edges=10,                # the nubmer of edges for each node in the GAT
        use_aug=True,              # whether to use data augmentation in the training
        use_pool=True,            # whether to use pooling
        use_clip=True,             # whether to use clip loss
    )
)

In [11]:
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.loggers import  CSVLogger

We use the pytorch Lightning module to train the model, where we define the train step, validation/predict step, loss function and optimizer.

In [12]:
audio_model_lit = Phoneme_GAT_lit(cfg=cfg)

Now, load vocab json files from  /home/ay/tmp/PLFD-ADD/vocab_phoneme Please make sure the vocab files are correct
Load WavLM model!!!!!!!


Some weights of WavLMForCTC were not initialized from the model checkpoint at /home/ay/.cache/huggingface/hub/models--microsoft--wavlm-base/snapshots/efa81aae7ff777e464159e0f877d54eac5b84f81/ and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([687, 768])


## Test forwarding 

In the lit model, we use the `_shared_pred` method to predict the logits of the input batch. If the stage is train, we also the the audio_transform to augment the spectrogram.

Generate a random batch:

In [13]:
x = torch.randn(3, 1, 48000)
batch = {
    "label": torch.randint(0, 2, (3,)),
    "audio": x,
    "sample_rate": 16000,
}

Note, you batch must be a dict with above keys.

In [14]:
batch_res = audio_model_lit._shared_pred(batch=batch, batch_idx=0)
for key, value in batch_res.items():
    print(key, value.shape)

logit torch.Size([3])
hidden_states torch.Size([3, 768])
phoneme_feat torch.Size([3, 149, 768])
encoder_feat torch.Size([3, 149, 768])
phoneme_cls_logit torch.Size([3, 687])
phoneme_cls_label torch.Size([3])
aug_logit torch.Size([3])
aug_frame_logit torch.Size([3])
aug_labels torch.Size([3])


## Demo training

We first build a simple dataloaders for training, where all the samples are randomly generated.

In [15]:
from callbacks import EER_Callback, BinaryAUC_Callback, BinaryACC_Callback

In [16]:
import torch
from torch.utils.data import Dataset, DataLoader

In [None]:
class SimpleTestDataset(Dataset):
    def __init__(self, num_samples=10):
        # Generate synthetic data similar to your example
        self.samples = []
        for _ in range(num_samples):
            self.samples.append({
                "audio": torch.randn(1, 48000),
                "label": torch.randint(0, 2, (1,)).item(),
                "sample_rate": 16000,
            })
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]



# Create the dataset and dataloader
test_dataset = SimpleTestDataset(num_samples=20)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=3,
    shuffle=False,
)

We build a simple trainer to train and test our model.

In [18]:
trainer = Trainer(
    logger=CSVLogger(save_dir="./logs", version=0),
    max_epochs=4,
    callbacks=[
        BinaryACC_Callback(batch_key="label", output_key="logit"),
        BinaryAUC_Callback(batch_key="label", output_key="logit"),
        EER_Callback(batch_key="label", output_key="logit"),
    ],
)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Start training

In [19]:
trainer.fit(audio_model_lit, test_dataloader)

/home/ay/anaconda3/envs/phoneme_deepfake_detection/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name          | Type                    | Params | Mode 
------------------------------------------------------------------
0 | model         | Phoneme_GAT             | 196 M  | train
1 | bce_loss      | BCEWithLogitsLoss       | 0      | train
2 | ce_loss       | CrossEntropyLoss        | 0      | train
3 | contrast_loss | BinaryTokenContrastLoss | 0      | trai

Epoch 3: 100%|██████████| 7/7 [00:01<00:00,  4.11it/s, v_num=0, train-loss=2.250, train-cls_loss=0.666, train-clip_loss=2.880, train-aug_loss=0.295, train-acc=0.600, train-auc=0.495, train-eer=0.462]

`Trainer.fit` stopped: `max_epochs=4` reached.


Epoch 3: 100%|██████████| 7/7 [00:07<00:00,  0.97it/s, v_num=0, train-loss=2.250, train-cls_loss=0.666, train-clip_loss=2.880, train-aug_loss=0.295, train-acc=0.600, train-auc=0.495, train-eer=0.462]


After training, you can view the logging loss in the logger file, for example `logs/lightning_logs/version_0/metrics.csv`.
![](imgs/loss.png)

## Demo Testing

After testing, the results will also saved in logger file.

In [20]:
trainer.test(audio_model_lit, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
/home/ay/anaconda3/envs/phoneme_deepfake_detection/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 7/7 [00:00<00:00,  8.32it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test-acc            0.6000000238418579
        test-auc             0.791208803653717
      test-aug_loss                 0.0
     test-clip_loss          2.873018741607666
      test-cls_loss         0.6275677680969238
        test-eer            0.3076923191547394
        test-loss            2.064077138900757
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test-loss': 2.064077138900757,
  'test-cls_loss': 0.6275677680969238,
  'test-clip_loss': 2.873018741607666,
  'test-aug_loss': 0.0,
  'test-acc': 0.6000000238418579,
  'test-auc': 0.791208803653717,
  'test-eer': 0.3076923191547394}]