<a href="https://colab.research.google.com/github/alanjding/old-chinese-g2p/blob/main/components.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install suparkanbun

Collecting suparkanbun
  Downloading suparkanbun-1.2.7-py3-none-any.whl (853 kB)
[K     |████████████████████████████████| 853 kB 5.4 MB/s 
Collecting transformers>=4.0.1
  Downloading transformers-4.9.2-py3-none-any.whl (2.6 MB)
[K     |████████████████████████████████| 2.6 MB 40.3 MB/s 
[?25hCollecting deplacy>=2.0.0
  Downloading deplacy-2.0.0-py3-none-any.whl (22 kB)
Collecting supar>=1.1.2
  Downloading supar-1.1.2-py3-none-any.whl (87 kB)
[K     |████████████████████████████████| 87 kB 5.4 MB/s 
Collecting stanza
  Downloading stanza-1.2.3-py3-none-any.whl (342 kB)
[K     |████████████████████████████████| 342 kB 52.9 MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 40.7 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)
[K     |███████████████████████████

In [None]:
import json
import numpy as np
import pandas as pd
import os

from suparkanbun.simplify import simplify

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

from tqdm.auto import tqdm

from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel, AutoModelForSequenceClassification, TrainingArguments, Trainer, AdamW, get_scheduler

In [None]:
traditionalize = {v: k for k, v in simplify.items()}

def simplify_with_default(c):
    if c in simplify:
        return simplify[c]
    else:
        return c

def trad_to_simp(sequence):
    return ''.join([simplify_with_default(c) for c in sequence])

def traditionalize_with_default(c):
    if c in traditionalize:
        return traditionalize[c]
    else:
        return c

def simp_to_trad(sequence):
    return ''.join([traditionalize_with_default(c) for c in sequence])

In [None]:
def one_hot(i, dim=5):
    if i < 0 or i >= dim:
        return np.zeros(dim)
    else:
        return np.eye(dim)[i]
    
def encode(sequence):
    sequence = trad_to_simp(sequence)
    return tokenizer.encode(sequence, return_tensors='pt', max_length=66, padding='max_length', truncation=True)[0]

# Masked LM for unseen characters

It is known that, often times, a given word with the same meaning and pronunciation may be represented by different characters in writing. While definitely far from a perfect solution to requiring that every character be assigned a pronunciation, we can take advantage of this fact to provide an imperfect guess for any character whose pronunciation does not exist in the Baxter-Sagart reconstruction. In particular, for such a character, we replace it with a mask token and allow GuwenBERT to estimate the probability that a given character would best replace the mask token. We take the highest-probability character that is also in Baxter-Sagart and then perform polyphone disambiguation if necessary. We demonstrate this functionality below.

In [None]:
tokenizer = AutoTokenizer.from_pretrained("ethanyt/guwenbert-base")
masked_lm_model = AutoModelForMaskedLM.from_pretrained("ethanyt/guwenbert-base")

Downloading:   0%|          | 0.00/519 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/93.5k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/418M [00:00<?, ?B/s]

In [None]:
baxter_path = './sound_table_v2.json'
sample_text_path = './KR1a0001_001.txt'

with open(baxter_path, 'r', encoding='utf-8') as f:
    table = json.load(f)

with open(sample_text_path, 'r', encoding='utf-8') as f:
    sequence = ''.join(f).replace('\n', '')

print(sequence)

䷀乾下乾上乾元亨利貞初九潛龍勿用九二見龍在田利見大人九三君子終日乾乾夕惕若厲无咎九四或躍在淵无咎九五飛龍在天利見大人上九亢龍有悔用九見群龍无首吉彖曰大哉乾元萬物資始乃統天雲行雨施品物流形大明終始六位時成時乘六龍以御天乾道變化各正性命保合大和乃利貞首出庶物萬國咸寧象曰天行健君子以自強不息潛龍勿用陽在下也見龍在田德施普也終日乾乾反復道也或躍在淵進无咎也飛龍在天大人造也亢龍有悔盈不可久也用九天德不可為首也文言曰元者善之長也亨者嘉之會也利者義之和也貞者事之幹也君子體仁足以長人嘉會足以合禮利物足以和義貞固足以幹事君子行此四德者故曰乾元亨利貞初九曰潛龍勿用何謂也子曰龍德而隱者也不易乎世不成乎名遯世无悶不見是而无悶樂則行之憂則違之確乎其不可拔潛龍也九二曰見龍在田利見大人何謂也子曰龍德而正中者也庸言之信庸行之謹閑邪存其誠善世而不伐德博而化易曰見龍在田利見大人君德也九三曰君子終日乾乾夕惕若厲无咎何謂也子曰君子進德脩業忠信所以進德也脩辭立其誠所以居業也知至至之可與幾也知終終之可與存義也是故居上位而不驕在下位而不憂故乾乾因其時而惕雖危无咎矣九四曰或躍在淵无咎何謂也子曰上下无常非為邪也進退无恆非離群也君子進德脩業欲及時也故无咎九五曰飛龍在天利見大人何謂也子曰同聲相應同氣相求水流濕火就燥雲從龍風從虎聖人作而萬物覩本乎天者親上本乎地者親下則各從其類也上九曰亢龍有悔何謂也子曰貴而无位高而无民賢人在下位而无輔是以動而有悔也潛龍勿用下也見龍在田時舍也終日乾乾行事也或躍在淵自試也飛龍在天上治也亢龍有悔窮之災也乾元用九天下治也潛龍勿用陽氣潛藏見龍在田天下文明終日乾乾與時偕行或躍在淵乾道乃革飛龍在天乃位乎天德亢龍有悔與時偕極乾元用九乃見天則乾元亨者始而亨者也利貞者性情也乾始能以美利利天下不言所利大矣哉大哉乾乎剛健中正純粹精也六爻發揮旁通情也時乘六龍以御天也雲行雨施天下平也君子以成德為行日可見之行也潛之為言也隱而未見行而未成是以君子弗用也君子學以聚之問以辯之寬以居之仁以行之易曰見龍在田利見大人君德也九三重剛而不中上不在天下不在田故乾乾因其時而惕雖危无咎矣九四重剛而不中上不在天下不在田中不在人故或之或之者疑之也故无咎夫大人者與天地合其德與日月合其明與四時合其序與鬼神合其吉凶先天而天弗違後天而奉天時天且弗違而況於人乎況於鬼神乎亢之為言也知進而不知退知存而不知亡知得而不知喪其唯聖人乎知進退存亡而不失其正者其唯聖人

In [None]:
def masked_lm(seq, table, debug_print=False):
    # indices where character is not in Baxter-Sagart
    no_baxter_idx = np.where([c not in table for c in seq])[0]
    if debug_print:
        print('Indices with no Baxter reading:', no_baxter_idx)

    simp_sequence = trad_to_simp(seq)

    # indices where character is not in the GuwenBERT lexicon
    # (these are likely special characters with no pronunciation, especially if the index is also in no_baxter_idx)
    if debug_print:
        not_bert_lex_idx = np.where(np.array(tokenizer.encode(simp_sequence)[1:-1]) <= 3)[0]
        print('Indices containing characters outside of the GuwenBERT lexicon:', not_bert_lex_idx)

    print()

    for mask_idx in no_baxter_idx:
        character = seq[mask_idx]
        if debug_print:
            print('Index:', mask_idx, '\nCharacter:', character)
            print()

        if tokenizer.encode(trad_to_simp(character))[1] <= 3:
            if debug_print:
                print(character, 'is outside of the GuwenBERT lexicon, moving on to next character\n')
            continue

        masked_sequence = simp_sequence[:mask_idx] + tokenizer.mask_token + simp_sequence[mask_idx + 1:]
        window_radius = 32
        subset = masked_sequence[max(0, mask_idx - window_radius):min(len(masked_sequence), mask_idx + window_radius)]

        inputs = tokenizer.encode(subset, return_tensors='pt')
        mask_token_index = torch.where(inputs == tokenizer.mask_token_id)[1]
        mask_token_logits = masked_lm_model(inputs).logits[0, mask_token_index, :]

        top_tokens = torch.topk(mask_token_logits, 100, dim=1).indices[0].tolist()

        if debug_print:
            for token in top_tokens[:5]:
                pred = tokenizer.decode([token])
                print(subset.replace(tokenizer.mask_token, pred))
                trad_pred = simp_to_trad(pred)
                print('Predicted character (traditional):', trad_pred)
                print('Has Baxter reading?', trad_pred in table)
                print()
        
        for token in top_tokens:
            pred = tokenizer.decode([token])
            trad_pred = simp_to_trad(pred)
            if trad_pred in table:
                seq = seq[:mask_idx] + trad_pred + seq[mask_idx + 1:]
                break

    return seq

masked_lm(sequence, table, debug_print=False)




'䷀乾下乾上乾元亨利貞初九潛龍勿用九二見龍在田利見大人九三君子終日乾乾夕惕若厲无咎九四或躍在淵无咎九五飛龍在天利見大人上九亢龍有悔用九見群龍无首吉彖曰大哉乾元萬物資始乃法天雲行雨施品物流形大明終始六位時成時乘六龍以御天乾道變化各正性命保合大和乃利貞首出庶物萬國咸寧象曰天行也君子以自強不息潛龍勿用陽在下也見龍在田德施普也終日乾乾反復道也或躍在淵進无咎也飛龍在天大人造也亢龍有悔盈不可久也用九天德不可為首也文言曰元者善之長也亨者嘉之會也利者義之和也貞者事之幹也君子體仁足以長人嘉會足以合禮利物足以和義貞固足以幹事君子行此四德者故曰乾元亨利貞初九曰潛龍勿用何謂也子曰龍德而隱者也不易乎世不成乎名遯世无悶不見是而无悶樂則行之憂則違之確乎其不可拔潛龍也九二曰見龍在田利見大人何謂也子曰龍德而正中者也庸言之信庸行之謹閑邪存其誠善世而不伐德博而化易曰見龍在田利見大人君德也九三曰君子終日乾乾夕惕若厲无咎何謂也子曰君子進德脩業忠信所以進德也脩辭立其誠所以居業也知至至之可與幾也知終終之可與存義也是故居上位而不驕在下位而不憂故乾乾因其時而惕雖危无咎矣九四曰或躍在淵无咎何謂也子曰上下无常非為邪也進退无恆非離群也君子進德脩業欲及時也故无咎九五曰飛龍在天利見大人何謂也子曰同聲相應同氣相求水流濕火就燥雲從龍風從虎聖人作而萬物覩本乎天者親上本乎地者親下則各從其類也上九曰亢龍有悔何謂也子曰貴而无位高而无民賢人在下位而无輔是以動而有悔也潛龍勿用下也見龍在田時舍也終日乾乾行事也或躍在淵自試也飛龍在天上治也亢龍有悔窮之災也乾元用九天下治也潛龍勿用陽氣潛藏見龍在田天下文明終日乾乾與時偕行或躍在淵乾道乃革飛龍在天乃位乎天德亢龍有悔與時偕極乾元用九乃見天則乾元亨者始而亨者也利貞者性情也乾始能以美利利天下不言所利大矣哉大哉乾乎剛柔中正純粹精也六爻發揮旁通情也時乘六龍以御天也雲行雨施天下平也君子以成德為行日可見之行也潛之為言也隱而未見行而未成是以君子弗用也君子學以聚之問以辯之寬以居之仁以行之易曰見龍在田利見大人君德也九三重剛而不中上不在天下不在田故乾乾因其時而惕雖危无咎矣九四重剛而不中上不在天下不在田中不在人故或之或之者疑之也故无咎夫大人者與天地合其德與日月合其明與四時合其序與鬼神合其吉凶先天而天弗違後天而奉天時天且弗違而況於人乎況於鬼神乎亢之為言也知進而不知退知存而不知亡知得而不知喪其唯聖人乎知進退存亡而不失其正者其唯聖

# Polyphone disambiguation

The following code trains and evaluates a model for polyphone disambiguation. This model is an implementation of the LSTM classifier from [Dai et al. (2019)](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2292.pdf).

 With a more robust set of labelled data, of course, we would want to tune hyperparameters further. Good candidates for tuning, as usual, include the learning rate and schedule as well as the number of epochs to train the data over. We can also change the LSTM hidden layer size based on whether the model appears to be overfitting or too inflexible. Finally, it may be interesting to see how the model trains with different loss functions, namely the Modified Focal Loss proposed in [Zhang et al. (2020)](http://www.interspeech2020.org/uploadfile/2020/1021/20201021034849937.pdf), which appears to be better at picking up less common pronunciations.

In [None]:
df = pd.read_csv('./polyphones-10k-partiallylabelled.csv')

# reconstruct index of polyphone from labelled dataset
def get_context_index(context, narrow_context, polyphone):
    narrow_context_start = context.find(narrow_context[:6])
    index = narrow_context_start + 5
    if context[index] == polyphone:
        return index
    else:
        index = narrow_context_start + len(narrow_context) - 6
        if context[index] == polyphone:
            return index
        else:
            return -1

df = df[df['index'] != -1]

df['location'] = df.apply(lambda row: get_context_index(row.context, row.narrow_context, row.polyphone), axis=1)

# generate masks for available index choices
def get_mask(num_pronunciations):
    arr = torch.zeros(5)
    for i in range(num_pronunciations):
        arr[i] = 1
    
    return arr

df['mask'] = df.polyphone.apply(lambda p: get_mask(len(table[p])))

In [None]:
df[:5]

Unnamed: 0.1,Unnamed: 0,source_text,polyphone,location,narrow_context,context,index,reading,gloss,mask
0,735015,中論,風,22,都託之乎觀風然而好變易,無倫而辭察託之乎通理居必人才遊必帝都託之乎觀風然而好變易姓名求之難獲託之乎能靜卑屈其體輯柔其...,0,prəm,wind (n.),"[tensor(1.), tensor(1.), tensor(0.), tensor(0...."
1,733961,中論,深,18,可誣哉故根深而枝葉茂行,者莫不見也有耳者莫不聞也其可誣哉故根深而枝葉茂行久而名譽逺易曰恒亨無咎利貞言久於其道也伊尹放...,1,◦ləms,,"[tensor(1.), tensor(1.), tensor(0.), tensor(0...."
2,734361,中論,質,30,藝乎旣脩其質且加其文文,菁菁者莪在彼中阿旣見君子樂且有儀美育■材其猶人之於藝乎旣脩其質且加其文文質著然後體全體全然後...,1,t-lit,"substance, solid part","[tensor(1.), tensor(1.), tensor(0.), tensor(0...."
3,734976,中論,於,26,狎之斯術之於斯民也猶内,定䘮其故性而不自知其迷也咸相與祖述其業而寵狎之斯術之於斯民也猶内關之疾也非有痛癢煩苛於身情志...,1,ʔa,at (locative preposition),"[tensor(1.), tensor(1.), tensor(0.), tensor(0...."
4,737071,中論,所,26,卿六遂之法所以維持其民,欲樹藝也雖有良農安所措其疆力乎是以先王制六卿六遂之法所以維持其民而爲之綱目也,0,s-qʰraʔ,place (n.); that which,"[tensor(1.), tensor(1.), tensor(0.), tensor(0...."


In [None]:
class PolyphoneDisambiguationDataset(Dataset):
    def __init__(self, contexts, locations, masks, labels):
        self.contexts = contexts
        self.locations = locations
        self.masks = masks
        self.labels = labels
        
    def __len__(self):
        return len(self.contexts)
    
    def __getitem__(self, i):
        context = self.contexts.iloc[i]
        location = self.locations.iloc[i]
        mask = self.masks.iloc[i]
        label = self.labels.iloc[i]
        sample = {'context': context, 'location': location, 'mask': mask, 'labels': label}
        return sample

In [None]:
dataset = PolyphoneDisambiguationDataset(
    df['context'].apply(encode), 
    df['location'], 
    df['mask'],
    df['index']
)

next(iter(dataset))

{'context': tensor([   0,   18, 1425,   13,  527,  619, 1237,    6,  217,  191,  230,  231,
          174,   11,  453,  431,  174,  167,  221, 1237,    6,  217,  261,  141,
           85,   13,  305,  447,  312,  625,   74,  342,    6,  268,  794, 1237,
            6,  217,   78,  758, 1502, 1277,   15,  393, 2281, 1228,   15, 1007,
         1237,    6,  217, 7242,  888,   85,    2,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1]),
 'labels': 0,
 'location': 22,
 'mask': tensor([1., 1., 0., 0., 0.])}

In [None]:
# rows from labelled DataFrame; pad context to 64 chars
train_len = int(len(dataset) * 0.8)
test_len = len(dataset) - train_len

train_dataset, test_dataset = random_split(dataset, [train_len, test_len])

batch_size = 1 # code does not work with non-trivial batches
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
next(iter(train_loader))

{'context': tensor([[   0,   11, 1520,   92,   19,  249, 5160,    6,  429,   21,   12, 2803,
          1180,   21,   16,   37, 5160,    6,   47,   78,    8,  378,  621,   13,
           437,  478,  787,  325,  186,    6,   22,  834,  108,   32,  517,  842,
           169,   13,    2,    1,    1,    1,    1,    1,    1,    1,    1,    1,
             1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
             1,    1,    1,    1,    1,    1]]),
 'labels': tensor([2]),
 'location': tensor([18]),
 'mask': tensor([[1., 1., 1., 0., 0.]])}

In [None]:
guwenbert = AutoModel.from_pretrained('ethanyt/guwenbert-base')

Some weights of the model checkpoint at ethanyt/guwenbert-base were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
class PolyphoneDisambiguationModel(nn.Module):
    def __init__(self):
        super(PolyphoneDisambiguationModel, self).__init__()
        
        self.lstm_hidden_size = 128 # tuneable hyperparam: reduce if model shows evidence of overfitting
        
        self.bert = guwenbert
        
        self.blstm = nn.LSTM(input_size=768, 
                             hidden_size=self.lstm_hidden_size, 
                             num_layers=2, 
                             bidirectional=True, 
                             batch_first=True)
        
        self.fc = nn.Linear(self.lstm_hidden_size * 2, 5)
        
    # mask should be a vector where mask[i] = 1 iff i is a valid Baxter index for the character to predict (and 0 otherwise)
    def forward(self, context, location, mask):
        embedding = self.bert(context).last_hidden_state

        # first index selects batch element, second index selects character to predict
        # not sure how to get batching to work in this specific instance
        # +1 is to account for the start token
        lstm_out = self.blstm(embedding)[0][:, location + 1, :][0][0]

        fc_out = self.fc(lstm_out)
        output = fc_out.masked_fill((1 - mask).bool(), float('-inf')) # logits
        # output = F.softmax(output, dim=1)
        
        return output

In [None]:
model = PolyphoneDisambiguationModel()

# freeze bert layer - probably better given our small dataset
for i, child in enumerate(model.children()):
    if i == 0:
        for param in child.parameters():
            param.requires_grad = False

optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-4) # adjust lr as needed
loss_function = nn.CrossEntropyLoss() # try Modified Focal Loss as needed (likely not supported by PyTorch natively :( )

num_epochs = 10
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
    'linear',
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device);

In [None]:
progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    for batch in train_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        output = model(batch['context'], batch['location'], batch['mask'])
        loss = loss_function(output, batch['labels'])
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

  0%|          | 0/4360 [00:00<?, ?it/s]

In [None]:
result = []

freqs = {}
for i, row in df.iterrows():
    if row.polyphone not in freqs:
        freqs[row.polyphone] = [0] * 5
    freqs[row.polyphone][row['index']] += 1

most_freq_index = {int(encode(k)[1]): np.argmax(v) for k, v in freqs.items()}

model.eval()
for batch in test_loader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        logits = model(batch['context'], batch['location'], batch['mask'])

        pred = torch.argmax(logits)

        polyphone_token = int(batch['context'][0][batch['location'][0] + 1])

        if polyphone_token in most_freq_index:
            most_freq = most_freq_index[polyphone_token]
        else:
            most_freq = None

        result.append({'prediction': pred, 
                       'reference': batch['labels'][0], 
                       'most_freq': most_freq})

In [None]:
# confusion matrix for the model, treating the most common pronunciation of a character as a "positive" result

tp = 0
tn = 0
fp = 0
fn = 0
for pair in result:
    pred = int(pair['prediction'])
    ref = int(pair['reference'])
    most_freq = pair['most_freq']
    tp += pred == ref and ref == most_freq
    tn += pred == ref and ref != most_freq
    fp += pred != ref and ref != most_freq
    fn += pred != ref and ref == most_freq

print('True positives (predicts the most common pronunciation correctly):', tp)
print('True negatives (predicts a less common pronunciation correctly):', tn)
print('False positives (predicts the most common pronunciation for a character that uses a different pronunciation):', fp)
print('False negatives (predicts a less common pronunciation for a character that uses the most common pronunciation):', fn)
print('Accuracy:', (tp + tn) / len(result))
precision = tp / (tp + fp)
print('Precision:', precision)
recall = tp / (tp + fn)
print('Recall:', recall)
print('F1 score:', 2 / (1/precision + 1/recall))
# print('Diagnostic odds ratio:', (tp * tn) / (fp * fn))

# comparing precision and F1 to a naive model (always predicts the most common pronunciation)
print('\nComparison to baseline:')
baseline_prec = (tp + fn) / len(result)
print('Precision (for a model that always predicts the most common pronunciation):', baseline_prec)
print('F1 score (for a model that always predicts the most common pronunciation):', 2 / (1/baseline_prec + 1))

True positives (predicts the most common pronunciation correctly): 102
True negatives (predicts a less common pronunciation correctly): 0
False positives (predicts the most common pronunciation for a character that uses a different pronunciation): 3
False negatives (predicts a less common pronunciation for a character that uses the most common pronunciation): 4
Accuracy: 0.9357798165137615
Precision: 0.9714285714285714
Recall: 0.9622641509433962
F1 score: 0.9668246445497629

Comparison to baseline:
Precision (for a model that always predicts the most common pronunciation): 0.9724770642201835
F1 score (for a model that always predicts the most common pronunciation): 0.986046511627907


# Complete g2p use case

This section demonstrates how to combine the masked LM and polyphone disambiguation models for more robust g2p capabilities (ascribing some pronunciation for every character in the GuwenBERT lexicon).


In [None]:
def get_context(text, poly_index, num_left=32, num_right=32):
    left = max(0, poly_index - num_left)
    right = min(len(text), poly_index + num_right + 1)
    return text[left:right], poly_index - left

all_baxter = masked_lm(sequence, table)

pb = tqdm(range(len(sequence)))

phoneme_seq = []
for i, c in enumerate(sequence):
    if c not in table:
        phoneme_seq.append(None) # default behavior for characters outside the GuwenBERT lexicon
    elif len(table[c]) == 1:
        phoneme_seq.append(table[c][0])
    else:
        context, location = get_context(sequence, i)
        mask = get_mask(len(table[c]))[None, :].to(device)
        context = encode(context)[None, :].to(device)
        logits = model(context, torch.tensor(location)[None], mask)
        pred = torch.argmax(logits)
        phoneme_seq.append(table[c][pred])

    pb.update(1)




  0%|          | 0/1001 [00:00<?, ?it/s]

In [None]:
for p in phoneme_seq[:10]:
    print(p)

None
['', '', '', 'k', 'ˤ', '', 'a', 'r', '', '', '', 'B&SOC2015-10-13', '']
['', '', '', 'g', 'ˤ', 'r', 'a', '', 'ʔ', '', '', 'B&SOC2015-10-13', '']
['', '', '', 'k', 'ˤ', '', 'a', 'r', '', '', '', 'B&SOC2015-10-13', '']
['m-', '', '', 'd', '', '', 'a', 'ŋ', 'ʔ', '', '', 'B&SOC2015-10-13', '']
['', '', '', 'g', '', 'r', 'a', 'r', '', '', '', 'B&SOC2015-10-13', '']
['', '', '', 'ŋ', '', '', 'o', 'r', '', '', '', 'B&SOC2015-10-13', '']
['', '', '', 'qʰ', 'ˤ', 'r', 'a', 'ŋ', '', '', '', 'B&SOC2015-10-13', '']
['C-', '', '', 'r', '', '', 'i', 't', '', 's', '', 'B&SOC2015-10-13', '']
['', '', '', 't', '', 'r', 'e', 'ŋ', '', '', '', 'B&SOC2015-10-13', '']


Something that is likely pretty important down the road but I didn't have the time to figure out is getting batching to work, which would improve the code's performance significantly. Unfortunately with the time allotted, I decided to prioritize getting the logic implemented and documented here.