### you can just run the code step by step

In [7]:
import warnings
import numpy as np
import torch
import torch.nn as nn
from transformers import PreTrainedTokenizerFast, T5EncoderModel, T5Config
import os
import json
from copy import deepcopy

#### hard lcs

In [1]:
def find_lcs_new(seq1, seq2):
        """Find the longest common subsequence between two sequences using embedding distance."""
        ## seq1 is target
        ## seq2 is source
        
        lengths = [[0] * (len(seq2) + 1) for _ in range(len(seq1) + 1)]
        for i, x in enumerate(seq1):
            for j, y in enumerate(seq2):
                if  x == y:  # Adjust the threshold as needed
                    ### Similar
                    lengths[i+1][j+1] = lengths[i][j] + 1
                else:
                    ### Not Similar
                    lengths[i+1][j+1] = max(lengths[i+1][j], lengths[i][j+1])
        
        # Reconstruct the LCS
        align_lcs_result = [] #tuple (src, target)
        x, y = len(seq1), len(seq2)
        
        
        while x != 0 and y != 0:

            if lengths[x][y] == lengths[x-1][y]:
                x -= 1
            elif lengths[x][y] == lengths[x][y-1]:
                align_lcs_result.append((seq2[y - 1], seq1[x - 1]))
                y -= 1
            else:
                align_lcs_result.append((seq2[y - 1], seq1[x - 1]))
                x -= 1
                y -= 1
                
            if x == 0 and y != 0:
                while y > 0: 
                    align_lcs_result.append((seq2[y - 1], seq1[0]))
                    y -= 1
                break
            if x != 0 and y == 0:
                align_lcs_result.append((seq2[0], seq1[x - 1]))
                break
        
        align_lcs_result.reverse()

        return align_lcs_result

def insert_missing_tuples(ref_words, align_result):
    # Extract the second elements from the align_result tuples
    align_words = []
    for t in align_result:
        if len(align_words) == 0:
            align_words.append(t[1])
        elif align_words[-1] != t[1]:
            align_words.append(t[1])
            
    missing_flg = [1] * len(ref_words)

    for i, word in enumerate(ref_words):
        missing_num = sum(missing_flg[:i])
        j = i - missing_num
        if j <len(align_words) and align_words[j] == word:
            missing_flg[i] = 0

    # Create a copy of align_result to insert missing words
    new_align_result = []
    ref_idx = 0  # Index to track the current position in ref_words

    for idx, word in enumerate(ref_words):
        # Insert missing words
        if missing_flg[idx] == 1:
            new_align_result.append((None, word))
        # Add the next word from align_result if it matches
        while ref_idx < len(align_result) and align_result[ref_idx][1] == word:
            new_align_result.append(align_result[ref_idx])
            ref_idx += 1

    return new_align_result

In [2]:
# CMU pho2dict
phoneme_to_id = {
    "AA": 0, "AE": 1, "AH": 2, "AO": 3, "AW": 4, "AY": 5,
    "B": 6, "CH": 7, "D": 8, "DH": 9, "EH": 10, "ER": 11,
    "EY": 12, "F": 13, "G": 14, "HH": 15, "IH": 16, "IY": 17,
    "JH": 18, "K": 19, "L": 20, "M": 21, "N": 22, "NG": 23,
    "OW": 24, "OY": 25, "P": 26, "R": 27, "S": 28, "SH": 29,
    "T": 30, "TH": 31, "UH": 32, "UW": 33, "V": 34, "W": 35,
    "Y": 36, "Z": 37, "ZH": 38,
    "<pad>": 39, "<unk>": 40, "<cls>": 41, "<sep>": 42
}

### customed tokenizer

In [4]:
from transformers import PreTrainedTokenizer

