In [1]:
import os
import sys
from pathlib import Path

src_path = Path(os.getcwd()).parent.parent.joinpath('src').absolute()
sys.path.append(str(src_path))

In [2]:
from huggingsound import TrainingArguments, ModelArguments, SpeechRecognitionModel, TokenSet
from typing import NamedTuple, TypedDict
import numpy as np
from persistence.model import FileMetadata
from pathlib import Path
from persistence.db import Database
from persistence.file_metadata_repository import FileMetadataRepository
from features.transcriber import Transcriber

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
ROOT_DIR = Path('/home/flok3n/konrad'); DB_DIR = ROOT_DIR
db = Database(DB_DIR)
await db.init_db()

repo = FileMetadataRepository(db)

files = await repo.load_all_files()

In [None]:
files_with_fixed_transcript = [f for f in files if f.is_transcript_analyzed and f.is_transcript_fixed]
len(files_with_fixed_transcript)

In [6]:
async def preprocess_data(files: list[FileMetadata]):
    transcriber = Transcriber(None)
    for f in files:
        src_path = ROOT_DIR.joinpath(f.name)
        target_path = PREPROCESSED_DATA_DIR.joinpath(f.name + '.wav')
        file_bytes = await transcriber._get_preprocessed_audio_file(src_path)
        with open(target_path, 'wb') as target:
            file_bytes.seek(0)
            target.write(file_bytes.getvalue())

In [7]:
# await preprocess_data(files_with_fixed_transcript)

In [8]:
np.random.seed(1234)

class TrainItem(TypedDict):
    path: str
    transcription: str

class Dataset(NamedTuple):
    train: list[TrainItem]
    eval: list[TrainItem]


def get_train_eval_dataset(files: list[FileMetadata], eval_split=0.2) -> Dataset:
    items_with_not_empty_trancript = [f for f in files if f.transcript != '']
    items_with_empty_trancript = [f for f in files if f.transcript == '']

    assert len(items_with_empty_trancript) > 1 and len(items_with_not_empty_trancript) > 1

    train, eval = [], [] 
    for src in (items_with_not_empty_trancript, items_with_empty_trancript):
        items = [TrainItem(path=str(PREPROCESSED_DATA_DIR.joinpath(f.name + '.wav').absolute()), transcription=str(f.transcript)) for f in src]
        num_eval_items = max(1, int(eval_split * len(items)))
        eval_idxs = set(np.random.choice(range(len(items)), num_eval_items, replace=False))
        for i, item in enumerate(items):
            if i in eval_idxs:
                eval.append(item)
            else:
                train.append(item)

    np.random.shuffle(train)
    np.random.shuffle(eval)
    return Dataset(train=train, eval=eval)

In [9]:
dataset = get_train_eval_dataset(files_with_fixed_transcript)

In [None]:
raise

In [18]:
from huggingsound import GreedyDecoder, Decoder
import torch

class DecoderAdapter(Decoder):
    def __init__(self, token_set, skip_special_tokens = True, ms_per_timestep = 20, probability_offset = 1, save_for_dev = True):
        super().__init__(token_set, skip_special_tokens, ms_per_timestep, probability_offset)
        self.save_for_dev = save_for_dev

    def _get_predictions(self, logits: torch.Tensor):

        predicted_ids = torch.argmax(logits, dim=-1)
        predictions = self._ctc_decode(predicted_ids)

        if self.save_for_dev:
            with open('./logits.np', 'wb') as f:
                logits_np = logits.detach().cpu().numpy()
                np.save(f, logits_np)

        return predictions
    
    @staticmethod
    def load_dev_logits() -> np.ndarray:
        with open('./logits.np', 'rb') as f:
            return np.load(f)


