# Generate bAbI-style data with the Python port

This notebook demonstrates simple synthetic story generation using `python/babi`.

In [None]:
from pathlib import Path
import random

# Make the local python package importable when running from repo root
import sys
sys.path.insert(0, str(Path.cwd() / 'python'))

from babi import World, Clause, actions

In [None]:
def build_world():
    world = World()

    for actor in ['john', 'mary', 'daniel', 'sandra']:
        world.create_entity(actor, {'is_actor': True, 'is_thing': True, 'size': 2, 'is_gettable': False})

    for location in ['kitchen', 'garden', 'hallway', 'office']:
        world.create_entity(location, {'is_location': True, 'is_thing': True, 'size': 100})

    for obj in ['milk', 'apple', 'football']:
        world.create_entity(obj, {'is_gettable': True, 'is_thing': True, 'size': 1})

    return world


def generate_story(num_steps=8, seed=0):
    random.seed(seed)
    world = build_world()

    actors = [world.entities[n] for n in ['john', 'mary', 'daniel', 'sandra']]
    locations = [world.entities[n] for n in ['kitchen', 'garden', 'hallway', 'office']]
    objects = [world.entities[n] for n in ['milk', 'apple', 'football']]

    story = []
    for _ in range(num_steps):
        if random.random() < 0.6:
            actor = random.choice(actors)
            location = random.choice(locations)
            clause = Clause(world, True, world.god(), actions['set'], actor, 'is_in', location)
        else:
            actor = random.choice(actors)
            obj = random.choice(objects)
            if random.random() < 0.5:
                clause = Clause(world, True, world.god(), actions['set'], obj, 'is_in', random.choice(locations))
            else:
                clause = Clause(world, True, actor, actions['get'], obj)

        if clause.is_valid():
            clause.perform()
            story.append(clause)

    return story


In [None]:
story = generate_story(num_steps=12, seed=7)
for i, clause in enumerate(story, start=1):
    args = ' '.join(getattr(arg, 'name', str(arg)) for arg in clause.args)
    print(f'{i:>2} {clause.actor.name} {clause.action} {args}')

In [None]:
def write_dataset(path, n_stories=100, steps=10):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    with path.open('w', encoding='utf-8') as f:
        for sid in range(n_stories):
            story = generate_story(num_steps=steps, seed=sid)
            for i, clause in enumerate(story, start=1):
                args = ' '.join(getattr(arg, 'name', str(arg)) for arg in clause.args)
                f.write(f'{i} {clause.actor.name} {clause.action} {args}\n')
            f.write('\n')


write_dataset('data/python/generated_stories.txt', n_stories=20, steps=12)
print('Wrote data/python/generated_stories.txt')