In [4]:
import random
import string
from tqdm import tqdm

# Define constants
MAX_LENGTH = 50
VALID_CHARACTERS = ["s", "a", "b", "e", "p"]
MAIN_CHARACTERS = ["a", "b"]
START_TOKEN = "s"
END_TOKEN = "e"
PADDING_TOKEN = "p"
VALID_RATIO = 0.5  # Half of the dataset should be valid a*b* strings


def generate_random_numbers_with_sum(target_sum):
    while True:
        numbers = sorted([random.randint(1, target_sum - 1) for _ in range(3)])
        if len(set(numbers)) == 3:
            break

    numbers = [0] + numbers + [target_sum]
    result = [numbers[i + 1] - numbers[i] for i in range(4)]
    return result


def generate_balanced_parentheses(n):
    # Generate a random string of balanced parentheses
    if n == 1:
        return "()"
    # 50% chance of returning ( + generate_balanced_parentheses(n-1) + )
    if random.random() < 0.5:
        return "(" + generate_balanced_parentheses(n - 1) + ")"
    # Otherwise choose a random split and add
    else:
        split = random.randint(1, n - 1)
        return generate_balanced_parentheses(split) + generate_balanced_parentheses(
            n - split
        )


def is_parens_balanced(s):
    # Check if a string of parentheses is balanced
    balance = 0
    for c in s:
        if c == "(":
            balance += 1
        elif c == ")":
            balance -= 1
        if balance < 0:
            return False
    return balance == 0


# Function to generate valid a*b* strings
def generate_valid_string():
    # need [..p..]s[..a..][..b..]e[..p..]
    # generate four numbers which sum to length
    numbers = generate_random_numbers_with_sum(MAX_LENGTH - 2)
    while (numbers[1] + numbers[2]) % 2 == 1:
        numbers = generate_random_numbers_with_sum(MAX_LENGTH - 2)

    return (
        PADDING_TOKEN * numbers[0]
        + START_TOKEN
        + generate_balanced_parentheses((numbers[1] + numbers[2]) // 2)
        + END_TOKEN
        + PADDING_TOKEN * numbers[3]
    )


def generate_invalid_string():
    numbers = generate_random_numbers_with_sum(MAX_LENGTH - 2)
    while (numbers[1] + numbers[2]) % 2 == 1:
        numbers = generate_random_numbers_with_sum(MAX_LENGTH - 2)

    while True:
        str = generate_balanced_parentheses((numbers[1] + numbers[2]) // 2)
        # Shuffle str
        str = list(str)
        random.shuffle(str)
        str = "".join(str)
        if not is_parens_balanced(str):
            break

    return (
        PADDING_TOKEN * numbers[0]
        + START_TOKEN
        + str
        + END_TOKEN
        + PADDING_TOKEN * numbers[3]
    )


# Generate dataset
dataset = []
num_samples = 200000  # Total number of samples

for _ in tqdm(range(num_samples)):
    x = generate_valid_string()
    dataset.append((x, 1))


# remove duplicates
valid_dataset = list(set(dataset))
dataset = []

for _ in tqdm(range(num_samples)):
    x = generate_invalid_string()
    dataset.append((x, 0))

# Remove all duplicates
invalid_dataset = list(set(dataset))[: len(valid_dataset)]
print(len(valid_dataset))
print(len(invalid_dataset))

split = len(valid_dataset) * 4 // 5
train_dataset = valid_dataset[:split] + invalid_dataset[:split]
test_dataset = valid_dataset[split:] + invalid_dataset[split:]


# Write to file
folder = "/Users/sambhav/Documents/vscode_projects/notebookplayground/cs224n_final_project/balanced_parens_data/"
with open(folder + "train_dataset.txt", "w") as f:
    for data, label in train_dataset:
        f.write(f"{data} {label}\n")

with open(folder + "test_dataset.txt", "w") as f:
    for data, label in test_dataset:
        f.write(f"{data} {label}\n")

100%|██████████| 200000/200000 [00:02<00:00, 93794.21it/s]
100%|██████████| 200000/200000 [00:03<00:00, 57852.02it/s]


163360
163360
