In [1]:
import pandas as pd
import numpy as np
import torch
from onehotencoder import OneHotEncoder
from typing import List
from collections import Counter

In [2]:
def count_sequence_lengths(
    seq_filepath: str,
    token_list: list[str],
    max_len: int = 57,
    allow_unknown: bool = False
) -> dict[int, int]:

    # Sort tokens by descending length so we always match the longest possible first
    tokens_sorted = sorted(token_list, key=len, reverse=True)
    counts = Counter()

    with open(seq_filepath, 'r', encoding='utf-8') as f:
        for lineno, line in enumerate(f, 1):
            seq = line.rstrip('\n')
            i = 0
            tokenized = []
            while i < len(seq):
                for tok in tokens_sorted:
                    if seq.startswith(tok, i):
                        tokenized.append(tok)
                        i += len(tok)
                        break
                else:
                    if allow_unknown:
                        # emit single-char as fallback
                        tokenized.append(seq[i])
                        i += 1
                    else:
                        raise ValueError(
                            f"Unknown token at line {lineno}, position {i}: {seq[i:]!r}"
                        )

            L = len(tokenized)
            if 1 <= L <= max_len:
                counts[L] += 1

    # make sure every length from 1..max_len is present
    return {length: counts.get(length, 0) for length in range(1, max_len + 1)}


import time
from typing import List, Dict, Optional

class TokenTrieNode:
    __slots__ = ("children", "token_end")
    def __init__(self):
        self.children: Dict[str, TokenTrieNode] = {}
        self.token_end: Optional[str] = None

def build_token_trie(tokens: List[str]) -> TokenTrieNode:
    root = TokenTrieNode()
    for tok in tokens:
        node = root
        for ch in tok:
            node = node.children.setdefault(ch, TokenTrieNode())
        node.token_end = tok
    return root

def tokenize_sequence(seq: str, trie: TokenTrieNode, allow_unknown: bool=False) -> List[str]:
    """
    Greedy longest‐match tokenization using the trie.
    """
    tokens = []
    i = 0
    n = len(seq)
    while i < n:
        node = trie
        last_match: Optional[str] = None
        last_pos = i
        j = i
        # walk as far as possible in the trie
        while j < n and seq[j] in node.children:
            node = node.children[seq[j]]
            j += 1
            if node.token_end:
                last_match = node.token_end
                last_pos = j
        if last_match:
            tokens.append(last_match)
            i = last_pos
        else:
            if allow_unknown:
                tokens.append(seq[i])
                i += 1
            else:
                # you could also `continue` here to skip bad lines
                raise ValueError(f"Unknown token at pos {i} of {seq!r}")
    return tokens

def filter_sequences_by_token_length(
    input_path: str,
    token_list: List[str],
    target_len: int,
    output_path: str,
    allow_unknown: bool = False
) -> None:
    """
    Reads raw sequences (one per line), tokenizes each, and writes only
    those whose token-count == target_len into output_path.
    """
    trie = build_token_trie(token_list)
    processed = 0
    matched = 0
    start = time.time()

    with open(input_path, "r", encoding="utf-8") as fin, \
         open(output_path, "w", encoding="utf-8") as fout:
        for line in fin:
            seq = line.strip()
            if not seq:
                continue
            processed += 1
            try:
                toks = tokenize_sequence(seq, trie, allow_unknown)
            except ValueError as e:
                # you can log or skip; here we skip any bad lines
                # print(f"Skipping line {processed}: {e}")
                continue

            if len(toks) == target_len:
                fout.write(seq + "\n")
                matched += 1

    elapsed = time.time() - start
    print(
        f"Processed {processed} lines in {elapsed:.2f}s, "
        f"wrote {matched} sequences of token‐length {target_len} to {output_path}"
    )

In [None]:
filepath = "data/train.csv"
token_list = ['Br', 'N', ')', 'c', 'o', '6', 's', 'Cl', '=', '2', ']', 'C', 'n', 'O', '4', '1', '#', 'S', 'F', '3', '[', '5', 'H', '(', '-', '[BOS]', '[EOS]', '[PAD]']
valid_tokens = set(token_list)

length_counts = count_sequence_lengths(
        seq_filepath=filepath,
        token_list=token_list,
        max_len=57,
        allow_unknown=False
    )

for length, cnt in length_counts.items():
    print(f"Length {length:2d}: {cnt}")

In [10]:
target_length = 26               # for example, only keep sequences of 10 tokens
out_file = "data/seqs_len26.txt"
filepath = "data/train.csv"
token_list = ['Br', 'N', ')', 'c', 'o', '6', 's', 'Cl', '=', '2', ']', 'C', 'n', 'O', '4', '1', '#', 'S', 'F', '3', '[', '5', 'H', '(', '-', '[BOS]', '[EOS]', '[PAD]']

filter_sequences_by_token_length(
    input_path=filepath,
    token_list=token_list,
    target_len=target_length,
    output_path=out_file,
    allow_unknown=False
)

Processed 1584663 lines in 5.99s, wrote 101709 sequences of token‐length 40 to data/seqs_len40.txt


In [14]:
for i in (list(range(26,46))):
    target_length = i
    out_file = f"data/seqs_len{i}.txt"
    filepath = "data/train.csv"
    token_list = ['Br', 'N', ')', 'c', 'o', '6', 's', 'Cl', '=', '2', ']', 'C', 'n', 'O', '4', '1', '#', 'S', 'F', '3', '[', '5', 'H', '(', '-', '[BOS]', '[EOS]', '[PAD]']
    filter_sequences_by_token_length(
    input_path=filepath,
    token_list=token_list,
    target_len=target_length,
    output_path=out_file,
    allow_unknown=False
    )

Processed 1584663 lines in 5.93s, wrote 15152 sequences of token‐length 26 to data/seqs_len26.txt
Processed 1584663 lines in 5.97s, wrote 25381 sequences of token‐length 27 to data/seqs_len27.txt
Processed 1584663 lines in 5.99s, wrote 38700 sequences of token‐length 28 to data/seqs_len28.txt
Processed 1584663 lines in 6.03s, wrote 53108 sequences of token‐length 29 to data/seqs_len29.txt
Processed 1584663 lines in 5.96s, wrote 71316 sequences of token‐length 30 to data/seqs_len30.txt
Processed 1584663 lines in 6.06s, wrote 84954 sequences of token‐length 31 to data/seqs_len31.txt
Processed 1584663 lines in 6.16s, wrote 95481 sequences of token‐length 32 to data/seqs_len32.txt
Processed 1584663 lines in 6.09s, wrote 106526 sequences of token‐length 33 to data/seqs_len33.txt
Processed 1584663 lines in 6.08s, wrote 113634 sequences of token‐length 34 to data/seqs_len34.txt
Processed 1584663 lines in 6.10s, wrote 120363 sequences of token‐length 35 to data/seqs_len35.txt
Processed 1584663