In [19]:
class Trie:
    class Node:
        def __init__(self, num_tokens: int, is_terminal: bool):
            self.children = [None] * num_tokens
            self.is_terminal = is_terminal

    def __init__(self, num_tokens: int):
        self.num_tokens = num_tokens
        self.root = self.Node(num_tokens=num_tokens, is_terminal=False)

    def add(self, word_token_ids: list[int]):
        if not word_token_ids:
            return
        cur = self.root
        for token in word_token_ids:
            if cur.children[token] is None:
                tmp = self.Node(self.num_tokens, is_terminal=False)
                cur.children[token] = tmp
            cur = cur.children[token]
        cur.is_terminal = True

    def search(self, word_token_ids: list[int]) -> tuple[bool, int, Node]:
        '''Returns: True iff word exists, length of the longest prefix, last Node on the path'''
        cur = self.root
        for i, token in enumerate(word_token_ids):
            if cur.children[token] is None:
                return False, i, cur
            cur = cur.children[token]
        return cur.is_terminal, len(word_token_ids), cur
    
    def has(self, word_token_ids: list[int]) -> bool:
        return self.search(word_token_ids)[0]
    
    def get_possible_next_tokens(self, node: "Trie.Node") -> list[int]:
        if node is None:
            return []
        return [i for i, x in enumerate(node.children) if x is not None] 

In [20]:
from typing import Generator
import editdistance
from collections import deque

class BKTree:
    class Node:
        def __init__(self, word: str):
            self.children: dict[int, "BKTree.Node"] = {}
            self.word = word

    def __init__(self, root_word: str):
        self.root = self.Node(root_word)

    def add(self, word: str):
        cur = self.root
        while True:
            dist = editdistance.eval(word, cur.word)
            if dist == 0:
                return
            child = cur.children.get(dist)
            if child is None:
                cur.children[dist] = self.Node(word)
                break
            else:
                cur = child

    def search(self, word: str, max_distance: int=1) -> Generator[tuple[str, int], None, None]:
        queue: deque["BKTree.Node"] = deque()
        queue.append(self.root)
        while queue:
            cur = queue.popleft()
            dist = editdistance.eval(word, cur.word)
            if dist <= max_distance:
                yield cur.word, dist
            for d in range(max(dist - max_distance, 0), dist + max_distance + 1):
                if next := cur.children.get(d):
                    queue.append(next)


In [None]:
t = Trie(5)
t.add([1, 3])
t.search([1, 3, 1])

In [None]:
bkt = BKTree('book')
words = ['books', 'cake', 'boo', 'cape', 'boon', 'cook', 'cart']
for w in words:
    bkt.add(w)

for w in bkt.search('name', max_distance=2):
    print(w[0], w[1])

In [23]:
import msgpack
import gzip
import wordfreq
with gzip.open(wordfreq.DATA_PATH.joinpath('large_pl.msgpack.gz'), 'rb') as f:
    data = msgpack.load(f, raw=False)

In [24]:
tokens = ['-', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'ó', 'ą', 'ć', 'ę', 'ł', 'ń', 'ś', 'ź', 'ż']
diacritics = ['ó', 'ą', 'ć', 'ę', 'ł', 'ń', 'ś', 'ź', 'ż']

tm = {x: i for i, x in enumerate(tokens)}

def get_token_id(token: str) -> int:
    try:
        return tm[token]
    except:
        raise KeyError(f'unknown token: {token}')

In [25]:
unknown_tokens = set()
trie = Trie(num_tokens=len(tokens))
for bucket in data[1:]:
    for word in bucket:
        try:
            tokenized_word = [get_token_id(x) for x in word]
            trie.add(tokenized_word)
        except KeyError as e:
            unknown_tokens.add(e.args[0][-1])

In [None]:
num_added_words = 0
bkt = BKTree('kurwa')
for bucket in data[1:]:
    for word in bucket:
        try:
            tokenized_word = [get_token_id(x) for x in word]
            bkt.add(word)
            num_added_words += 1
            if num_added_words % 10000 == 0:
                print(f'\r{num_added_words}', end='')
        except KeyError as e:
            unknown_tokens.add(e.args[0][-1])

