# WFST with Neural LCS Alignment

This notebook integrates the Neural LCS alignment functionality from the Neural-LCS project with WFST decoding for dysfluency detection.


In [12]:
# Import required libraries
import warnings
import numpy as np
import torch
import torch.nn as nn
import torchaudio
import json
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2PhonemeCTCTokenizer, PreTrainedTokenizerFast, T5EncoderModel, T5Config
from IPython.display import Audio, display
import matplotlib.pyplot as plt
from copy import deepcopy
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import os

warnings.filterwarnings("ignore")


In [13]:
# CMU Phoneme to ID mapping
phoneme_to_id = {
    "AA": 0, "AE": 1, "AH": 2, "AO": 3, "AW": 4, "AY": 5,
    "B": 6, "CH": 7, "D": 8, "DH": 9, "EH": 10, "ER": 11,
    "EY": 12, "F": 13, "G": 14, "HH": 15, "IH": 16, "IY": 17,
    "JH": 18, "K": 19, "L": 20, "M": 21, "N": 22, "NG": 23,
    "OW": 24, "OY": 25, "P": 26, "R": 27, "S": 28, "SH": 29,
    "T": 30, "TH": 31, "UH": 32, "UW": 33, "V": 34, "W": 35,
    "Y": 36, "Z": 37, "ZH": 38,
    "<pad>": 39, "<unk>": 40, "<cls>": 41, "<sep>": 42
}


In [14]:
# Custom Phoneme Tokenizer
from transformers import PreTrainedTokenizer

class PhonemeTokenizer(PreTrainedTokenizer):
    def __init__(self, phoneme_to_id, **kwargs):
        self.phoneme_to_id = phoneme_to_id
        super().__init__(**kwargs)
        self.id_to_phoneme = {v: k for k, v in phoneme_to_id.items()}
        self.pad_token = "<pad>"
        self.unk_token = "<unk>"
        self.cls_token = "<cls>"
        self.sep_token = "<sep>"
        
    def get_vocab(self):
        # return vocabulary table
        return self.phoneme_to_id
        
    def _convert_token_to_id(self, token):
        # Convert a single phoneme to its ID
        return self.phoneme_to_id.get(token, self.phoneme_to_id.get(self.unk_token))

    def _convert_id_to_token(self, index):
        # Convert a single ID back to a phoneme
        return self.id_to_phoneme.get(index, self.unk_token)

    def _tokenize(self, text):
        # Split text into phonemes and map them to IDs
        return [self.phoneme_to_id.get(phoneme, self.phoneme_to_id.get(self.unk_token)) for phoneme in text.split()]

    def encode(self, text, max_length = 120, add_special_tokens = True, padding = True):
        max_len = max_length
        token_ids = self._tokenize(text)
        if add_special_tokens:
            token_ids = token_ids + [self.phoneme_to_id[self.sep_token]]
        if padding:
            prev_len = len(token_ids)
            token_ids = token_ids + [self.phoneme_to_id[self.pad_token]] * (max_len - prev_len)
            mask = [1] * prev_len + [0] * (max_len - prev_len)
        if padding:
            return {"input_ids": torch.tensor(token_ids), "attention_mask": torch.tensor(mask)}
        else:
            return {"input_ids": torch.tensor(token_ids), "attention_mask": None}
                

    def decode(self, token_ids, skip_special_tokens=True):
        tokens = [self.id_to_phoneme[token_id] for token_id in token_ids if token_id in self.id_to_phoneme]
        if skip_special_tokens:
            tokens = [token for token in tokens if token not in [self.pad_token, self.cls_token, self.sep_token]]
        return " ".join(tokens)

    def __len__(self):
        return len(self.phoneme_to_id)

tokenizer = PhonemeTokenizer(phoneme_to_id)


