In [1]:
import torch
from argparse import Namespace
import warnings
import os

warnings.filterwarnings("ignore")
torch.manual_seed(42)
torch.set_float32_matmul_precision("medium")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.5.1
CUDA available: False


## Phoneme Recognition Model

In [2]:
from phoneme_GAT.phoneme_model import BaseModule, load_phoneme_model, optim_param

In [3]:
from ay2.tools.text._phonemes import Phonemer_Tokenizer_Recombination

**Setup Instructions:**

1. The pretrained phoneme recognition model should be in the project root: `Best Epoch 42 Validation 0.407.ckpt`
2. The vocab files should be in: `vocab_phoneme/` directory with all 9 language JSON files
3. The paths are now automatically configured to use local files (no manual changes needed!)

If the checkpoint is missing, download it from [Google Drive](https://drive.google.com/file/d/1SbqynkUQxxlhazklZz9OgcVK7Fl2aT-z/view?usp=drive_link).

In [4]:
import os

# Use local checkpoint path (automatically finds it in project root)
project_root = os.path.abspath(".")
pretrained_path = os.path.join(project_root, "Best Epoch 42 Validation 0.407.ckpt")

network_param = Namespace(
    network_name="WavLM",
    pretrained_path=pretrained_path,
    freeze=True,
    freeze_transformer=True,
    eos_token="</s>",
    bos_token="<s>",
    unk_token="<unk>",
    pad_token="<pad>",
    word_delimiter_token="|",
    vocab_size=200,
)

print(f"✓ Using checkpoint: {pretrained_path}")
print(f"✓ Checkpoint exists: {os.path.exists(pretrained_path)}")

✓ Using checkpoint: /Users/arjunjindal/Desktop/PLFD-ADD/Best Epoch 42 Validation 0.407.ckpt
✓ Checkpoint exists: True


**Building the Phoneme Recognition Model:**

The `load_phoneme_model` function now automatically:
1. Uses the local checkpoint path specified above
2. Finds vocab files in the `vocab_phoneme/` directory (relative path)
3. Downloads the WavLM base model from HuggingFace if not cached

✅ Everything is configured automatically - just run the cells!

In [5]:
total_num_phonemes = 687  ## 198, or 687

phoneme_model = load_phoneme_model(
    network_name=network_param.network_name,
    pretrained_path=network_param.pretrained_path,
    total_num_phonemes=total_num_phonemes,
)
assert len(phoneme_model.tokenizer.total_phonemes) == total_num_phonemes

Now, load vocab json files from  /Users/arjunjindal/Desktop/PLFD-ADD/vocab_phoneme Please make sure the vocab files are correct
Load WavLM model!!!!!!!


Some weights of WavLMForCTC were not initialized from the model checkpoint at microsoft/wavlm-base and are newly initialized: ['lm_head.weight', 'lm_head.bias', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'encoder.pos_conv_embed.conv.parametrizations.weight.original0']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([687, 768])


# Load model

## Audio model

In [6]:
from phoneme_GAT.modules import Phoneme_GAT_lit,Phoneme_GAT

In [7]:
audio_model = Phoneme_GAT(
    backbone='wavlm',
    use_raw=0,
    use_GAT=1,
    n_edges=10,
)

Now, load vocab json files from  /Users/arjunjindal/Desktop/PLFD-ADD/vocab_phoneme Please make sure the vocab files are correct
Load WavLM model!!!!!!!


Some weights of WavLMForCTC were not initialized from the model checkpoint at microsoft/wavlm-base and are newly initialized: ['lm_head.weight', 'lm_head.bias', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'encoder.pos_conv_embed.conv.parametrizations.weight.original0']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([687, 768])


Generate a random audio to test the model.

In [8]:
x = torch.randn(3, 1, 48000)
num_frames = torch.full((x.shape[0],), 48000 // 320 - 1) # (batch_size,)
res = audio_model(x, num_frames=num_frames)

In [9]:
for key, value in res.items():
    print(key, value.shape)

logit torch.Size([3])
hidden_states torch.Size([3, 768])
phoneme_feat torch.Size([3, 149, 768])
encoder_feat torch.Size([3, 149, 768])
phoneme_cls_logit torch.Size([3, 687])
phoneme_cls_label torch.Size([3])
aug_logit torch.Size([3])
aug_frame_logit torch.Size([3])
aug_labels torch.Size([3])


# Lit model

The settings of the `AudioModel` are defined in the `cfg`. Each setting is a key-value pair, where the key is the name of the setting and the value is the value of the setting. The meaning of each setting is defined as follows:

1. **Network Structure Parameters**:
   - `backbone` : "wavlm", the backbone of the phoneme recognition model.
   - `use_raw` : `False`, whether to use raw transformer as the backbone
   - `use_GAT`: `True`, whether to use GAT
   - `n_edges`: `10`, the nubmer of edges for each node in the GAT
   - `use_pool`: `True`, whether to use pooling


2. **Loss Function Parameters**:
   - `use_clip`: `True`, whether to use clip loss


3. **Data Augmentation and Training Strategy**:
   - `use_aug`: `True`, whether to use data augmentation in the training


In [10]:
from argparse import Namespace

# Construct the configuration using Namespace
cfg = Namespace(
    PhonemeGAT=Namespace(
        backbone="wavlm",  # wavlm or wav2vec
        use_raw=False,              # whether to use raw transformer as the backbone
        use_GAT=True,              # whether to use GAT
        n_edges=10,                # the nubmer of edges for each node in the GAT
        use_aug=True,              # whether to use data augmentation in the training
        use_pool=True,            # whether to use pooling
        use_clip=True,             # whether to use clip loss
    )
)

In [11]:
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.loggers import  CSVLogger

We use the pytorch Lightning module to train the model, where we define the train step, validation/predict step, loss function and optimizer.

In [12]:
audio_model_lit = Phoneme_GAT_lit(cfg=cfg)

Now, load vocab json files from  /Users/arjunjindal/Desktop/PLFD-ADD/vocab_phoneme Please make sure the vocab files are correct
Load WavLM model!!!!!!!


Some weights of WavLMForCTC were not initialized from the model checkpoint at microsoft/wavlm-base and are newly initialized: ['lm_head.weight', 'lm_head.bias', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'encoder.pos_conv_embed.conv.parametrizations.weight.original0']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([687, 768])


## Test forwarding 

In the lit model, we use the `_shared_pred` method to predict the logits of the input batch. If the stage is train, we also the the audio_transform to augment the spectrogram.

Generate a random batch:

In [13]:
x = torch.randn(3, 1, 48000)
batch = {
    "label": torch.randint(0, 2, (3,)),
    "audio": x,
    "sample_rate": 16000,
}

Note, you batch must be a dict with above keys.

In [14]:
batch_res = audio_model_lit._shared_pred(batch=batch, batch_idx=0)
for key, value in batch_res.items():
    print(key, value.shape)

logit torch.Size([3])
hidden_states torch.Size([3, 768])
phoneme_feat torch.Size([3, 149, 768])
encoder_feat torch.Size([3, 149, 768])
phoneme_cls_logit torch.Size([3, 687])
phoneme_cls_label torch.Size([3])
aug_logit torch.Size([3])
aug_frame_logit torch.Size([3])
aug_labels torch.Size([3])


## Demo training

We first build a simple dataloaders for training, where all the samples are randomly generated.

In [15]:
from callbacks import EER_Callback, BinaryAUC_Callback, BinaryACC_Callback

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

In [None]:
class SimpleTestDataset(Dataset):
    def __init__(self, num_samples=10):
        # Generate synthetic data similar to your example
        self.samples = []
        for _ in range(num_samples):
            self.samples.append({
                "audio": torch.randn(1, 48000),
                "label": torch.randint(0, 2, (1,)).item(),
                "sample_rate": 16000,
            })
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]



# Create the dataset and dataloader
test_dataset = SimpleTestDataset(num_samples=20)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=3,
    shuffle=False,
)