In [27]:
class DictionaryAssistedDecoder(Decoder):
    def __init__(self, token_set, dictionary: Trie, edit_distance_search_tree: BKTree, disctionary_token_id_lut: dict[str, int], average_word_length=8):
        super().__init__(token_set)
        self.dictionary = dictionary
        self.edit_distance_search_tree = edit_distance_search_tree
        self.disctionary_token_id_lut = disctionary_token_id_lut
        self.best_configuration_compute_buff = np.zeros((2 * average_word_length + 1, average_word_length*10), dtype=np.float32)

    def _get_predictions(self, logits: torch.Tensor):
        assert logits.shape[0] == 1
        logits = logits[0, ...]
        predicted_ids = []
        N = logits.shape[0]
        probs = torch.nn.functional.softmax(logits, dim=-1)
        log_probs = torch.log(probs)
        tokens_sorted_by_prob = probs.argsort(dim=-1, descending=True)
        word_start = 0
        previous_word = None
        while word_start < N:
            i = word_start
            word_decoding_state: list[int] = []
            while i < N:
                most_likely_token = tokens_sorted_by_prob[i, 0]
                prob = probs[i, most_likely_token]
                if len(word_decoding_state) == 0 and (most_likely_token == self.token_set.blank_token_id or
                                                      most_likely_token == self.token_set.silence_token_id):
                    word_start += 1
                    break
                if most_likely_token == self.token_set.blank_token_id:
                    i += 1
                    continue
                if most_likely_token == self.token_set.silence_token_id:
                    word = self._join_word(word_decoding_state)
                    word_exists, existing_prefix_size, last_node = self.dictionary.search(self._convert_to_dictionary_tokens(word_decoding_state))
                    if word_exists:
                        # greedily accept word
                        word_start, previous_word = self._accept(predicted_ids, word_decoding_state, i)
                        break
                    elif previous_word is not None:
                        combined_word = previous_word + word_decoding_state
                        if self.dictionary.has(self._convert_to_dictionary_tokens(combined_word)):
                            predicted_ids.pop()
                            word_start, previous_word = self._accept(predicted_ids, word_decoding_state, i)
                            break
                    if prob < 0.9 and existing_prefix_size == len(word): 
                        # maybe this isn't the end of word, lookup trie if we have some options and if after this blank there is likely some letter
                        if possible_next_tokens := self.dictionary.get_possible_next_tokens(last_node):
                            LOOK_AHEAD = 5
                            should_continue_with_word = False
                            for k in range(i, min(i+LOOK_AHEAD, N)):
                                most_probable_existing_token = possible_next_tokens[0]
                                for token in possible_next_tokens[1:]:
                                    if probs[k, token] > probs[k, most_probable_existing_token]:
                                        most_probable_existing_token = token
                                if probs[k, most_probable_existing_token] > 0.25:
                                    word_decoding_state.append(most_probable_existing_token)
                                    i = k + 1
                                    should_continue_with_word = True
                                    break
                        if should_continue_with_word:
                            continue

                    # try splitting into two words and correct them separately                        
                    best_split, split_best_log_prob = self._split_word_and_attempt_correcting(
                        word_decoding_state, log_probs, word_start, i-1)

                    # find existing word with small levenshtein distance and highest probability
                    whole_best_alternative_tokens, whole_best_log_prob = self._correct_word(
                        log_probs, word, word_start, i-1, max_dist=1 if len(word) <= 3 else 2)
                    
                    if whole_best_alternative_tokens is None and best_split is None:
                        self._accept(predicted_ids, word_decoding_state, i)
                        word_start, previous_word = self._accept(predicted_ids, whole_best_alternative_tokens, i)
                    else:
                        if split_best_log_prob is not None and (whole_best_log_prob is None or split_best_log_prob > whole_best_log_prob):
                            self._accept(predicted_ids, best_split[0], i)
                            word_start, previous_word = self._accept(predicted_ids, best_split[1], i)
                        else:
                            word_start, previous_word = self._accept(predicted_ids, whole_best_alternative_tokens, i)
                    break
                            
                if len(word_decoding_state) == 0 or most_likely_token != tokens_sorted_by_prob[i-1, 0]:
                    word_decoding_state.append(most_likely_token)
                i += 1

        predictions = self._ctc_decode([predicted_ids])

        return predictions
    
    def _split_word_and_attempt_correcting(self, word_decoding_state: list[int],
            log_probs: torch.Tensor, word_start: int, word_end: int) -> tuple[tuple[list[int], list[int]] | None, float]:
        if len(word_decoding_state) < 3:
            return None, -np.inf

        best_split, best_log_prob = None, None

        for split_pos in range(2, len(word_decoding_state) - 2):
            first_tokens, second_tokens = word_decoding_state[:split_pos], word_decoding_state[split_pos:]
            first_needs_correction, second_needs_correction = False, False
            if not self.dictionary.has(self._convert_to_dictionary_tokens(first_tokens)):
                first_needs_correction = True
            if not self.dictionary.has(self._convert_to_dictionary_tokens(second_tokens)):
                second_needs_correction = True

            partial_log_prob = 0.
            if first_needs_correction:
                first_tokens, lp = self._correct_word(log_probs, self._join_word(first_tokens), word_start, word_start + split_pos - 1, max_dist=1) 
                if lp is None:
                    continue
                partial_log_prob += lp
            else:
                partial_log_prob += self._get_log_probability_of_best_configuration(log_probs, first_tokens, word_start, word_start + split_pos - 1)
            
            if second_needs_correction:
                second_tokens, lp = self._correct_word(log_probs, self._join_word(second_tokens), word_start + split_pos, word_end, max_dist=1) 
                if lp is None:
                    continue
                partial_log_prob += lp
            else:
                partial_log_prob += self._get_log_probability_of_best_configuration(log_probs, first_tokens, word_start + split_pos, word_end)
            
            if best_log_prob is None or partial_log_prob > best_log_prob:
                best_split, best_log_prob = (first_tokens, second_tokens), partial_log_prob

        return best_split, best_log_prob
    
    def _correct_word(self, log_probs: torch.Tensor, word: str, start_idx: int, end_idx: int, max_dist: int) -> tuple[list[int] | None, float | None]:
        alternatives = list(self.edit_distance_search_tree.search(word, max_distance=max_dist))
        overall_best_alternative_tokens, overall_best_log_prob = None, None
        for dist in range(1, 3):
            alternatives_with_dist = [x[0] for x in alternatives if x[1] == dist]
            if not alternatives_with_dist:
                continue
            tokens_of_alternatives = [self._tokenize_word(x) for x in alternatives_with_dist]
            most_probable_alternative, best_log_prob = 0, None
            for k, tokens in enumerate(tokens_of_alternatives):
                lp = self._get_log_probability_of_best_configuration(log_probs, tokens, start_idx, end_idx)
                if best_log_prob is None or lp > best_log_prob:
                    most_probable_alternative, best_log_prob = k, lp
            if best_log_prob is not None:
                if overall_best_log_prob is None or best_log_prob > overall_best_log_prob:
                    overall_best_alternative_tokens, overall_best_log_prob = tokens_of_alternatives[most_probable_alternative], best_log_prob
        return overall_best_alternative_tokens, overall_best_log_prob

    def _get_log_probability_of_best_configuration(self, log_probs: torch.Tensor, tokens: list[int], start_idx: int, end_idx: int) -> float:
        N = end_idx - start_idx + 1
        if len(tokens) > N:
            return -np.inf

        if self.best_configuration_compute_buff.shape[0] >= len(tokens) + 1 and self.best_configuration_compute_buff.shape[1] >= N + 1:
            F = self.best_configuration_compute_buff
        else:
            F = np.zeros((len(tokens) + 1, N + 1), dtype=np.float32)

        for i in range(1, N+1):
            F[0, i] = F[0, i-1] + log_probs[start_idx + i - 1, self.token_set.blank_token_id]

        for i in range(1, len(tokens) + 1):
            for j in range(i, N + 1):
                take_cur_lp = log_probs[start_idx + j - 1, tokens[i - 1]]
                blank_lp = log_probs[start_idx + j - 1, self.token_set.blank_token_id]
                if i == j:
                    F[i, j] = F[i-1, j-1] + take_cur_lp
                else:
                    F[i, j] = max(F[i, j-1] + blank_lp, F[i-1, j-1] + take_cur_lp)
        
        return F[len(tokens), N]
    
    def _accept(self, predicted_ids: list[int], token_ids: list[int], i: int) -> tuple[int, list[int]]:
        if token_ids:
            edited_token_ids = [token_ids[0]]
            for token in token_ids[1:]:
                if token == edited_token_ids[-1]:
                    # ctc decoding filters same consecutive tokens that are not separated by blank
                    edited_token_ids.append(self.token_set.blank_token_id)
                edited_token_ids.append(token)
            predicted_ids.extend(edited_token_ids)
        predicted_ids.append(self.token_set.silence_token_id)
        return i + 1, token_ids
    
    def _tokenize_word(self, word: str) -> list[int]:
        return [self.token_set.id_by_token[x] for x in word]
    
    def _convert_to_dictionary_tokens(self, token_ids: list[int]) -> list[int]:
        return [x - len(self.token_set.special_tokens) for x in token_ids]
    
    def _join_word(self, token_ids: list[int]) -> str:
        return ''.join([self.token_set.tokens[x] for x in token_ids])


