In [28]:
with open('../Jack_London_-_The_Sea_Wolf_ascii.txt', 'r') as file:
    text = file.read()


In [31]:
# sorted(list(set(text)))

In [1]:
# Implementation based on https://www.cs.helsinki.fi/u/tpkarkka/publications/jacm05-revised.pdf and https://mailund.dk/posts/skew-python-go/

import numpy as np
import numba
from typing import Tuple


@numba.jit()
def merge(x: np.array, SA12: np.array, SA3: np.array) -> np.array:
    "Merge the suffixes in sorted SA12 and SA3."
    ISA = np.zeros((len(x),), dtype='int')
    for i in range(len(SA12)):
        ISA[SA12[i]] = i
    SA = np.zeros((len(x),), dtype='int')
    idx = 0
    i, j = 0, 0
    while i < len(SA12) and j < len(SA3):
        if less(x, SA12[i], SA3[j], ISA):
            SA[idx] = SA12[i]
            idx += 1
            i += 1
        else:
            SA[idx] = SA3[j]
            idx += 1
            j += 1
    if i < len(SA12):
        SA[idx:len(SA)] = SA12[i:]
    elif j < len(SA3):
        SA[idx:len(SA)] = SA3[j:]
    return SA


@numba.jit()
def u_idx(i: int, m: int) -> int:
    "Map indices in u back to indices in the original string."
    if i < m:
        return 1 + 3 * i
    else:
        return 2 + 3 * (i - m - 1)


@numba.jit()
def safe_idx(x: np.array, i: int) -> int:
    "Hack to get zero if we index beyond the end."
    return 0 if i >= len(x) else x[i]


@numba.jit()
def symbcount(x: np.array, asize: int) -> np.array:
    "Count how often we see each character in the alphabet."
    counts = np.zeros((asize,), dtype="int")
    for c in x:
        counts[c] += 1
    return counts


@numba.jit()
def cumsum(counts: np.array) -> np.array:
    "Compute the cumulative sum from the character count."
    res = np.zeros((len(counts, )), dtype='int')
    acc = 0
    for i, k in enumerate(counts):
        res[i] = acc
        acc += k
    return res


@numba.jit()
def bucket_sort(x: np.array, asize: int,
                idx: np.array, offset: int = 0) -> np.array:
    "Sort indices in idx according to x[i + offset]."
    sort_symbs = np.array([safe_idx(x, i + offset) for i in idx])
    counts = symbcount(sort_symbs, asize)
    buckets = cumsum(counts)
    out = np.zeros((len(idx),), dtype='int')
    for i in idx:
        bucket = safe_idx(x, i + offset)
        out[buckets[bucket]] = i
        buckets[bucket] += 1
    return out


@numba.jit()
def radix3(x: np.array, asize: int, idx: np.array) -> np.array:
    "Sort indices in idx according to their first three letters in x."
    idx = bucket_sort(x, asize, idx, 2)
    idx = bucket_sort(x, asize, idx, 1)
    return bucket_sort(x, asize, idx)


@numba.jit()
def triplet(x: np.array, i: int) -> Tuple[int, int, int]:
    "Extract the triplet (x[i],x[i+1],x[i+2])."
    return safe_idx(x, i), safe_idx(x, i + 1), safe_idx(x, i + 2)


@numba.jit()
def collect_alphabet(x: np.array, idx: np.array) -> Tuple[np.array, int]:
    "Map the triplets starting at idx to a new alphabet."
    alpha = np.zeros((len(x),), dtype='int')
    value = 1
    last_trip = -1, -1, -1
    for i in idx:
        trip = triplet(x, i)
        if trip != last_trip:
            value += 1
            last_trip = trip
        alpha[i] = value
    return alpha, value - 1


@numba.jit()
def build_u(x: np.array, alpha: np.array) -> np.array:
    "Construct u string, using 1 as central sentinel."
    a = np.array([alpha[i] for i in range(1, len(x), 3)] +
                 [1] +
                 [alpha[i] for i in range(2, len(x), 3)])
    return a


@numba.jit()
def less(x: np.array, i: int, j: int, ISA: np.array) -> bool:
    "Check if x[i:] < x[j:] using the inverse suffix array for SA12."
    a: int = safe_idx(x, i)
    b: int = safe_idx(x, j)
    if a < b:
        return True
    if a > b:
        return False
    if i % 3 != 0 and j % 3 != 0:
        return ISA[i] < ISA[j]
    return less(x, i + 1, j + 1, ISA)


@numba.jit()
def skew_rec(x: np.array, asize: int) -> np.array:
    "skew/DC3 SA construction algorithm."

    SA12 = np.array([i for i in range(len(x)) if i % 3 != 0])

    SA12 = radix3(x, asize, SA12)
    new_alpha, new_asize = collect_alphabet(x, SA12)
    if new_asize < len(SA12):
        # Recursively sort SA12
        u = build_u(x, new_alpha)
        sa_u = skew_rec(u, new_asize + 2)
        m = len(sa_u) // 2
        SA12 = np.array([u_idx(i, m) for i in sa_u if i != m])

    if len(x) % 3 == 1:
        SA3 = np.array([len(x) - 1] + [i - 1 for i in SA12 if i % 3 == 1])
    else:
        SA3 = np.array([i - 1 for i in SA12 if i % 3 == 1])
    SA3 = bucket_sort(x, asize, SA3)
    return merge(x, SA12, SA3)


