In [2]:
import pandas as pd
import numpy as np
import torch
from onehotencoder import OneHotEncoder
from typing import List
from collections import Counter

In [3]:
def count_sequence_lengths(
    seq_filepath: str,
    token_list: list[str],
    max_len: int = 57,
    allow_unknown: bool = False
) -> dict[int, int]:

    # Sort tokens by descending length so we always match the longest possible first
    tokens_sorted = sorted(token_list, key=len, reverse=True)
    counts = Counter()

    with open(seq_filepath, 'r', encoding='utf-8') as f:
        for lineno, line in enumerate(f, 1):
            seq = line.rstrip('\n')
            i = 0
            tokenized = []
            while i < len(seq):
                for tok in tokens_sorted:
                    if seq.startswith(tok, i):
                        tokenized.append(tok)
                        i += len(tok)
                        break
                else:
                    if allow_unknown:
                        # emit single-char as fallback
                        tokenized.append(seq[i])
                        i += 1
                    else:
                        raise ValueError(
                            f"Unknown token at line {lineno}, position {i}: {seq[i:]!r}"
                        )

            L = len(tokenized)
            if 1 <= L <= max_len:
                counts[L] += 1

    # make sure every length from 1..max_len is present
    return {length: counts.get(length, 0) for length in range(1, max_len + 1)}


import time
from typing import List, Dict, Optional

class TokenTrieNode:
    __slots__ = ("children", "token_end")
    def __init__(self):
        self.children: Dict[str, TokenTrieNode] = {}
        self.token_end: Optional[str] = None

def build_token_trie(tokens: List[str]) -> TokenTrieNode:
    root = TokenTrieNode()
    for tok in tokens:
        node = root
        for ch in tok:
            node = node.children.setdefault(ch, TokenTrieNode())
        node.token_end = tok
    return root

def tokenize_sequence(seq: str, trie: TokenTrieNode, allow_unknown: bool=False) -> List[str]:
    """
    Greedy longest‐match tokenization using the trie.
    """
    tokens = []
    i = 0
    n = len(seq)
    while i < n:
        node = trie
        last_match: Optional[str] = None
        last_pos = i
        j = i
        # walk as far as possible in the trie
        while j < n and seq[j] in node.children:
            node = node.children[seq[j]]
            j += 1
            if node.token_end:
                last_match = node.token_end
                last_pos = j
        if last_match:
            tokens.append(last_match)
            i = last_pos
        else:
            if allow_unknown:
                tokens.append(seq[i])
                i += 1
            else:
                # you could also `continue` here to skip bad lines
                raise ValueError(f"Unknown token at pos {i} of {seq!r}")
    return tokens

def filter_sequences_by_token_length(
    input_path: str,
    token_list: List[str],
    target_len: int,
    output_path: str,
    allow_unknown: bool = False
) -> None:
    """
    Reads raw sequences (one per line), tokenizes each, and writes only
    those whose token-count == target_len into output_path.
    """
    trie = build_token_trie(token_list)
    processed = 0
    matched = 0
    start = time.time()

    with open(input_path, "r", encoding="utf-8") as fin, \
         open(output_path, "w", encoding="utf-8") as fout:
        for line in fin:
            seq = line.strip()
            if not seq:
                continue
            processed += 1
            try:
                toks = tokenize_sequence(seq, trie, allow_unknown)
            except ValueError as e:
                # you can log or skip; here we skip any bad lines
                # print(f"Skipping line {processed}: {e}")
                continue

            if len(toks) == target_len:
                fout.write(seq + "\n")
                matched += 1

    elapsed = time.time() - start
    print(
        f"Processed {processed} lines in {elapsed:.2f}s, "
        f"wrote {matched} sequences of token‐length {target_len} to {output_path}"
    )

In [4]:
filepath = "data/train.csv"
token_list = ['Br', 'N', ')', 'c', 'o', '6', 's', 'Cl', '=', '2', ']', 'C', 'n', 'O', '4', '1', '#', 'S', 'F', '3', '[', '5', 'H', '(', '-', '[BOS]', '[EOS]', '[PAD]']
valid_tokens = set(token_list)

length_counts = count_sequence_lengths(
        seq_filepath=filepath,
        token_list=token_list,
        max_len=57,
        allow_unknown=False
    )

for length, cnt in length_counts.items():
    print(f"Length {length:2d}: {cnt}")

