In [None]:
from __future__ import annotations

from dataclasses import dataclass, field
# from functools import lru_cache
from typing import Any
import uuid

from cachetools import cached
from cachetools.keys import hashkey
import numpy as np

class Anything:
    def __eq__(self, other: Any) -> bool: return True
    def __repr__(self) -> str: return '<?>'
    __str__ = __repr__

@dataclass
class Token:
    value: str | Anything
    uid: str = field(default_factory=lambda: uuid.uuid4().hex)
    
    def __repr__(self) -> str:
        return str(self.value)

@dataclass
class MultiVariant:
    options: list[list[Token]]
    uid: str = field(default_factory=lambda: uuid.uuid4().hex)
    
    def __repr__(self) -> str:
        return str(self.options)

def tokens_from_text(text: str) -> list[Token]:
    return [Token(t) for t in text.split()]
    
def multi_variant_from_texts(texts: list[str]) -> MultiVariant:
    return MultiVariant([tokens_from_text(text) for text in texts])

@dataclass(kw_only=True)
class Match:
    true: list[Token]
    pred: list[Token]
    true_len: int
    n_errs: int
    
    @classmethod
    def from_pair(cls, true: list[Token], pred: list[Token]) -> Match:
        assert len(true) > 0 or len(pred) > 0
        return Match(
            true=true,
            pred=pred,
            true_len=len(true),
            n_errs=(
                0
                if [t.value for t in true] == [t.value for t in pred]
                or (len(true) == 1 and isinstance(true[0].value, Anything))
                else max(len(true), len(pred))
            ),
        )
    
    def __repr__(self) -> str:
        first = ' '.join([str(x) for x in self.true])
        second = ' '.join([str(x) for x in self.pred])
        return f'({first}, {second})'

@dataclass
class MatchesList:
    matches: list[Match]
    total_true_len: int
    total_n_errs: int
    total_n_correct: int
    
    @classmethod
    def from_list(cls, matches: list[Match]) -> MatchesList:
        return MatchesList(
            matches=matches,
            total_true_len=sum(m.true_len for m in matches),
            total_n_errs=sum(m.n_errs for m in matches),
            total_n_correct=sum(m.n_errs == 0 for m in matches),
        )
    
    def prepend(self, match: Match) -> MatchesList:
        return MatchesList(
            matches=[match] + self.matches,
            total_true_len=match.true_len + self.total_true_len,
            total_n_errs=match.n_errs + self.total_n_errs,
            total_n_correct=(match.n_errs == 0) + self.total_n_correct,
        )
    
    @property
    def value(self) -> int:
        return self.total_n_errs * 1_000_000 - self.total_n_correct
    

def select_shortest_multi_variants(seq: list[Token | MultiVariant]) -> list[Token]:
    result: list[Token] = []
    for x in seq:
        if isinstance(x, MultiVariant):
            result += min(x.options, key=len)
        else:
            result.append(x)
    return result

