In [None]:
## Forced alignment for char frames

In [1]:
from transformers import Wav2Vec2ForCTC
from transformers import Wav2Vec2FeatureExtractor
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2Processor
from datasets import DatasetDict
from dataclasses import dataclass
import re,json
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pickle


In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
SAMPLERATE=16000
#VOCABJSON='IndicTimit_vocab.json'
VOCABJSON='vocab_960h.json'
PKLDIR='/home/ubuntu/manifold/new6/pkldata'
DATASETDIR='/home/ubuntu/manifold/datasets'
MODELDIR='/home/ubuntu/manifold/new6/models'

vocabjson=f'{PKLDIR}/{VOCABJSON}'

with open(vocabjson,'r') as f:
    vocabs=json.load(f)


In [3]:
def get_trellis(emission, tokens, blank_id=0):
    num_frame = emission.size(0)
    num_tokens = len(tokens)
    #print(f'num_tokens:{num_tokens}')
    # Trellis has extra diemsions for both time axis and tokens.
    # The extra dim for tokens represents <SoS> (start-of-sentence)
    # The extra dim for time axis is for simplification of the code.
    trellis = torch.empty((num_frame + 1, num_tokens + 1))
    trellis[0, 0] = 0
    trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
    trellis[0, -num_tokens:] = -float("inf")
    trellis[-num_tokens:, 0] = float("inf")

    for t in range(num_frame):
        trellis[t + 1, 1:] = torch.maximum(
            # Score for staying at the same token
            trellis[t, 1:] + emission[t, blank_id],
            # Score for changing to the next token
            trellis[t, :-1] + emission[t, tokens],
        )            
        
    return trellis

@dataclass
class Point:
    token_index: int
    time_index: int
    score: float


def backtrack(trellis, emission, tokens, blank_id=0):
    j = trellis.size(1) - 1
    t_start = torch.argmax(trellis[:, j]).item()

    path = []
    for t in range(t_start, 0, -1):
        stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
        changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
        prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
        path.append(Point(j - 1, t - 1, prob))

        if changed > stayed:
            j -= 1
            if j == 0:
                break
    else:
        raise ValueError("Failed to align")
    return path[::-1]

    
# Merge the labels
@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"

    @property
    def length(self):
        return self.end - self.start


def merge_repeats(path,transcript):
    i1, i2 = 0, 0
    segments = []
    while i1 < len(path):
        while i2 < len(path) and path[i1].token_index == path[i2].token_index:
            i2 += 1
        score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
        segments.append(
            Segment(
                transcript[path[i1].token_index],
                path[i1].time_index,
                path[i2 - 1].time_index + 1,
                score,
            )
        )
        i1 = i2
    return segments

# Merge words
def merge_words(segments, separator="|"):
    words = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                word = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
                words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return words
    


# In[ ]:


def plot_alignments(trellis, segments, word_segments, waveform, sampling_rate=16000):
    trellis_with_path = trellis.clone()
    for i, seg in enumerate(segments):
        if seg.label != "|":
            trellis_with_path[seg.start + 1 : seg.end + 1, i + 1] = float("nan")

    fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))

    ax1.imshow(trellis_with_path[1:, 1:].T, origin="lower")
    ax1.set_xticks([])
    ax1.set_yticks([])

    for word in word_segments:
        ax1.axvline(word.start - 0.5)
        ax1.axvline(word.end - 0.5)

    # The original waveform
    ratio = waveform.size(0) / (trellis.size(0) - 1)
    ax2.plot(waveform)
    for word in word_segments:
        x0 = ratio * word.start
        x1 = ratio * word.end
        #ax2.axvspan(x0, x1, alpha=0.1, edgecolor="blue", ls='--',lw=2)
        
        ax2.axvline(x0,mfc='red',ls='--',lw=2)
        ax2.axvline(x1,mfc='red',ls='--',lw=2)
    xticks = ax2.get_xticks()
    plt.xticks(xticks, xticks / sampling_rate)
    ax2.set_xlabel("time [second]")
    ax2.set_yticks([])
    ax2.set_ylim(-1.0, 1.0)
    ax2.set_xlim(0, waveform.size(-1))
    plt.show()


In [8]:
LANG_SET=['HIN']  # 