class PhonemeTokenizer(PreTrainedTokenizer):
    def __init__(self, phoneme_to_id, **kwargs):
        self.phoneme_to_id = phoneme_to_id
        super().__init__(**kwargs)
        self.id_to_phoneme = {v: k for k, v in phoneme_to_id.items()}
        self.pad_token = "<pad>"
        self.unk_token = "<unk>"
        self.cls_token = "<cls>"
        self.sep_token = "<sep>"
        
    def get_vocab(self):
        # return vocabulary table
        return self.phoneme_to_id
        
    def _convert_token_to_id(self, token):
        # Convert a single phoneme to its ID
        return self.phoneme_to_id.get(token, self.phoneme_to_id.get(self.unk_token))

    def _convert_id_to_token(self, index):
        # Convert a single ID back to a phoneme
        return self.id_to_phoneme.get(index, self.unk_token)

    def _tokenize(self, text):
        # Split text into phonemes and map them to IDs
        return [self.phoneme_to_id.get(phoneme, self.phoneme_to_id.get(self.unk_token)) for phoneme in text.split()]

    def encode(self, text, max_length = 120, add_special_tokens = True, padding = True):
        max_len = max_length
        token_ids = self._tokenize(text)
        if add_special_tokens:
            token_ids = token_ids + [self.phoneme_to_id[self.sep_token]]
        if padding:
            prev_len = len(token_ids)
            token_ids = token_ids + [self.phoneme_to_id[self.pad_token]] * (max_len - prev_len)
            mask = [1] * prev_len + [0] * (max_len - prev_len)
        if padding:
            return {"input_ids": torch.tensor(token_ids), "attention_mask": torch.tensor(mask)}
        else:
            return {"input_ids": torch.tensor(token_ids), "attention_mask": None}
                

    def decode(self, token_ids, skip_special_tokens=True):
        tokens = [self.id_to_phoneme[token_id] for token_id in token_ids if token_id in self.id_to_phoneme]
        if skip_special_tokens:
            tokens = [token for token in tokens if token not in [self.pad_token, self.cls_token, self.sep_token]]
        return " ".join(tokens)

    def __len__(self):
        return len(self.phoneme_to_id)

tokenizer = PhonemeTokenizer(phoneme_to_id)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import json
from torch.utils.data import Dataset, DataLoader, random_split

### structure of Phoneme soft aligner

In [8]:
class PhonemeBoundaryAlignerT5(nn.Module):
    def __init__(self, pretrained_model_name="t5-small", phoneme_vocab_size=50, hidden_dim=512, num_filters = 16):
        super(PhonemeBoundaryAlignerT5, self).__init__()
        # Load pretrained T5 model
        self.encoder = T5EncoderModel.from_pretrained(pretrained_model_name)
        
        # Resize token embeddings to fit phoneme vocab size
        self.encoder.resize_token_embeddings(phoneme_vocab_size)

        # 1D CNN Layer
        self.conv1d = nn.Conv1d(
            in_channels = hidden_dim * 2,
            out_channels = num_filters,
            kernel_size = 3,
            stride = 1,
            padding = 1
        )
                
        # Boundary Predictor: Fully connected layers
        self.MLP = nn.Sequential(
            nn.ReLU(),
            nn.Linear(num_filters, num_filters//2),
            nn.ReLU(),
            nn.Linear(num_filters//2, 4),
        )
        
        # Re-initialize weights if needed
        self._init_weights()

    def forward(self, ref, src, ref_mask=None, src_mask=None):
        # Encode reference and source phonemes using T5
        ref_output = self.encoder(input_ids=ref, attention_mask=ref_mask).last_hidden_state  # (batch_size, src_len, hidden_dim)
        src_output = self.encoder(input_ids=src, attention_mask=src_mask).last_hidden_state  # (batch_size, src_len, hidden_dim)
        
        # Combine ref and src features
        alignment_features = torch.cat((ref_output, src_output), dim=-1)  # (batch_size, src_len, hidden_dim * 2)
        alignment_features = alignment_features.permute(0, 2, 1)                               # (batch_size, hidden_dim * 2, src_len)
        
        # print(alignment_features.shape)
        
        # Process with Conv layer to capture contextual information
        conv_features = self.conv1d(alignment_features)     # (batch_size, num_filters, src_len) 
        conv_features = conv_features.permute(0, 2, 1)                      # (batch_size, src_len, num_filters)
        
        # Through MLP layer
        mlp_features = self.MLP(conv_features) # (batch_size, src_len, 3)
        
        return mlp_features
    
    def _init_weights(self):
        """
        Reinitialize model parameters for layers other than embeddings.
        """
        def init_weights(m):
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)
            elif isinstance(m, torch.nn.Conv1d):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, torch.nn.Embedding):
                torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        
        # Apply custom weight initialization to layers
        self.apply(init_weights)

#### Fine-tune Neural lcs with hard lcs to enhance robustness