In [15]:
# Hard LCS Algorithm
def find_lcs_new(seq1, seq2):
    """Find the longest common subsequence between two sequences using embedding distance."""
    ## seq1 is target
    ## seq2 is source
    
    lengths = [[0] * (len(seq2) + 1) for _ in range(len(seq1) + 1)]
    for i, x in enumerate(seq1):
        for j, y in enumerate(seq2):
            if  x == y:  # Adjust the threshold as needed
                ### Similar
                lengths[i+1][j+1] = lengths[i][j] + 1
            else:
                ### Not Similar
                lengths[i+1][j+1] = max(lengths[i+1][j], lengths[i][j+1])
    
    # Reconstruct the LCS
    align_lcs_result = [] #tuple (src, target)
    x, y = len(seq1), len(seq2)
    
    
    while x != 0 and y != 0:

        if lengths[x][y] == lengths[x-1][y]:
            x -= 1
        elif lengths[x][y] == lengths[x][y-1]:
            align_lcs_result.append((seq2[y - 1], seq1[x - 1]))
            y -= 1
        else:
            align_lcs_result.append((seq2[y - 1], seq1[x - 1]))
            x -= 1
            y -= 1
            
        if x == 0 and y != 0:
            while y > 0: 
                align_lcs_result.append((seq2[y - 1], seq1[0]))
                y -= 1
            break
        if x != 0 and y == 0:
            align_lcs_result.append((seq2[0], seq1[x - 1]))
            break
    
    align_lcs_result.reverse()

    return align_lcs_result

def insert_missing_tuples(ref_words, align_result):
    # Extract the second elements from the align_result tuples
    align_words = []
    for t in align_result:
        if len(align_words) == 0:
            align_words.append(t[1])
        elif align_words[-1] != t[1]:
            align_words.append(t[1])
            
    missing_flg = [1] * len(ref_words)

    for i, word in enumerate(ref_words):
        missing_num = sum(missing_flg[:i])
        j = i - missing_num
        if j <len(align_words) and align_words[j] == word:
            missing_flg[i] = 0

    # Create a copy of align_result to insert missing words
    new_align_result = []
    ref_idx = 0  # Index to track the current position in ref_words

    for idx, word in enumerate(ref_words):
        # Insert missing words
        if missing_flg[idx] == 1:
            new_align_result.append((None, word))
        # Add the next word from align_result if it matches
        while ref_idx < len(align_result) and align_result[ref_idx][1] == word:
            new_align_result.append(align_result[ref_idx])
            ref_idx += 1

    return new_align_result


In [16]:
# Neural LCS Model Architecture
class PhonemeBoundaryAlignerT5(nn.Module):
    def __init__(self, pretrained_model_name="t5-small", phoneme_vocab_size=50, gru_hidden_dim = 128, num_filters = 16, hidden_dim=512):
        super(PhonemeBoundaryAlignerT5, self).__init__()
        # Load pretrained T5 model
        self.encoder = T5EncoderModel.from_pretrained(pretrained_model_name)
        
        # Resize token embeddings to fit phoneme vocab size
        self.encoder.resize_token_embeddings(phoneme_vocab_size)

        # GRU Layer to capture long sentence feature
        self.gru = nn.GRU(input_size = hidden_dim * 2, hidden_size = gru_hidden_dim, 
                          num_layers = 1, batch_first = True, bidirectional = True)

        # 1D CNN Layer
        self.conv1d = nn.Conv1d(
            in_channels = gru_hidden_dim * 2,
            out_channels = num_filters,
            kernel_size = 3,
            stride = 1,
            padding = 1
        )
                
        # Boundary Predictor: Fully connected layers
        self.MLP = nn.Sequential(
            nn.ReLU(),
            nn.Linear(num_filters, num_filters//2),
            nn.ReLU(),
            nn.Linear(num_filters//2, 4),
        )
        
        # Re-initialize weights if needed
        self._init_weights()

    def forward(self, ref, src, ref_mask=None, src_mask=None):
        # Encode reference and source phonemes using T5
        ref_output = self.encoder(input_ids=ref, attention_mask=ref_mask).last_hidden_state  # (batch_size, src_len, hidden_dim)
        src_output = self.encoder(input_ids=src, attention_mask=src_mask).last_hidden_state  # (batch_size, src_len, hidden_dim)
        
        # Combine ref and src features
        alignment_features = torch.cat((ref_output, src_output), dim=-1)  # (batch_size, src_len, hidden_dim * 2)
        alignment_features, _ = self.gru(alignment_features)              # (batch_size, src_len, gru_hidden_dim * 2)

        alignment_features = alignment_features.permute(0, 2, 1)          # (batch_size, gru_hidden_dim * 2, src_len)
        
        # Process with Conv layer to capture contextual information
        conv_features = self.conv1d(alignment_features)     # (batch_size, num_filters, src_len) 
        conv_features = conv_features.permute(0, 2, 1)      # (batch_size, src_len, num_filters)
        
        # Through MLP layer
        mlp_features = self.MLP(conv_features) # (batch_size, src_len, 3)
        
        return mlp_features
    
    def _init_weights(self):
        """
        Reinitialize model parameters for layers other than embeddings.
        """
        def init_weights(m):
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)
            elif isinstance(m, torch.nn.Conv1d):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, torch.nn.Embedding):
                torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        
        # Apply custom weight initialization to layers
        self.apply(init_weights)