Length  1: 0
Length  2: 0
Length  3: 0
Length  4: 0
Length  5: 0
Length  6: 0
Length  7: 0
Length  8: 0
Length  9: 0
Length 10: 0
Length 11: 0
Length 12: 0
Length 13: 21
Length 14: 13
Length 15: 19
Length 16: 72
Length 17: 63
Length 18: 138
Length 19: 230
Length 20: 497
Length 21: 1073
Length 22: 1792
Length 23: 3247
Length 24: 5108
Length 25: 9155
Length 26: 15152
Length 27: 25381
Length 28: 38700
Length 29: 53108
Length 30: 71316
Length 31: 84954
Length 32: 95481
Length 33: 106526
Length 34: 113634
Length 35: 120363
Length 36: 124303
Length 37: 126212
Length 38: 124680
Length 39: 113970
Length 40: 101709
Length 41: 83418
Length 42: 60024
Length 43: 43895
Length 44: 25990
Length 45: 15992
Length 46: 8943
Length 47: 4834
Length 48: 2511
Length 49: 1161
Length 50: 557
Length 51: 256
Length 52: 101
Length 53: 39
Length 54: 14
Length 55: 7
Length 56: 2
Length 57: 2


In [7]:
for i in (list(range(18,52))):
    target_length = i
    out_file = f"data/seqs_len{i}.txt"
    filepath = "data/train.csv"
    token_list = ['Br', 'N', ')', 'c', 'o', '6', 's', 'Cl', '=', '2', ']', 'C', 'n', 'O', '4', '1', '#', 'S', 'F', '3', '[', '5', 'H', '(', '-', '[BOS]', '[EOS]', '[PAD]']
    filter_sequences_by_token_length(
    input_path=filepath,
    token_list=token_list,
    target_len=target_length,
    output_path=out_file,
    allow_unknown=False
    )

Processed 1584663 lines in 6.21s, wrote 138 sequences of token‐length 18 to data/seqs_len18.txt
Processed 1584663 lines in 6.28s, wrote 230 sequences of token‐length 19 to data/seqs_len19.txt
Processed 1584663 lines in 6.23s, wrote 497 sequences of token‐length 20 to data/seqs_len20.txt
Processed 1584663 lines in 6.20s, wrote 1073 sequences of token‐length 21 to data/seqs_len21.txt
Processed 1584663 lines in 6.23s, wrote 1792 sequences of token‐length 22 to data/seqs_len22.txt
Processed 1584663 lines in 6.23s, wrote 3247 sequences of token‐length 23 to data/seqs_len23.txt
Processed 1584663 lines in 6.34s, wrote 5108 sequences of token‐length 24 to data/seqs_len24.txt
Processed 1584663 lines in 6.20s, wrote 9155 sequences of token‐length 25 to data/seqs_len25.txt
Processed 1584663 lines in 6.29s, wrote 15152 sequences of token‐length 26 to data/seqs_len26.txt
Processed 1584663 lines in 6.19s, wrote 25381 sequences of token‐length 27 to data/seqs_len27.txt
Processed 1584663 lines in 6.11

In [None]:
criterion = nn.CrossEntropyLoss()

#AdamW


charRNN.train()
#Typical training loop
print(f'Training for {num_epochs} epochs with {len(train_loader)} batches of size {batch_size}, {n_gram}-gram encoding, and {warmup_steps} lr warmup steps')
for epoch in range(num_epochs):
    start_time = time.time()
    total_epoch_loss = 0.0
    if epoch < anneal_epochs:
        current_beta = b_start + (b_end - b_start) * (epoch / anneal_epochs)
    else:
        current_beta = b_end
    for idx, (batch_inputs, batch_targets) in enumerate(train_loader):
        print(f'Batch {idx + 1}/{len(train_loader)}', end='\r')
        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.squeeze(2).to(device)
        current_batch_size = batch_inputs.size(0)
        seq_len = batch_inputs.size(1)
        batch_inputs = batch_inputs.view(current_batch_size, seq_len, n_gram * vocab_size).to(device)
        target_indices = torch.argmax(batch_targets, dim=2).long().to(device)

        hidden = charRNN.init_hidden(current_batch_size).to(device)

        logits, hidden = charRNN(batch_inputs, hidden)

        logits_permuted = logits.permute(0, 2, 1)

        reconstruction_loss = criterion(logits_permuted, target_indices)

        loss = reconstruction_loss
        loss.backward()
        nn.utils.clip_grad_norm_(charRNN.parameters(), 5.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        total_epoch_loss += loss.item()

    avg_epoch_loss = total_epoch_loss / len(train_loader)

    end_time = time.time()
    epoch_duration = end_time - start_time
    epoch_duration_minutes = int(epoch_duration // 60)
    epoch_duration_seconds = int(epoch_duration % 60)

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_epoch_loss}, Time: {epoch_duration_minutes}m {epoch_duration_seconds}s")