In [None]:
# model = SpeechRecognitionModel("jonatasgrosman/wav2vec2-large-xlsr-53-polish", device='cuda')
model = SpeechRecognitionModel(str(OUTPUT_DIR), device='cuda')

In [29]:
tb = Transcriber(None)
engine = tb.Engine(tb, lambda: (model, DecoderAdapter(model.token_set)))

In [None]:
# transcription = await engine.transcribe(ROOT_DIR.joinpath('317290092_5660838990665621_2090424565105818065_n.mp4'))
# transcription = await engine.transcribe(ROOT_DIR.joinpath('video-1591823465.mp4'))
transcription = await engine.transcribe(ROOT_DIR.joinpath('353634787_6536568179715415_2948274603283924016_n.mp4'))

In [None]:
logits = DecoderAdapter.load_dev_logits()

DecoderAdapter(model.token_set, save_for_dev=False)(torch.from_numpy(logits))[0]['transcription']

In [None]:
DictionaryAssistedDecoder(model.token_set, dictionary=trie, edit_distance_search_tree=bkt, disctionary_token_id_lut=tm)(torch.from_numpy(logits))[0]['transcription']
# DictionaryAssistedDecoder(model.token_set, dictionary=trie, edit_distance_search_tree=bkt)(torch.from_numpy(logits))[0]['transcription']