We build a simple trainer to train and test our model.

In [18]:
# Auto-detect GPU or CPU
if torch.cuda.is_available():
    accelerator = "gpu"
    devices = 1
else:
    accelerator = "cpu"
    devices = "auto"

print(f"Using accelerator: {accelerator}")

trainer = Trainer(
    logger=CSVLogger(save_dir="./logs", version=None),
    max_epochs=4,
    accelerator=accelerator,
    devices=devices,
    callbacks=[
        BinaryACC_Callback(batch_key="label", output_key="logit"),
        BinaryAUC_Callback(batch_key="label", output_key="logit"),
        EER_Callback(batch_key="label", output_key="logit"),
    ],
)

print(f"Logger path: {trainer.logger.log_dir}")

GPU available: True (mps), used: False


Using accelerator: cpu


TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Logger path: ./logs/lightning_logs/version_5


Start training

In [19]:
trainer.fit(audio_model_lit, test_dataloader)


  | Name          | Type                    | Params | Mode 
------------------------------------------------------------------
0 | model         | Phoneme_GAT             | 196 M  | train
1 | bce_loss      | BCEWithLogitsLoss       | 0      | train
2 | ce_loss       | CrossEntropyLoss        | 0      | train
3 | contrast_loss | BinaryTokenContrastLoss | 0      | train
4 | clip_head     | Sequential              | 1.2 M  | train
5 | clip_loss     | CLIPLoss1D              | 1      | train
------------------------------------------------------------------
102 M     Trainable params
94.9 M    Non-trainable params
197 M     Total params
790.544   Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=4` reached.


After training, you can view the logging loss in the logger file, for example `logs/lightning_logs/version_0/metrics.csv`.
![](imgs/loss.png)

## Demo Testing

After testing, the results will also saved in logger file.

In [20]:
trainer.test(audio_model_lit, test_dataloader)

Testing: |          | 0/? [00:00<?, ?it/s]

[{'test-loss': 1.7813193798065186,
  'test-cls_loss': 0.701637864112854,
  'test-clip_loss': 2.159363269805908,
  'test-aug_loss': 0.0,
  'test-acc': 0.44999998807907104,
  'test-auc': 0.59375,
  'test-eer': 0.5}]