## To create BPE vocabulary from DNA sequences.

In [None]:
from collections import defaultdict, Counter
import re

In [None]:
class DNABPE:
    def __init__(self, vocab_size=4000, min_frequency=3):
        self.vocab_size = vocab_size
        self.min_frequency = min_frequency
        self.vocab = set()
        self.merges = {}
        
    def get_stats(self, sequences):
        """computing frequency of adjacent symbol pairs"""
        pairs = defaultdict(int)
        for sequence in sequences:
            symbols = sequence.split()
            for i in range(len(symbols) - 1):
                pairs[symbols[i], symbols[i + 1]] += 1
        return pairs
    
    def merge_vocab(self, pair, sequences):
        """merge the most frequent pair in the sequences"""
        first, second = pair
        new_sequences = []
        pattern = re.compile(r'(?<!\S)' + re.escape(first + ' ' + second) + r'(?!\S)')
        replacement = first + second
        
        for sequence in sequences:
            new_sequence = pattern.sub(replacement, sequence)
            new_sequences.append(new_sequence)
            
        return new_sequences, replacement
    
    def fit(self, fasta_file):
        """training BPE from a FASTA file"""
        #loading fasta sequences
        sequences = []
        with open(fasta_file, 'r') as f:
            current_seq = ""
            for line in f:
                if line.startswith('>'):
                    if current_seq:
                        sequences.append(current_seq)
                    current_seq = ""
                else:
                    current_seq += line.strip()
            if current_seq:
                sequences.append(current_seq)
        
        print(f"Loaded {len(sequences)} sequences")
        
        # initial tokens with single nucleotides
        initial_vocab = {'A', 'T', 'C', 'G'}
        self.vocab = initial_vocab.copy()
        
        # transform sequences to tokenized form
        tokenized_seqs = [' '.join(list(seq)) for seq in sequences]
        
        # traning loop
        num_merges = 0
        while len(self.vocab) < self.vocab_size:
            # stats: { (token1, token2): frequency }
            stats = self.get_stats(tokenized_seqs)
            if not stats:
                break
                
            # find the most frequent pair
            best_pair = max(stats.items(), key=lambda x: x[1])
            pair, freq = best_pair
            
            # check stopping cutoff
            if freq < self.min_frequency:
                print(f"Stopping: no token frequency > {self.min_frequency}")
                break
                
            # update vocab
            tokenized_seqs, new_token = self.merge_vocab(pair, tokenized_seqs)
            self.vocab.add(new_token)
            self.merges[pair] = new_token
            num_merges += 1
            
            if num_merges % 100 == 0:
                print(f"Merges: {num_merges}, Vocab size: {len(self.vocab)}, "
                      f"Last merge: {pair} -> {new_token} (freq: {freq})")
        
        print(f"Training completed. Final vocab size: {len(self.vocab)}")
        print(f"Total merges: {num_merges}")
        return tokenized_seqs
    
    def tokenize(self, sequence):
        """ tokenize a single sequence"""
        tokens = list(sequence)  # initial tokens are single characters
        
        # apply BPE tokenization
        changed = True
        while changed:
            changed = False
            i = 0
            new_tokens = []
            
            while i < len(tokens):
                if i < len(tokens) - 1:
                    pair = (tokens[i], tokens[i + 1])
                    if pair in self.merges:
                        new_tokens.append(self.merges[pair])
                        i += 2
                        changed = True
                    else:
                        new_tokens.append(tokens[i])
                        i += 1
                else:
                    new_tokens.append(tokens[i])
                    i += 1
                    
            tokens = new_tokens
            
        return tokens
    
    def analyze_compression(self, fasta_file):
        """analyze compression ratio"""
        # reload sequences for compression analysis
        sequences = []
        with open(fasta_file, 'r') as f:
            current_seq = ""
            for line in f:
                if line.startswith('>'):
                    if current_seq:
                        sequences.append(current_seq)
                    current_seq = ""
                else:
                    current_seq += line.strip()
            if current_seq:
                sequences.append(current_seq)
        
        print("\n" + "="*50)
        print("Compression Analysis")
        print("="*50)
        
        total_original_len = 0
        total_token_len = 0
        
        for i, seq in enumerate(sequences[:40]):  # only analyze first 40 sequences
            tokens = self.tokenize(seq)
            original_len = len(seq)
            token_len = len(tokens)
            compression_ratio = original_len / token_len
            
            total_original_len += original_len
            total_token_len += token_len
            
            print(f"Seq {i+1}: Original={original_len}, "
                  f"Tokens={token_len}, Compression={compression_ratio:.2f}x")
        
        # summary
        if sequences:
            avg_compression = total_original_len / total_token_len
            print(f"\nOverall: Average compression ratio = {avg_compression:.2f}x")
            print(f"Reduced from {total_original_len} to {total_token_len} tokens "
                  f"({total_token_len/total_original_len*100:.1f}% of original)")
        
        return {
            'vocab_size': len(self.vocab),
            'merges': len(self.merges),
            'avg_compression_ratio': avg_compression if sequences else 0
        }

