In [31]:
from collections import defaultdict
sequences = ["ACTCGATCGACTCAG",
             "ACGACTCGACTACGCAGACAT",
             "ACGACTGTATATTAGCGACTA",
             "ACTACGCATCATGCACGACTA",
             "TATATCATTCTCCTTACTCTCAG",
             "GAGTAGTGCATCGTAGCTA"]



def kmerize(seq, k = 6):
    return [seq[i:i+k] for i in range(0, len(seq), k)]


def get_word_freqs(sequences, k = 6):
    word_freqs = defaultdict(int)
    for seq in sequences:
        tokens = kmerize(seq, k = k)
        for token in tokens:
            word_freqs[token] += 1
    print(word_freqs)
    return word_freqs
    
def get_alphabet(word_freqs):
    alphabet = []

    for word in word_freqs.keys():
        for letter in word:
            if letter not in alphabet:
                alphabet.append(letter)
    alphabet.sort()
    print(alphabet)
    return alphabet

def compute_pair_freqs(splits, word_freqs):
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            pair_freqs[pair] += freq
    return pair_freqs


def merge_pair(a, b, splits, word_freqs):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue

        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                split = split[:i] + [a + b] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

def tokenize(text, merges):
    tokens = kmerize(text)
    
    splits = [[l for l in word] for word in tokens]
    print(splits)
    for pair, merge in merges.items():
        for idx, split in enumerate(splits):
            i = 0
            while i < len(split) - 1:
                if split[i] == pair[0] and split[i + 1] == pair[1]:
                    split = split[:i] + [merge] + split[i + 2 :]
                else:
                    i += 1
            splits[idx] = split

    return sum(splits, [])


def train_tokenizer(corpus, vocab_size = 100, k = 6):

    word_freqs = get_word_freqs(corpus, k = k)
    vocab = get_alphabet(word_freqs)

    splits = {word: [c for c in word] for word in word_freqs.keys()}
    pair_freqs = compute_pair_freqs(splits, word_freqs)

    for i, key in enumerate(pair_freqs.keys()):
        print(f"{key}: {pair_freqs[key]}")
        if i >= 5:
            break


    merges = dict()
    splits = {word: [c for c in word] for word in word_freqs.keys()}

    while len(vocab) < vocab_size:
        pair_freqs = compute_pair_freqs(splits, word_freqs)
        best_pair = ""
        max_freq = None
        for pair, freq in pair_freqs.items():
            if max_freq is None or max_freq < freq:
                best_pair = pair
                max_freq = freq
        splits = merge_pair(*best_pair, splits, word_freqs)
        merges[best_pair] = best_pair[0] + best_pair[1]
        vocab.append(best_pair[0] + best_pair[1])

    return merges
    

In [1]:
from collections import Counter
from typing import Dict, List, Tuple, Iterable

def kmerize(seq: str, k: int = 6) -> List[str]:
    # non-overlapping k-mers (keeps last shorter chunk, same as your version)
    return [seq[i:i + k] for i in range(0, len(seq), k)]

def get_word_freqs(sequences: Iterable[str], k: int = 6) -> Dict[str, int]:
    freqs = Counter()
    for seq in sequences:
        for tok in kmerize(str(seq), k=k):
            freqs[tok] += 1
    return dict(freqs)

def get_alphabet(word_freqs: Dict[str, int]) -> List[str]:
    # unique characters from all tokens
    letters = {c for w in word_freqs for c in w}
    return sorted(letters)

def _init_splits(word_freqs: Dict[str, int]) -> Dict[str, List[str]]:
    # per-token list of chars
    return {w: list(w) for w in word_freqs}

def compute_pair_freqs(
    splits: Dict[str, List[str]],
    word_freqs: Dict[str, int],
) -> Counter:
    # count adjacent char-pair frequencies weighted by token frequency
    pair_freqs = Counter()
    for w, freq in word_freqs.items():
        s = splits[w]
        if len(s) < 2 or freq == 0:
            continue
        for a, b in zip(s, s[1:]):
            pair_freqs[(a, b)] += freq
    return pair_freqs

def merge_pair(
    a: str,
    b: str,
    splits: Dict[str, List[str]],
    word_freqs: Dict[str, int],
) -> None:
    # in-place merge of best pair across all splits
    merged = a + b
    for w in word_freqs:
        s = splits[w]
        if len(s) < 2:
            continue
        out = []
        i = 0
        while i < len(s):
            if i < len(s) - 1 and s[i] == a and s[i + 1] == b:
                out.append(merged)
                i += 2
            else:
                out.append(s[i])
                i += 1
        splits[w] = out

def train_tokenizer(
    corpus: Iterable[str],
    vocab_size: int = 100,
    k: int = 6,
):
    # learn merges until reaching vocab_size or no pairs left
    word_freqs = get_word_freqs(corpus, k=k)
    vocab = get_alphabet(word_freqs)
    splits = _init_splits(word_freqs)
    merges: Dict[Tuple[str, str], str] = {}  # insertion order preserved

    while len(vocab) < vocab_size:
        pair_freqs = compute_pair_freqs(splits, word_freqs)
        if not pair_freqs:
            break
        (best_a, best_b), _ = pair_freqs.most_common(1)[0]
        merged_token = best_a + best_b
        merges[(best_a, best_b)] = merged_token
        vocab.append(merged_token)
        merge_pair(best_a, best_b, splits, word_freqs)

    return merges  # maps (a,b) -> "ab"

