In [1]:
import json, random, re

In [59]:
global sentence_templates, entity_labels, label_value_options

sentence_templates = [
    "The TYPE had 3 COLOR lights around the edges with 1 COLOR light in the middle",
    "The TYPE was COLOR COLOR",
    "The TYPE was huge",
    "It was a COLOR TYPE, with three COLOR lights and a COLOR one in the center",
    "The TYPE object was glowing COLOR, like a sort of plasma",
    "There were three COLOR TYPE with a COLOR object behind them",
    "There were 10 COLOR TYPE with COLOR lights around the edges",
    "The TYPE was glowing COLOR",
    "I saw what looked like a TYPE hovering silently",
]
entity_labels = ["COLOR", "TYPE"]
label_value_options = {
    "COLOR": (
        "red",
        "green",
        "blue",
        # "white",
        # "silver",
        # "orange",
        # "grey",
        # "gray",
        # "purple",
        # "black",
        # "rainbow",
        # "pink",
        # "yellow",
        # "brown",
    ),
    "TYPE": (
        "Light",
        "Boomerang",
        "Triangle",
        "Wing",
        "Crescent",
        "Chevron",
        "Disk",
        "Saucer",
        "Cylinder",
        "Pyramid",
        "Tic tac",
        "Orb",
        "Globe",
        "Round",
        "Square",
        "Rectangle",
        "Cube",
        "Fireball",
        "Wheel",
        "Top",
        "Cigar",
        "Pill",
        "Starlike",
        "Rod",
        "Trapezoid",
        "Diamond",
        "Lightbulb",
        "Dome",
        "Dot",
        "Sphere",
        "Saucer",
        "Flying disk",
        "disks",
    ),
}


def print_debug(print_output, str):
    print(str) if print_output else None


def substring_range(text, substring):
    ents = []
    for i in re.finditer(re.escape(substring.upper()), text):
        ents.append((i.start(), i.end()))
    return ents


def shuffle_and_batch(batch_size, data, print_output):
    # {"classes": entity_labels, "annotations": [sent]}
    batches = []
    batch = []
    all_batches = []
    random.shuffle(data)
    count = 0
    for sent in data:
        if count < batch_size:
            batch.append(sent)
            count += 1
        else:
            batches.append(batch)
            batch = [sent]
            count = 1
    for b in batches:
        print_debug(print_output, f"\n{b}")
        all_batches.append({"classes": entity_labels, "annotations": b})
    return all_batches


def update_sentence(label, sent, options):
    new_sent = ""
    options_used = set()
    for option in options:
        new_sent = (new_sent if new_sent else sent).replace(label, option.upper(), 1)
        options_used.add(option)
        if not label in new_sent:
            break
    return new_sent, options_used


def get_all_for_label(generated_sentences, label_of_interest):
    output_for_label = []
    for old_gen_sentence in generated_sentences:
        label_options_used = set()
        options_list = set(label_value_options[label_of_interest])
        for option in options_list:
            ents = []
            options = options_list - label_options_used
            if not options:
                label_options_used.clear()
                to_shuffle = list(label_value_options[label_of_interest])
                random.shuffle(to_shuffle)
                options_list = set(to_shuffle)
                options = options_list
            new_sentence, options_used = update_sentence(
                label_of_interest, old_gen_sentence[0], options
            )
            label_options_used.update(options_used)
            new_ents = old_gen_sentence[1]["entities"] + ents if ents else []
            new_ents.sort()
            output_for_label.append([new_sentence, {"entities": new_ents}])
    return output_for_label


def find_ents(data, print_output):
    sents_with_ents = []
    for sent in data:
        ents = []
        for entity_label in entity_labels:
            for label_value in label_value_options[entity_label]:
                extracted = substring_range(sent[0], label_value.upper())
                if extracted:
                    ents.append([extracted[0][0], extracted[0][1], f"{entity_label}"])
            ents.sort()
            sents_with_ents.append([sent[0].lower().capitalize(), {"entities": ents}])
    print_debug(print_output, f"\nSents with ents: {sents_with_ents}")
    return sents_with_ents


def generate_annotations(print_output):
    for sent in sentence_templates:
        output = []
        for label in entity_labels:
            output = get_all_for_label(
                [[sent, {"entities": []}]] if not output else output,
                label,
            )
        print_debug(print_output, f"{len(output)} annotated sentences generated!")
        print_debug(print_output, f"Output example: {output[0]}")
        yield output


def run_generator(batch_size, print_output):
    annotations = generate_annotations(print_output)
    generated_data = []
    for _ in range(
        len(sentence_templates)
    ):  # change to range(len(sentences)) for all sentences
        sent_template_to_batch = next(annotations)
        print_debug(
            print_output,
            f"\nSentence to batch & shuffle example: {sent_template_to_batch[0]}",
        )
        with_ents = find_ents(sent_template_to_batch, print_output)
        training_data = shuffle_and_batch(
            batch_size, with_ents, print_output
        )  # one sentence of annotations
        # print the first generated annotation document to verify success (uncomment)
        print_debug(
            print_output,
            f"\nShuffled training data example: {training_data[0]}",
        )
        # count the annotated annotated documents in the new corpus
        print_debug(
            print_output,
            f"\nThe new training data has {len(training_data)} documents of {batch_size} annotated training lines each\n",
        )
        generated_data.append(training_data)
    return generated_data

def write_json(path, filename, data):
    with open(path + filename + '.json', 'w') as fp:
        json.dump(data, fp)

def make_annotations(print_output):
    save_path = './experiment_test_files/train/golds_small/'
    filename = 'golds'
    batch_size = 5
    generated_training_data = run_generator(
        batch_size, 0
    )  # batch size (int), print debug output? (boolean)
    print_debug(print_output, f"Result: {len(generated_training_data)} sentence annotations, with {len(generated_training_data[0])} batches each.")
    print_debug(print_output, generated_training_data[0])
    for s_idx, sent in enumerate(generated_training_data):
        for b_idx, batch in enumerate(sent):
            write_json(save_path, f'{filename}_{s_idx}_{b_idx}', batch)

make_annotations(0)