In [16]:
def correct(align_result, ref, src):
    if align_result[-1][1] == ref[-1] and align_result[-1][0] == src[-1]: return align_result
    elif align_result[-1][1] == align_result[-1][0]: return align_result
    else:
        mark = 0
        split_bd = None
        split_sign = None
        idx = len(align_result) - 1
        while idx > 0:
            if align_result[idx][0] == align_result[idx][1]: 
                mark += 1
            if mark == 2:
                j = idx - 1
                while j >= 0 and align_result[j][1] == align_result[idx][1]:
                    j -= 1
                split_bd = align_result[j][1]
                split_sign = j
                break
            idx -= 1 
        # print(split_bd)
        if mark != 2 or split_bd == None: return align_result

        align_right = deepcopy(align_result)
        
        idx = len(align_result) - 1
        new_src = []
        new_ref = []
        while idx >= 0 and idx != split_sign:
            align_right.pop()
            if align_result[idx][0] != None:
                new_src.insert(0, align_result[idx][0])
            idx -= 1
        
        idx = len(align_result) - 1
        while idx > 0 and idx != split_sign:
            if align_result[idx][1] == None:
                idx -= 1
                continue
            jdx = idx
            while jdx >= 0 and align_result[jdx][1] == align_result[idx][1]:
                jdx -= 1
            jdx += 1
            idx = jdx
            new_ref.insert(0, align_result[jdx][1])
            idx -= 1

        # print(new_ref, new_src)

        new_result = align_right + insert_missing_tuples(new_ref, find_lcs_new(new_ref, new_src))
        return new_result

In [17]:
model_path = "model/phn_align_2.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PhonemeBoundaryAlignerT5(phoneme_vocab_size=len(tokenizer))
model.load_state_dict(torch.load(model_path, map_location = device))
model.to(device)

def neuralLCS(seq1, seq2 ,model):
    ## seq1 is target
    ## seq2 is source
    tokenizer = PhonemeTokenizer(phoneme_to_id)
    ref = tokenizer.encode(" ".join(seq1), max_length = 120)["input_ids"]
    src = tokenizer.encode(" ".join(seq2), max_length = 120)["input_ids"]
    ref_mask = tokenizer.encode(" ".join(seq1), max_length = 120)["attention_mask"]
    src_mask = tokenizer.encode(" ".join(seq2), max_length = 120)["attention_mask"]

    model.eval()

    ref = torch.tensor(ref).unsqueeze(0).to(device)
    src = torch.tensor(src).unsqueeze(0).to(device)
    ref_mask = torch.tensor(ref_mask).unsqueeze(0).to(device)
    src_mask = torch.tensor(src_mask).unsqueeze(0).to(device)

    output = model(ref, src, ref_mask, src_mask).squeeze(0)
    predicted_classes = torch.argmax(output, dim=1).tolist()

    prediction = predicted_classes[: predicted_classes.index(3) + 1]
    # prediction = prediction[: -1]
    # print(prediction)

    aln_mark = 0
    src_mark = 0
    align_result = []
    left = None
    for jdx, target in enumerate(seq1):
        if prediction[aln_mark] == 3: 
            left = list(range(jdx, len(seq1)))
            break
        elif prediction[aln_mark] == 2: 
            align_result.append((None, target))
            aln_mark += 1
            continue
        elif prediction[aln_mark] == 0:
            i = aln_mark
            while True:
                if prediction[i] == 0 and src_mark < len(seq2): 
                    align_result.append((seq2[src_mark], target))
                    i += 1
                    src_mark += 1
                else: break
            aln_mark = i
        
        if prediction[aln_mark] == 1:
            if src_mark < len(seq2):
                align_result.append((seq2[src_mark], target))
                aln_mark += 1
                src_mark += 1

    if left == None and jdx < len(seq1) - 1:
        left = list(range(jdx, len(seq1)))

    if left != None:
        for item in left: align_result.append((None, seq1[item]))

    if src_mark - 1 < len(seq2) - 1:
        for item in list(range(src_mark, len(seq2))): align_result.append((seq2[src_mark], None))

    return correct(align_result, seq1, seq2)

seq1 = 'EH V R IY W AH N IH Z T UW AH P S EH T UW K AA M EH N T'.split(" ")
seq2 = 'EH V R IY AE N IH Z T UW AH P S EH D T UW K AA M EH N T'.split(" ")

align_nn_lcs = neuralLCS(seq1, seq2, model)
print(align_nn_lcs)

  model.load_state_dict(torch.load(model_path, map_location = device))


[('EH', 'EH'), ('V', 'V'), ('R', 'R'), ('IY', 'IY'), ('AE', 'IY'), (None, 'W'), (None, 'AH'), ('N', 'N'), ('IH', 'IH'), ('Z', 'Z'), ('T', 'T'), ('UW', 'UW'), ('AH', 'AH'), ('P', 'P'), ('S', 'S'), ('EH', 'EH'), ('D', 'EH'), ('T', 'T'), ('UW', 'UW'), ('K', 'K'), ('AA', 'AA'), ('M', 'M'), ('EH', 'EH'), ('N', 'N'), ('T', 'T')]


  ref = torch.tensor(ref).unsqueeze(0).to(device)
  src = torch.tensor(src).unsqueeze(0).to(device)
  ref_mask = torch.tensor(ref_mask).unsqueeze(0).to(device)
  src_mask = torch.tensor(src_mask).unsqueeze(0).to(device)