def tokenize(
    text: str,
    merges: Dict[Tuple[str, str], str],
    k: int = 6,
) -> List[str]:
    # apply merges in learned order to k-mer-char splits, then flatten
    tokens = kmerize(text, k=k)
    splits = [list(w) for w in tokens]
    for (a, b), merged in merges.items():
        for idx, s in enumerate(splits):
            if len(s) < 2:
                continue
            out = []
            i = 0
            while i < len(s):
                if i < len(s) - 1 and s[i] == a and s[i + 1] == b:
                    out.append(merged)
                    i += 2
                else:
                    out.append(s[i])
                    i += 1
            splits[idx] = out
    out_tokens: List[str] = []
    for s in splits:
        out_tokens.extend(s)
    return out_tokens

In [53]:
tokenizer = train_tokenizer(sequences, vocab_size=10, k = 12)
print(tokenizer)

{('A', 'C'): 'AC', ('T', 'A'): 'TA', ('T', 'C'): 'TC', ('G', 'AC'): 'GAC', ('G', 'C'): 'GC', ('A', 'TC'): 'ATC'}


In [None]:

#
#splits = merge_pair('A','C', splits)


TypeError: tokenize() missing 2 required positional arguments: 'vocab' and 'merges'

In [1]:
from Bio import SeqIO
import dna_bpe as bpe

corpus = []
with open("/Users/torbjornbak/Library/CloudStorage/OneDrive-Personal/Master Thesis/git/ML-bacterial-phenotyping/data/28901.2926.fna") as handle:
    for record in SeqIO.parse(handle, "fasta"):
        corpus.append(record.seq)

In [2]:
tokenizer, vocab = bpe.train_tokenizer(corpus, vocab_size=1000, k = 6)



In [3]:
vocab.sort()
vocab_dict = {token : i for i, token in enumerate(vocab)}

In [None]:
seq = bpe.tokenize(id = "1", text="tagatagaagtagcgcatcgcat", merges=tokenizer, vocab_dict=vocab_dict, k = 6)



In [6]:
print(seq)


{'1': [788, 162, 512, 767, 616, 375, 161]}


In [6]:
tokenized_sequence = bpe.tokenize("1", corpus[0], tokenizer)

In [7]:
print(f'{len(corpus[0])=}')
print(f'{len(tokenized_sequence)=}')
print(corpus[0])

print(f'{len(set(tokenized_sequence))=}')
print(f'{tokenized_sequence=}')



len(corpus[0])=408539
len(tokenized_sequence)=1
ttgataggccgggtgtgtaagcgcagcgatgcgttgagctaaccggtactaatgaaccgtgaggcttaaccttacaacgccgaagatgttttggcggatgagagacgattttcagcactgattcagagttgagtacgcaataatttgcgcagcagcaaggcggcaagcgaaggaaaggaaggcgcatacagaagtatgtgactgactttgcgagcgcaggcaacgccgctggtgcgataaagaattgcgtacagagcacaaagaatttgcctggcggcagtagcgcggtggtcccacctgaccccatgccgaactcagaagtgaaacgccgtagcgccgatggtagtgtggggtctccccatgcgagagtagggaactgccaggcatcaaataaaacgaagggccctgtcgaaagacagggccttttgttttatctgtagtctgtcggtgagcactcgttgcgcagcaacggcccgtagggtggcgggcaggacgcccgccataaactgccaggcatcagacaagtgaagaagcccgtccgtcaggatgggcttttttgcgtgtgtagtaactgtaaaaatgtaggcctgataagcgcagcgccatcaggcaacctgcaccgccggatggcagcttcgccttatccggcccacagcaattattccgcctacgaagctggtgctcttgtagaactgataagagcacaacccacccaatcgttaacgatgaactatcctttacacgtgcttatataagcagtgaggattttcattggctatcaaaccttttaactaccagcaagacttttccagcattgacttccgccagcagcctgaattgtatcaggttggacgaggcgagcagggggtgctactggttgaaccctacaaaagcgaaattcttcctttctggcgctataaagatgaagcatcggcgatgaaatccgcagaac

In [8]:
print([vocab_dict[s] for s in tokenized_sequence["1"]])

[951, 162, 681, 665, 742, 907, 59, 140, 644, 129, 768, 344, 457, 845, 1, 347, 525, 646, 768, 351, 780, 37, 582, 539, 914, 961, 539, 124, 124, 104, 989, 556, 845, 229, 124, 952, 727, 550, 69, 958, 132, 64, 552, 550, 503, 666, 63, 1, 690, 780, 125, 159, 207, 852, 845, 82, 980, 360, 138, 280, 108, 488, 369, 126, 914, 600, 780, 124, 130, 247, 125, 240, 484, 557, 767, 622, 886, 816, 256, 845, 82, 316, 870, 82, 810, 55, 846, 47, 730, 326, 207, 729, 845, 886, 715, 461, 315, 161, 601, 501, 159, 124, 665, 52, 574, 290, 1, 164, 360, 63, 578, 845, 826, 2, 517, 152, 599, 913, 181, 907, 160, 845, 835, 845, 129, 118, 452, 560, 103, 569, 441, 152, 749, 678, 105, 299, 569, 162, 53, 280, 810, 82, 55, 846, 126, 550, 326, 816, 734, 218, 550, 999, 600, 907, 729, 767, 121, 767, 15, 788, 569, 860, 615, 387, 280, 101, 96, 665, 218, 276, 395, 194, 315, 83, 552, 914, 221, 326, 351, 103, 62, 745, 802, 951, 788, 37, 846, 768, 525, 255, 37, 299, 94, 1, 838, 768, 361, 845, 52, 181, 493, 780, 103, 865, 915, 792, 55