In [17]:
# Alignment Correction Function
def correct(align_result, ref, src):
    if align_result[-1][1] == ref[-1] and align_result[-1][0] == src[-1]: return align_result
    elif align_result[-1][1] == align_result[-1][0]: return align_result
    else:
        mark = 0
        split_bd = None
        split_sign = None
        idx = len(align_result) - 1
        while idx > 0:
            if align_result[idx][0] == align_result[idx][1]: 
                mark += 1
            if mark == 2:
                j = idx - 1
                while j >= 0 and align_result[j][1] == align_result[idx][1]:
                    j -= 1
                split_bd = align_result[j][1]
                split_sign = j
                break
            idx -= 1 
        # print(split_bd)
        if mark != 2 or split_bd == None: return align_result

        align_right = deepcopy(align_result)
        
        idx = len(align_result) - 1
        new_src = []
        new_ref = []
        while idx >= 0 and idx != split_sign:
            align_right.pop()
            if align_result[idx][0] != None:
                new_src.insert(0, align_result[idx][0])
            idx -= 1
        
        idx = len(align_result) - 1
        while idx > 0 and idx != split_sign:
            if align_result[idx][1] == None:
                idx -= 1
                continue
            jdx = idx
            while jdx >= 0 and align_result[jdx][1] == align_result[idx][1]:
                jdx -= 1
            jdx += 1
            idx = jdx
            new_ref.insert(0, align_result[jdx][1])
            idx -= 1

        # print(new_ref, new_src)

        new_result = align_right + insert_missing_tuples(new_ref, find_lcs_new(new_ref, new_src))
        return new_result


In [18]:
# Neural LCS Function
def neuralLCS(seq1, seq2, model):
    ## seq1 is target
    ## seq2 is source
    tokenizer = PhonemeTokenizer(phoneme_to_id)
    ref = tokenizer.encode(" ".join(seq1), max_length = 120)["input_ids"]
    src = tokenizer.encode(" ".join(seq2), max_length = 120)["input_ids"]
    ref_mask = tokenizer.encode(" ".join(seq1), max_length = 120)["attention_mask"]
    src_mask = tokenizer.encode(" ".join(seq2), max_length = 120)["attention_mask"]

    model.eval()

    ref = torch.tensor(ref).unsqueeze(0).to(device)
    src = torch.tensor(src).unsqueeze(0).to(device)
    ref_mask = torch.tensor(ref_mask).unsqueeze(0).to(device)
    src_mask = torch.tensor(src_mask).unsqueeze(0).to(device)

    output = model(ref, src, ref_mask, src_mask).squeeze(0)
    predicted_classes = torch.argmax(output, dim=1).tolist()

    prediction = predicted_classes[: predicted_classes.index(3) + 1]
    # prediction = prediction[: -1]
    # print(prediction)

    aln_mark = 0
    src_mark = 0
    align_result = []
    left = None
    for jdx, target in enumerate(seq1):
        if prediction[aln_mark] == 3: 
            left = list(range(jdx, len(seq1)))
            break
        elif prediction[aln_mark] == 2: 
            align_result.append((None, target))
            aln_mark += 1
            continue
        elif prediction[aln_mark] == 0:
            i = aln_mark
            while True:
                if prediction[i] == 0 and src_mark < len(seq2): 
                    align_result.append((seq2[src_mark], target))
                    i += 1
                    src_mark += 1
                else: break
            aln_mark = i
        
        if prediction[aln_mark] == 1:
            if src_mark < len(seq2):
                align_result.append((seq2[src_mark], target))
                aln_mark += 1
                src_mark += 1

    if left == None and jdx < len(seq1) - 1:
        left = list(range(jdx, len(seq1)))

    if left != None:
        for item in left: align_result.append((None, seq1[item]))

    if src_mark - 1 < len(seq2) - 1:
        for item in list(range(src_mark, len(seq2))): align_result.append((seq2[src_mark], None))

    return correct(align_result, seq1, seq2)


# Load Neural LCS Model