for LANG in LANG_SET:
    DATASET=f'dataset_indic_timit_{LANG}'
    # modelname=f'model_w2v2base_indic_timit_{LANG}'   
    # MODELFILE=f'{MODELDIR}/{modelname}'
    # print(f'Preparing pkldata for {LANG} using ref model {MODELFILE}...')
    
    for DATATYPE in ['train','val','test']:
        
        #PKLFILE=f'{PKLDIR}/timit_support_set_{LANG}_w2v2base_{DATATYPE}.pkl'
        PKLFILE=f'{PKLDIR}/timit_support_set_{LANG}_960h_{DATATYPE}_1.pkl'

        xxx_dataset=f'{DATASETDIR}/{DATASET}'
        
        dset=DatasetDict().load_from_disk(xxx_dataset)

        with torch.no_grad():
            #model = Wav2Vec2ForCTC.from_pretrained(MODELFILE).to(device)

            processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
            model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h").to(device)

        #tokenizer = Wav2Vec2CTCTokenizer(vocabjson, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
        #feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
        #processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

        charfeats={}
        
        checklen={}
        for ky in vocabs.keys():
            if ky in ["<pad>","<s>","</s>","<unk>","'"]: continue
            checklen[ky]=False
        lenupto=10000
        cflag=False
        num_rows=dset[DATATYPE].num_rows
        print(f'--> {DATATYPE} : {num_rows} rows')
        for idx in range(num_rows):
            print(f'\t--> {idx}', end='\r')
            #for ky in vocabs.keys():
            for ky in checklen.keys():
                if checklen[ky]:
                    cflag=False
                    continue
                else:
                    cflag=True
                    break
            if not cflag:
                break
            else:
                waveform=torch.from_numpy(dset[DATATYPE][idx]['audio']['array'].astype(np.float32)).unsqueeze(0)
                text=dset[DATATYPE]['text'][idx]
                text_lst=text.split(' ')
                ntext='|'.join(text_lst)

                with torch.no_grad():
                    emissions = model(waveform.to(device)).logits
                emits=torch.log_softmax(emissions,dim=-1)
                emits=emits[0].cpu().detach()
                labels = processor(text=ntext.upper()).input_ids
                
                trellis=get_trellis(emits,labels)
                path = backtrack(trellis, emits, labels)
                segments = merge_repeats(path,ntext.upper())

                with torch.no_grad():
                    #features
                    fout=model.wav2vec2.feature_extractor(waveform.to(device))
                    #projection
                    fpout=model.wav2vec2.feature_projection(fout.permute(0,2,1))[0]
                    #encoder out
                    eout=model.wav2vec2.encoder(fpout).last_hidden_state

                # prepare char-encfeat dict
                for segm in segments:
                    #print(segm.label,segm.start,segm.end)
                    if not checklen[segm.label]:
                        for fr in range(segm.start,segm.end):
                            try:
                                charfeats[segm.label].append(eout[:,fr,:])
                            except:
                                charfeats[segm.label]=[eout[:,fr,:]]
                        if len(charfeats[segm.label])>lenupto:
                            checklen[segm.label]=True


        ncharfeats={}
        charlens={}
        for kys in charfeats.keys():
            charlens[kys.lower()]=len(charfeats[kys])
            feats=torch.stack(charfeats[kys]).squeeze(1).cpu()
            ncharfeats[kys.lower()]=feats

        with open(PKLFILE,'wb') as f:
            pickle.dump((ncharfeats,charlens),f)

        print(f'{LANG}-{DATATYPE} Completed...\nsaved:{PKLFILE}')
    print('------------')


--> train


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


HIN-train Completed...
saved:/home/ubuntu/manifold/new6/pkldata/timit_support_set_HIN_960h_train_1.pkl
--> val


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


HIN-val Completed...
saved:/home/ubuntu/manifold/new6/pkldata/timit_support_set_HIN_960h_val_1.pkl
--> test


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


HIN-test Completed...
saved:/home/ubuntu/manifold/new6/pkldata/timit_support_set_HIN_960h_test_1.pkl
------------


In [10]:
print(charlens),len(charlens)

{'T': 10001, 'H': 10001, 'E': 10005, 'Y': 10003, '|': 10005, 'P': 9209, 'O': 10001, 'L': 10004, 'I': 10004, 'S': 10003, 'D': 10001, 'W': 8835, 'N': 10001, 'R': 10001, 'C': 10002, 'A': 10006, 'F': 9239, 'B': 6476, 'U': 10003, 'M': 10001, 'V': 3616, 'G': 9956, 'Z': 557, 'K': 3823, 'J': 666, 'X': 2075, 'Q': 345}


(None, 27)

In [7]:
vocabs.keys()

dict_keys(['<pad>', '<s>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])

In [None]:
ntext

In [None]:
print(labels)

In [None]:
ntext.upper()

In [None]:
labels1 = processor(text=ntext.upper()).input_ids
print(labels1)

In [None]:
vocabs_960h={"<pad>": 0, "<s>": 1, "</s>": 2, "<unk>": 3, "|": 4, "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, "'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}