def align(
    true: list[Token | MultiVariant],
    pred: list[Token],
    try_speedup: bool = False,
) -> MatchesList:
    multivariant_prefixes: dict[tuple[str, int], list[Token]] = {}
    for x in true:
        if isinstance(x, MultiVariant):
            for i, option in enumerate(x.options):
                multivariant_prefixes[x.uid, i] = option
    
    err_cap = np.inf
    if len(true) > 50 and len(pred) > 50 and try_speedup:
        # block-wise
        n_blocks = 10
        true_points = np.linspace(0, len(true), num=n_blocks + 1, endpoint=True, dtype=int)
        pred_points = np.linspace(0, len(pred), num=n_blocks + 1, endpoint=True, dtype=int)
        matches_per_block: list[MatchesList] = []
        for i in range(n_blocks):
            matches_per_block.append(align(
                true[true_points[i]:true_points[i + 1]],
                pred[pred_points[i]:pred_points[i + 1]],
            ))
        err_cap = sum([m.total_n_errs for m in matches_per_block])
        print('err_cap', err_cap)
    
    @cached(cache={}, key=lambda *args: hashkey(*args[:-1]), info=True)  # do not cache the last argument
    def _align_recursive(
        true_pos: int,
        pred_pos: int,
        multivariant_prefix_id: tuple[str, int] | None,
        multivariant_prefix_pos: int,
        prev_total_err: int,
    ) -> MatchesList:
        # print('call', true_pos, pred_pos, multivariant_prefix_idx, multivariant_prefix_pos)
        _true = true[true_pos:]
        _pred = pred[pred_pos:]
        
        if multivariant_prefix_id is not None:
            prefix = multivariant_prefixes[multivariant_prefix_id][multivariant_prefix_pos:]
            _true = prefix + _true
        
        if len(_pred) == 0 and len(_true) == 0:
            return MatchesList.from_list([])
        elif len(_pred) == 0 and len(_true) > 0:
            _matches: list[Match] = []
            for token in _true:
                if len(shortest := select_shortest_multi_variants([token])):
                    _matches.append(Match.from_pair(shortest, []))
            return MatchesList.from_list(_matches)
        elif len(_pred) > 0 and len(_true) == 0:
            return MatchesList.from_list([
                Match.from_pair([], [token])
                for token in _pred
            ])
        elif not isinstance(_true[0], MultiVariant):
            options: list[MatchesList] = []
            current_match_options = [
                # option 1: match true[0] with pred[0]
                (1, 1, Match.from_pair(_true[:1], _pred[:1])), # type: ignore
                # option 2: match pred[0] with nothing
                (0, 1, Match.from_pair([], _pred[:1])),
                # option 3: match true[0] with nothing
                (1, 0, Match.from_pair(_true[:1], [])), # type: ignore
            ]
            filtered_match_options = [
                (i, j, current_match)
                for i, j, current_match in current_match_options
                if prev_total_err + current_match.n_errs <= err_cap
            ]
            if len(filtered_match_options) == 0:
                filtered_match_options = current_match_options[:1]
            for i, j, current_match in filtered_match_options:
                new_true_pos = true_pos
                new_multivariant_prefix_idx = multivariant_prefix_id
                new_multivariant_prefix_pos = multivariant_prefix_pos
                if i == 1:
                    if multivariant_prefix_id is not None:
                        if len(prefix) > 1: # type: ignore
                            new_multivariant_prefix_pos += 1
                        else:
                            new_multivariant_prefix_idx = None
                            new_multivariant_prefix_pos = 0
                    else:
                        new_true_pos += 1
                _results = _align_recursive(
                        new_true_pos,
                        pred_pos + j,
                        new_multivariant_prefix_idx,
                        new_multivariant_prefix_pos,
                        prev_total_err + current_match.n_errs
                    )
                options.append(
                    _results.prepend(current_match)
                )
            if isinstance(_true[0].value, Anything):
                current_match = Match.from_pair(_true[:1], _pred[:1]) # type: ignore
                options.append(
                    # option 4: match Anything with pred[0], but keep Anything in the true tokens
                    _align_recursive(
                        true_pos,
                        pred_pos + 1,
                        multivariant_prefix_id,
                        multivariant_prefix_pos,
                        prev_total_err,
                    ).prepend(current_match)
                )
            # print('_true', _true)
            # print('_pred', _pred)
            # print('OPTIONS')
            # for opt in options:
            #     print(opt.value, opt)
            
            return min(options, key=lambda x: x.value)
        else:
            assert multivariant_prefix_id is None
            options = [
                _align_recursive(
                    true_pos + 1,
                    pred_pos,
                    (_true[0].uid, i),
                    0,
                    prev_total_err,
                )
                for i in range(len(_true[0].options))
            ]
            return min(options, key=lambda x: x.value)
    
    result = _align_recursive(0, 0, None, 0, 0)
    print(_align_recursive.cache_info()) # type: ignore
    return result

# matches_list = align(
#     [
#         Token(Anything()),
#     ],
#     tokens_from_text('a b')
# )

