In [2]:
import random
import string
from tqdm import tqdm

# Define constants
MAX_LENGTH = 500
OOD_MAX_LENGTH = 1500
VALID_CHARACTERS = ["s", "a", "b", "e", "p"]
MAIN_CHARACTERS = ["a", "b"]
START_TOKEN = "s"
END_TOKEN = "e"
PADDING_TOKEN = "p"
VALID_RATIO = 0.5  # Half of the dataset should be valid a*b* strings


def generate_random_numbers_with_sum(target_sum):
    numbers = sorted([random.randint(0, target_sum) for _ in range(3)])
    numbers = [0] + numbers + [target_sum]
    result = [numbers[i + 1] - numbers[i] for i in range(4)]
    return result


# Function to generate valid a*b* strings
def generate_valid_string():
    # need [..p..]s[..a..][..b..]e[..p..]
    # generate four numbers which sum to length
    numbers = generate_random_numbers_with_sum(MAX_LENGTH - 2)

    return (
        PADDING_TOKEN * numbers[0]
        + START_TOKEN
        + "a" * numbers[1]
        + "b" * numbers[2]
        + END_TOKEN
        + PADDING_TOKEN * numbers[3]
    )


def is_valid_string(s):
    # Vaid if string is of the form [..p..]s[..a..][..b..]e[..p..]
    # Check via regex
    import re

    return bool(re.match(r"p*sa+b+ep*", s))


# Function to generate invalid strings with random content
def generate_invalid_string_1():
    numbers = generate_random_numbers_with_sum(MAX_LENGTH - 2)
    ab_length = numbers[1] + numbers[2]
    while True:
        invalid_str = "".join(random.choices(["a", "b"], k=ab_length))
        if not is_valid_string(invalid_str):
            break
    return (
        PADDING_TOKEN * numbers[0]
        + START_TOKEN
        + invalid_str
        + END_TOKEN
        + PADDING_TOKEN * numbers[3]
    )


# Function to slightly modify valid strings so that they become invalid
def generate_invalid_string_2():
    valid_str = generate_valid_string()
    valid_str = list(valid_str)
    s_index = valid_str.index("s")
    e_index = valid_str.index("e")
    while True:
        index = random.randint(s_index + 1, e_index - 1)
        if valid_str[index] == "a":
            valid_str[index] = "b"
        else:
            valid_str[index] = "a"
        if not is_valid_string("".join(valid_str)):
            break
    return "".join(valid_str)


def generate_invalid_string():
    if random.random() < 0.8:
        return generate_invalid_string_1()
    else:
        return generate_invalid_string_2()


# Generate dataset
dataset = []
num_samples = 200000  # Total number of samples

for _ in tqdm(range(num_samples)):
    x = generate_valid_string()
    dataset.append((x, 1))


# remove duplicates
valid_dataset = list(set(dataset))
dataset = []

for _ in tqdm(range(num_samples)):
    x = generate_invalid_string()
    dataset.append((x, 0))

# Remove all duplicates
invalid_dataset = list(set(dataset))[: len(valid_dataset)]
print(len(valid_dataset))
print(len(invalid_dataset))

split = len(valid_dataset) * 4 // 5
train_dataset = valid_dataset[:split] + invalid_dataset[:split]
test_dataset = valid_dataset[split:] + invalid_dataset[split:]

# num_ood_samples = 50000
# dataset = []
# for _ in tqdm(range(num_ood_samples)):
#     while True:
#         x = generate_valid_string(min_length=MAX_LENGTH + 2, max_length=OOD_MAX_LENGTH)
#         if not (len(x) == OOD_MAX_LENGTH):
#             continue
#         dataset.append((x, 1))
#         break

# # remove duplicates
# ood_valid_dataset = list(set(dataset))
# dataset = []

# for _ in tqdm(range(num_ood_samples)):
#     while True:
#         x = generate_invalid_string(
#             min_length=MAX_LENGTH + 2, max_length=OOD_MAX_LENGTH
#         )
#         if not (len(x) == OOD_MAX_LENGTH):
#             continue
#         dataset.append((x, 0))
#         break

# # Remove all duplicates
# ood_invalid_dataset = list(set(dataset))[: len(ood_valid_dataset)]
# print(len(ood_valid_dataset), len(ood_invalid_dataset))

# ood_dataset = ood_valid_dataset + ood_invalid_dataset
# print(len(train_dataset))
# print(len(test_dataset))
# print(len(ood_dataset))


# Write to file
with open("train_dataset.txt", "w") as f:
    for data, label in train_dataset:
        f.write(f"{data} {label}\n")

with open("test_dataset.txt", "w") as f:
    for data, label in test_dataset:
        f.write(f"{data} {label}\n")

# with open("ood_dataset.txt", "w") as f:
#     for data, label in ood_dataset:
#         f.write(f"{data} {label}\n")

100%|██████████| 200000/200000 [00:00<00:00, 326082.83it/s]
100%|██████████| 200000/200000 [00:03<00:00, 50679.24it/s]


199082
199082