def main():
    # usage example
    bpe = DNABPE(vocab_size=4000, min_frequency=5)
    
    # BPE training
    tokenized_sequences = bpe.fit("./data/test3.fasta")  # 替换为你的FASTA文件
    
    # show part of the vocabulary
    print(f"\nFirst 20 tokens in vocabulary:")
    vocab_list = sorted(list(bpe.vocab), key=len)
    for i, token in enumerate(vocab_list[:20]):
        print(f"  {i+1:2d}. '{token}' (length: {len(token)})")
    
    # analyzing compression ratio
    stats = bpe.analyze_compression("./data/test3.fasta")
    
    # save vocabulary to file
    with open("bpe_vocab.txt", "w") as f:
        for token in sorted(bpe.vocab, key=len):
            f.write(f"{token}\n")
    
    print(f"\nVocabulary saved to bpe_vocab.txt")

if __name__ == "__main__":
    main()

Loaded 55 sequences
Merges: 100, Vocab size: 104, Last merge: ('G', 'TT') -> GTT (freq: 37)
Merges: 200, Vocab size: 204, Last merge: ('TGGTTACCTGCAACCGGTAAAGTATATCTACCACCATCGACCCCAGTTGCAAGGGTACAAAGCACGGATGAGTACATACAAAGGACTGACATCTTTTACCACGCTAATAGTGATCGGTTACTCACAGTAGGACATCCATACTTTGAGGTTCGAGCCACAACAGAACCATATCAGGTGACAGTACCTAAAGTTAGTGGAAATCAGTTTAGAGCTTTCAGACTTAAATTACCTGATCCTAATAGGTTTGCATTGGTAGATACTACAGTGTATAATCCTGACAAGGAAAGATTAGTC', 'TGGG') -> TGGTTACCTGCAACCGGTAAAGTATATCTACCACCATCGACCCCAGTTGCAAGGGTACAAAGCACGGATGAGTACATACAAAGGACTGACATCTTTTACCACGCTAATAGTGATCGGTTACTCACAGTAGGACATCCATACTTTGAGGTTCGAGCCACAACAGAACCATATCAGGTGACAGTACCTAAAGTTAGTGGAAATCAGTTTAGAGCTTTCAGACTTAAATTACCTGATCCTAATAGGTTTGCATTGGTAGATACTACAGTGTATAATCCTGACAAGGAAAGATTAGTCTGGG (freq: 25)
Merges: 300, Vocab size: 304, Last merge: ('TACTGGGTACATTACCAATTCCAAGGATGACAGACAAGATACATCCTTTGATCCCAAACAGGTACAAATGTTTATAATTGGCTGTACCCCTTGCTGGGGAGAGCATTGGGATATTGCTCCACGCTGTGATGATGATCAACCTATCCAAGGGGCCTGTCCTCCATTAGAATTAAGAAATACTATTATTGAGGATGGCGATATG