# matches_list = align(
#     [
#         Token(Anything()),
#         multi_variant_from_texts(['a b', 'a']),
#         multi_variant_from_texts(['a', 'b a']),
#         multi_variant_from_texts(['x', 'y']),
#     ],
#     tokens_from_text('a b a')
# )

# matches_list = align(
#     [
#         multi_variant_from_texts(['a']),
#         Token('x'),
#     ],
#     [Token('a')],
# )

# matches_list



CacheInfo(hits=79600, misses=40401, maxsize=None, currsize=40401)
errs 67


In [80]:
true = [Token(str(x)) if np.random.rand() > 0.05 else Token(Anything()) for x in np.random.randint(0, 2, size=1000)]
pred = [Token(str(x)) for x in np.random.randint(0, 2, size=1000)]

%time result = align(true, pred, try_speedup=True) # type: ignore

print('errs', result.total_n_errs)

CacheInfo(hits=20000, misses=10201, maxsize=None, currsize=10201)
CacheInfo(hits=20200, misses=10201, maxsize=None, currsize=10201)
CacheInfo(hits=20400, misses=10201, maxsize=None, currsize=10201)
CacheInfo(hits=20700, misses=10201, maxsize=None, currsize=10201)
CacheInfo(hits=20500, misses=10201, maxsize=None, currsize=10201)
CacheInfo(hits=20000, misses=10201, maxsize=None, currsize=10201)
CacheInfo(hits=20100, misses=10201, maxsize=None, currsize=10201)
CacheInfo(hits=20300, misses=10201, maxsize=None, currsize=10201)
CacheInfo(hits=20400, misses=10201, maxsize=None, currsize=10201)
CacheInfo(hits=19900, misses=10201, maxsize=None, currsize=10201)
err_cap 258
CacheInfo(hits=352297, misses=706962, maxsize=None, currsize=706962)
CPU times: user 29.2 s, sys: 138 ms, total: 29.3 s
Wall time: 29.3 s
errs 270


In [78]:
import sys
sys.getrecursionlimit()

3000

In [81]:
import pandas as pd
df  = pd.read_csv('../eval_common_voice_17_0test_200samples.csv', skiprows=1) # type: ignore
true = [
    Token(word) if np.random.rand() > 0.05 else Token(Anything())
    for word in ' '.join(df.iloc[:100]['true_texts']).split() # type: ignore
]
pred = [
    Token(word)
    for word in ' '.join(df.iloc[:100]['pred_texts']).split() # type: ignore
]
print(len(true), len(pred))

%time result = align(true, pred, try_speedup=False) # type: ignore
print('errs', result.total_n_errs)

854 831
CacheInfo(hits=1450072, misses=711360, maxsize=None, currsize=711360)
CPU times: user 52 s, sys: 177 ms, total: 52.2 s
Wall time: 52.2 s
errs 712


In [82]:
%time result = align(true, pred, try_speedup=True) # type: ignore
print('errs', result.total_n_errs)

CacheInfo(hits=14274, misses=7224, maxsize=None, currsize=7224)
CacheInfo(hits=14357, misses=7224, maxsize=None, currsize=7224)
CacheInfo(hits=14522, misses=7308, maxsize=None, currsize=7308)
CacheInfo(hits=14108, misses=7224, maxsize=None, currsize=7224)
CacheInfo(hits=14522, misses=7308, maxsize=None, currsize=7308)
CacheInfo(hits=14357, misses=7224, maxsize=None, currsize=7224)
CacheInfo(hits=14108, misses=7224, maxsize=None, currsize=7224)
CacheInfo(hits=14439, misses=7308, maxsize=None, currsize=7308)
CacheInfo(hits=14357, misses=7224, maxsize=None, currsize=7224)
CacheInfo(hits=14446, misses=7395, maxsize=None, currsize=7395)
err_cap 735
CacheInfo(hits=1212441, misses=706304, maxsize=None, currsize=706304)
CPU times: user 47 s, sys: 205 ms, total: 47.2 s
Wall time: 47.2 s
errs 717