def get_suffix_array(x: str) -> np.array:
    if "$" in x:
        raise ValueError('Text should not contain $')
    str_to_int = {
        "$": 0,  # End of strig
    }
    str_to_int = str_to_int | {
        c: n+1
        for (n, c) in enumerate(sorted(list(set(x))))
    }
    return skew_rec(np.array([str_to_int[y] for y in x]), len(str_to_int))


In [2]:
def print_sa(sa, t):
    for x in sa:
        print(t[x:]+t[0:x])
    print('-'*10)


dna_string = 'AGCTN4ACTGN'
suffix_array = get_suffix_array(dna_string)
print_sa(suffix_array, dna_string)


4ACTGNAGCTN
ACTGNAGCTN4
AGCTN4ACTGN
CTGNAGCTN4A
CTN4ACTGNAG
GCTN4ACTGNA
GNAGCTN4ACT
NAGCTN4ACTG
N4ACTGNAGCT
TGNAGCTN4AC
TN4ACTGNAGC
----------


In [11]:
def get_sort_canon_repr(s):
    """Returns cononical representation of sort by string s
    e.g. [3,1,0,2]"""
    sort_info = [None]*len(s)
    for new_place, (c, old_place) in enumerate(sorted([(c, i) for i, c 
                                                       in enumerate(s)])):
        sort_info[old_place] = new_place
    return sort_info


def apply_permutation(s, perm):
    res = [None]*len(s)
    for old_place, new_place in enumerate(perm):
        res[new_place] = s[old_place]
    return res


def inverse_permutation(canon_repr):
    res = [None]*len(canon_repr)
    for old_place, new_place in enumerate(canon_repr):
        res[new_place] = old_place
    return res


ban_bwt = 'а#ннБннБаааа'
inverse_permutation(get_sort_canon_repr(ban_bwt))


[1, 4, 7, 0, 8, 9, 10, 11, 2, 3, 5, 6]

In [61]:
SHARP = '#'


class BWT:
    def encode(t: str):
        bwt = [None]*len(t)
        sa = get_suffix_array(t)
        # print_sa(sa, t)
        for i in range(len(t)):
            bwt[i] = t[sa[i]-1]
        return ''.join(bwt)

    def decode(bwt: str):
        sigma = get_sort_canon_repr(bwt)
        inversed_sigma = inverse_permutation(sigma)
        res = [None]*len(bwt)
        i = bwt.index(SHARP)
        index_in_first_col = inversed_sigma[i]
        for j, c in enumerate(bwt):
            res[j] = bwt[index_in_first_col]
            index_in_first_col = inversed_sigma[index_in_first_col]
        return ''.join(res)


def _shift(alphabet, up, lo):
    for i in range(lo, up-1, -1):
        alphabet[i+1] = alphabet[i]
    return alphabet


class mtf:
    def get_alphabet():
        return [chr(i) for i in range(ord('z')+1)]
    
    def update_alphabet(alphabet, ind, c):
        if ind > 1:
            _shift(alphabet, 1, ind-1)
            alphabet[1] = c
        if ind == 1:
            alphabet[1] = alphabet[0]
            alphabet[0] = c

    def encode(t: str):
        alphabet = mtf.get_alphabet()
        diff = set(t)-set(alphabet)
        if diff:
            raise ValueError(
                f'Found chars in text that are not presented in alphabet: {diff}')
        res = []
        for c in t:
            ind = alphabet.index(c)
            res.append(ind)
            mtf.update_alphabet(alphabet, ind, c)
            
        return res

    def decode(encoded):
        alphabet = mtf.get_alphabet()
        res = []
        for ind in encoded:
            c = alphabet[ind]
            res.append(c)
            mtf.update_alphabet(alphabet, ind, c)
        return ''.join(res)


In [62]:
eng_example = 'BanBanaxxxxxxxxnana#'#'БанБананана#'
mtf_encoded = mtf.encode(eng_example)
mtf.decode(mtf_encoded)

'BanBanaxxxxxxxxnana#'

In [63]:
import random
import string


def randomword(length):
    letters = string.ascii_lowercase
    return ''.join(random.choice(letters) for i in range(length))


def test_bwt_encode_decode():
    for i in range(1000):
        w = randomword(10)+SHARP
        # print(w)
        assert BWT.decode(BWT.encode(w)) == w

def test_mtf():
    for i in range(1000):
        w = randomword(10)+SHARP
        # print(w)
        assert mtf.decode(mtf.encode(w)) == w

test_bwt_encode_decode()
test_mtf()
ban = 'БанБананана#'
BWT.decode(BWT.encode(ban))
# BWT.encode('БанБананана')


'БанБананана#'

In [9]:
from itertools import permutations


def test_permutation():
    a = list('abbced')
    for perm in list(permutations(list(range(len(a)))))[:100]:
        assert a == apply_permutation(
            apply_permutation(a, perm), inverse_permutation(perm))


test_permutation()
