In [7]:
import random
import string
from tqdm import tqdm

# Define constants
MAX_LENGTH = 500
OOD_MAX_LENGTH = 1500
VALID_CHARACTERS = ["s", "a", "b", "c", "d", "x", "e", "p"]
MAIN_CHARACTERS = ["a", "b", "c", "d", "x"]
START_TOKEN = "s"
END_TOKEN = "e"
PADDING_TOKEN = "p"
VALID_RATIO = 0.5  # Half of the dataset should be valid


# Function to generate valid
# min_length is without padding, max_length is with padding
def generate_valid_string(min_length=4, max_length=MAX_LENGTH):
    length = random.randint(min_length, max_length - 2)

    # 80% chance of generating dense. 20% chance of generating sparse
    sparse = random.random() < 0.2
    if not sparse:
        # Generate random string of a,b,c,d,x. Ensure that a,b,c,d occur at least once
        while True:
            valid_str = ""
            for _ in range(length):
                valid_str += random.choice(MAIN_CHARACTERS)
            if (
                "a" not in valid_str
                or "b" not in valid_str
                or "c" not in valid_str
                or "d" not in valid_str
            ):
                continue
            break
    else:
        valid_str = "x" * length
        # Pick 4 random indices to replace with a,b,c,d
        indices = random.sample(range(length), 4)
        # Randomly permute indices
        random.shuffle(indices)
        valid_str = list(valid_str)
        for i in range(4):
            valid_str[indices[i]] = MAIN_CHARACTERS[i]
        valid_str = "".join(valid_str)

    return (
        START_TOKEN + valid_str + END_TOKEN + PADDING_TOKEN * (max_length - length - 2)
    )


# Function to generate invalid strings
def generate_invalid_string(min_length=1, max_length=MAX_LENGTH):
    length = random.randint(min_length, max_length - 2)

    # 80% chance of generating dense. 20% chance of generating sparse
    sparse = random.random() < 0.2
    if not sparse:
        # Generate random string of a,b,c,d,x. Pick one to not appear
        forbidden = random.choice(MAIN_CHARACTERS[:-1])
        new_characters = [c for c in MAIN_CHARACTERS if c != forbidden]
        while True:
            valid_str = ""
            for _ in range(length):
                valid_str += random.choice(new_characters)
            break
    else:
        valid_str = "x" * length
        num_included = random.randint(0, min(3, length))
        abcd = MAIN_CHARACTERS[:-1]
        random.shuffle(abcd)
        abcd = abcd[:num_included]
        # Pick num_included random indices to replace with a,b,c,d
        indices = random.sample(range(length), num_included)
        # Randomly permute indices
        random.shuffle(indices)
        valid_str = list(valid_str)
        for i in range(num_included):
            valid_str[indices[i]] = abcd[i]
        valid_str = "".join(valid_str)

    return (
        START_TOKEN + valid_str + END_TOKEN + PADDING_TOKEN * (max_length - length - 2)
    )

In [None]:
# Generate dataset
dataset = []
num_samples = 2000000  # Total number of samples

for _ in tqdm(range(num_samples)):
    while True:
        x = generate_valid_string()
        if not (len(x) == MAX_LENGTH):
            continue
        x += PADDING_TOKEN * (OOD_MAX_LENGTH - MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 1))
        break

# remove duplicates
valid_dataset = list(set(dataset))
dataset = []

for _ in tqdm(range(num_samples)):
    while True:
        x = generate_invalid_string()
        if not (len(x) == MAX_LENGTH):
            continue
        x += PADDING_TOKEN * (OOD_MAX_LENGTH - MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 0))
        break

# Remove all duplicates
invalid_dataset = list(set(dataset))[: len(valid_dataset)]
print(len(valid_dataset))
print(len(invalid_dataset))

In [8]:
split = len(valid_dataset) * 4 // 5
train_dataset = valid_dataset[:split] + invalid_dataset[:split]
test_dataset = valid_dataset[split:] + invalid_dataset[split:]

num_ood_samples = 50000
dataset = []
for _ in tqdm(range(num_ood_samples)):
    while True:
        x = generate_valid_string(min_length=MAX_LENGTH + 2, max_length=OOD_MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 1))
        break

# remove duplicates
ood_valid_dataset = list(set(dataset))
dataset = []

for _ in tqdm(range(num_ood_samples)):
    while True:
        x = generate_invalid_string(
            min_length=MAX_LENGTH + 2, max_length=OOD_MAX_LENGTH
        )
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 0))
        break

# Remove all duplicates
ood_invalid_dataset = list(set(dataset))[: len(ood_valid_dataset)]
print(len(ood_valid_dataset), len(ood_invalid_dataset))

ood_dataset = ood_valid_dataset + ood_invalid_dataset
print(len(train_dataset))
print(len(test_dataset))
print(len(ood_dataset))
folder = "more_complex/"

# Write to file
with open(folder + "train_dataset.txt", "w") as f:
    for data, label in train_dataset:
        f.write(f"{data} {label}\n")

with open(folder + "test_dataset.txt", "w") as f:
    for data, label in test_dataset:
        f.write(f"{data} {label}\n")

with open(folder + "ood_dataset.txt", "w") as f:
    for data, label in ood_dataset:
        f.write(f"{data} {label}\n")

100%|██████████| 50000/50000 [00:12<00:00, 4133.18it/s]
100%|██████████| 50000/50000 [00:12<00:00, 4027.35it/s]


50000 48383
3183096
540931
98383