In [19]:
# Load the pre-trained Neural LCS model
model_path = "Neural-LCS/phn_lcs/model/phn_align_1.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PhonemeBoundaryAlignerT5(phoneme_vocab_size=len(tokenizer))
model.load_state_dict(torch.load(model_path, map_location = device))
model.to(device)

print(f"Model loaded on device: {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


Model loaded on device: cuda
Model parameters: 19,802,044


# Example Usage: Neural LCS Alignment


In [20]:
# Example: Compare Hard LCS vs Neural LCS
seq1 = 'EH V R IY W AH N IH Z T UW AH P S EH T UW K AA M EH N T'.split(" ")
seq2 = 'EH V R IY AE N IH Z T UW AH P S EH D T UW K AA M EH N T'.split(" ")

print("Reference sequence:", seq1)
print("Source sequence:", seq2)
print()

# Hard LCS alignment
hard_align = find_lcs_new(seq1, seq2)
print("Hard LCS alignment:")
for i, (src, ref) in enumerate(hard_align):
    src_str = str(src) if src is not None else "None"
    ref_str = str(ref) if ref is not None else "None"
    print(f"{i:2d}: {src_str:3s} -> {ref_str:3s}")

print()

# Neural LCS alignment
neural_align = neuralLCS(seq1, seq2, model)
print("Neural LCS alignment:")
for i, (src, ref) in enumerate(neural_align):
    src_str = str(src) if src is not None else "None"
    ref_str = str(ref) if ref is not None else "None"
    print(f"{i:2d}: {src_str:3s} -> {ref_str:3s}")


Reference sequence: ['EH', 'V', 'R', 'IY', 'W', 'AH', 'N', 'IH', 'Z', 'T', 'UW', 'AH', 'P', 'S', 'EH', 'T', 'UW', 'K', 'AA', 'M', 'EH', 'N', 'T']
Source sequence: ['EH', 'V', 'R', 'IY', 'AE', 'N', 'IH', 'Z', 'T', 'UW', 'AH', 'P', 'S', 'EH', 'D', 'T', 'UW', 'K', 'AA', 'M', 'EH', 'N', 'T']

Hard LCS alignment:
 0: EH  -> EH 
 1: V   -> V  
 2: R   -> R  
 3: IY  -> IY 
 4: AE  -> IY 
 5: N   -> N  
 6: IH  -> IH 
 7: Z   -> Z  
 8: T   -> T  
 9: UW  -> UW 
10: AH  -> AH 
11: P   -> P  
12: S   -> S  
13: EH  -> EH 
14: D   -> EH 
15: T   -> T  
16: UW  -> UW 
17: K   -> K  
18: AA  -> AA 
19: M   -> M  
20: EH  -> EH 
21: N   -> N  
22: T   -> T  

Neural LCS alignment:
 0: EH  -> EH 
 1: V   -> V  
 2: R   -> R  
 3: IY  -> IY 
 4: AE  -> IY 
 5: None -> W  
 6: None -> AH 
 7: N   -> N  
 8: IH  -> IH 
 9: Z   -> Z  
10: T   -> T  
11: UW  -> UW 
12: AH  -> AH 
13: P   -> P  
14: S   -> S  
15: EH  -> EH 
16: D   -> EH 
17: T   -> T  
18: UW  -> UW 
19: K   -> K  
20: AA  -> AA 
21: M

# Integration with WFST Decoding


In [21]:
# Enhanced WFST Decoding with Neural LCS Alignment
def enhanced_wfst_decode_with_neural_lcs(audio_file, ref_text_file, use_neural_lcs=True):
    """
    Enhanced WFST decoding that integrates Neural LCS alignment for better dysfluency detection.
    
    Args:
        audio_file: Path to audio file
        ref_text_file: Path to reference text file
        use_neural_lcs: Whether to use Neural LCS alignment (True) or Hard LCS (False)
    
    Returns:
        Dictionary containing decoding results and alignment information
    """
    
    # Load audio and text
    waveform, sample_rate = torchaudio.load(audio_file)
    with open(ref_text_file, 'r') as f:
        ref_text = f.read().strip()
    
    print(f"Processing: {audio_file}")
    print(f"Reference text: {ref_text}")
    print(f"Sample rate: {sample_rate}")
    
    # Load Wav2Vec2 model
    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-phon-cv-ft")
    wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-phon-cv-ft")
    device_wav2vec = "cuda:1" if torch.cuda.is_available() else "cpu"
    
    # Resample audio
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    waveform = waveform.squeeze()
    input_values = processor(waveform, return_tensors="pt", device=device_wav2vec, sampling_rate=16000).input_values
    
    # Get logits from Wav2Vec2
    with torch.no_grad():
        logits = wav2vec_model(input_values).logits
        print(f'Logits shape: {logits.shape}')
    
    # WFST Decoding (assuming you have the WFST decoder from main.ipynb)
    # This is a placeholder - you would need to import your WFST decoder
    # from utils.fst import WFSTdecoder
    # from utils.wper import W_PER
    
    # For now, we'll simulate the WFST decoding result
    # In practice, you would use your actual WFST decoder here
    simulated_result = {
        'id': 'example',
        'ref_phonemes': ['ð', 'e', 'l', 'l', 'ɛ', 'f', 't', 'ɚ', 'l', 'i'],
        'dys_detect': [
            {'phoneme': 'l', 'dysfluency_type': 'repetition'},
            {'phoneme': 'l', 'dysfluency_type': 'normal'},
            {'phoneme': 'ɛ', 'dysfluency_type': 'normal'},
            {'phoneme': 'f', 'dysfluency_type': 'normal'},
            {'phoneme': 't', 'dysfluency_type': 'normal'},
            {'phoneme': 'ɚ', 'dysfluency_type': 'normal'},
            {'phoneme': 'l', 'dysfluency_type': 'normal'},
            {'phoneme': 'i', 'dysfluency_type': 'normal'}
        ],
        'decode_phonemes': ['ð', 'e', 'l', 'l', 'ɛ', 'f', 't', 'ɚ', 'l', 'i'],
        'lattice': []
    }
    
    # Convert IPA to CMU for alignment
    ipa2cmu = json.load(open('config/ipa2cmu.json', 'r')) if os.path.exists('config/ipa2cmu.json') else {}
    
    def ipa_to_cmu(ipa_list):
        cmu_list = []
        for ipa in ipa_list:
            if ipa in ipa2cmu:
                cmu_value = ipa2cmu[ipa].split()[0]
                cmu_list.append(cmu_value)
            else:
                cmu_list.append(ipa)  # fallback to original
        return cmu_list
    
    ref_cmu = ipa_to_cmu(simulated_result['ref_phonemes'])
    decode_cmu = ipa_to_cmu(simulated_result['decode_phonemes'])
    
    print(f"Reference CMU: {ref_cmu}")
    print(f"Decoded CMU: {decode_cmu}")
    
    # Apply Neural LCS or Hard LCS alignment
    if use_neural_lcs:
        print("\\nUsing Neural LCS alignment...")
        alignment_result = neuralLCS(ref_cmu, decode_cmu, model)
        alignment_type = "Neural LCS"
    else:
        print("\\nUsing Hard LCS alignment...")
        alignment_result = find_lcs_new(ref_cmu, decode_cmu)
        alignment_type = "Hard LCS"
    
    print(f"\\n{alignment_type} alignment result:")
    for i, (src, ref) in enumerate(alignment_result):
        src_str = str(src) if src is not None else "None"
        ref_str = str(ref) if ref is not None else "None"
        print(f"{i:2d}: {src_str:3s} -> {ref_str:3s}")
    
    # Enhanced result with alignment information
    enhanced_result = {
        'original_result': simulated_result,
        'alignment_type': alignment_type,
        'alignment_result': alignment_result,
        'ref_cmu': ref_cmu,
        'decode_cmu': decode_cmu
    }
    
    return enhanced_result

# Example usage
# result = enhanced_wfst_decode_with_neural_lcs("data/audio/p088_4067.wav", "data/gt_text/p088_4067.txt", use_neural_lcs=True)


# Summary

This notebook integrates the Neural LCS alignment functionality from the Neural-LCS project with WFST decoding for enhanced dysfluency detection. The key components include:

1. **Hard LCS Algorithm**: Traditional longest common subsequence alignment
2. **Neural LCS Model**: T5-based neural alignment model for improved accuracy
3. **Phoneme Tokenizer**: Custom tokenizer for CMU phoneme sequences
4. **Enhanced WFST Decoding**: Integration function that combines WFST decoding with Neural LCS alignment

The integration provides:
- Better alignment accuracy through neural learning
- Enhanced dysfluency detection capabilities
- Flexible choice between hard and neural alignment methods
- Comprehensive result reporting with alignment information