In [13]:
firany_files = []
for f in files:
    if 'firany' in f.name:
        firany_files.append(f)

In [None]:
trie.search([get_token_id(x) for x in 'poddawać'])

In [None]:
trie.has([get_token_id(x) for x in 'ćme'])

In [None]:
transcription_results = []
for file in firany_files[1:2]:
    transcription = await engine.transcribe(ROOT_DIR.joinpath(file.name), DecoderAdapter(model.token_set))
    transcription_results.append(transcription)

In [None]:
for f, t in zip(firany_files, transcription_results):
    print(f'before: {f.transcript}')
    print(f'after: {t}')
    print()

In [None]:
raise

In [None]:
model = SpeechRecognitionModel("jonatasgrosman/wav2vec2-large-xlsr-53-polish", device='cuda')

In [12]:
args = TrainingArguments(
    overwrite_output_dir=True,
    per_device_train_batch_size=4,
    learning_rate=3e-4,
    num_train_epochs=5,
    fp16=False,
)

In [None]:
finetuned = model.finetune(
    str(OUTPUT_DIR),
    train_data=[*dataset.train, *dataset.eval],
    # eval_data=dataset.eval,
    eval_data=None,
    token_set=model.token_set,
    training_args=args
)

In [None]:
%load_ext tensorboard
%tensorboard --logdir finetuning/models