### Partition

Separates elements in a collection into multiple output
collections. The partitioning function contains the logic that determines how
to separate the elements of the input collection into each resulting
partition output collection.


The number of partitions must be determined at graph construction time.
You cannot determine the number of partitions in mid-pipeline


`Partition` accepts a function that receives the number of partitions,
and returns the index of the desired partition for the element.
The number of partitions passed must be a positive integer,
and it must return an integer in the range `0` to `num_partitions-1`.

In [3]:
import apache_beam as beam

durations = ['annual', 'biennial', 'perennial']

def by_duration(plant, num_partitions):
    return durations.index(plant["duration"])

with beam.Pipeline() as pipeline:
    annuals, biennials, perennials = (
        pipeline
        | "Gardening plants" >> beam.Create([
            {'icon': '🍓', 'name': 'Strawberry', 'duration': 'perennial'},
            {'icon': '🥕', 'name': 'Carrot', 'duration': 'biennial'},
            {'icon': '🍆', 'name': 'Eggplant', 'duration': 'perennial'},
            {'icon': '🍅', 'name': 'Tomato', 'duration': 'annual'},
            {'icon': '🥔', 'name': 'Potato', 'duration': 'perennial'},
        ])
        | 'Partition' >> beam.Partition(by_duration, len(durations))
    )

    annuals | 'Annuals' >> beam.Map(lambda x: print(f'annual: {x}'))
    biennials | 'Biennials' >> beam.Map(lambda x: print(f'biennial: {x}'))
    perennials | 'Perennials' >> beam.Map(lambda x: print(f'perennial: {x}'))

perennial: {'icon': '🍓', 'name': 'Strawberry', 'duration': 'perennial'}
biennial: {'icon': '🥕', 'name': 'Carrot', 'duration': 'biennial'}
perennial: {'icon': '🍆', 'name': 'Eggplant', 'duration': 'perennial'}
annual: {'icon': '🍅', 'name': 'Tomato', 'duration': 'annual'}
perennial: {'icon': '🥔', 'name': 'Potato', 'duration': 'perennial'}


## Partition with a lambda function

In [4]:
import apache_beam as beam

durations = ['annual', 'biennial', 'perennial']

def by_duration(plant, num_partitions):
    return durations.index(plant["duration"])

with beam.Pipeline() as pipeline:
    annuals, biennials, perennials = (
        pipeline
        | "Gardening plants" >> beam.Create([
            {'icon': '🍓', 'name': 'Strawberry', 'duration': 'perennial'},
            {'icon': '🥕', 'name': 'Carrot', 'duration': 'biennial'},
            {'icon': '🍆', 'name': 'Eggplant', 'duration': 'perennial'},
            {'icon': '🍅', 'name': 'Tomato', 'duration': 'annual'},
            {'icon': '🥔', 'name': 'Potato', 'duration': 'perennial'},
        ])
        | 'Partition' >> beam.Partition(
            lambda plant, num_partitions: durations.index(plant['duration']),
            len(durations)
        )
    )

    annuals | 'Annuals' >> beam.Map(lambda x: print(f'annual: {x}'))
    biennials | 'Biennials' >> beam.Map(lambda x: print(f'biennial: {x}'))
    perennials | 'Perennials' >> beam.Map(lambda x: print(f'perennial: {x}'))

perennial: {'icon': '🍓', 'name': 'Strawberry', 'duration': 'perennial'}
biennial: {'icon': '🥕', 'name': 'Carrot', 'duration': 'biennial'}
perennial: {'icon': '🍆', 'name': 'Eggplant', 'duration': 'perennial'}
annual: {'icon': '🍅', 'name': 'Tomato', 'duration': 'annual'}
perennial: {'icon': '🥔', 'name': 'Potato', 'duration': 'perennial'}


## Partition with multiple arguments

In [7]:
import apache_beam as beam
import json

def split_dataset(plant, num_partitions, ratio):
    assert num_partitions == len(ratio)
    bucket = sum(map(ord, json.dumps(plant))) % sum(ratio)
    total = 0
    for i, part in enumerate(ratio):
        total += part
        if bucket < total:
            return i
    return len(ratio) - 1

with beam.Pipeline() as pipeline:
    train_dataset, test_dataset = (
        pipeline
        | "Gardening plants" >> beam.Create([
            {'icon': '🍓', 'name': 'Strawberry', 'duration': 'perennial'},
            {'icon': '🥕', 'name': 'Carrot', 'duration': 'biennial'},
            {'icon': '🍆', 'name': 'Eggplant', 'duration': 'perennial'},
            {'icon': '🍅', 'name': 'Tomato', 'duration': 'annual'},
            {'icon': '🥔', 'name': 'Potato', 'duration': 'perennial'},
        ])
        | "Partition" >> beam.Partition(split_dataset, 2, ratio=[8,2])
    )

    train_dataset | 'Train' >> beam.Map(lambda x: print(f'train: {x}'))
    test_dataset | 'Test' >> beam.Map(lambda x: print(f'train: {x}'))

train: {'icon': '🍓', 'name': 'Strawberry', 'duration': 'perennial'}
train: {'icon': '🥕', 'name': 'Carrot', 'duration': 'biennial'}
train: {'icon': '🍆', 'name': 'Eggplant', 'duration': 'perennial'}
train: {'icon': '🍅', 'name': 'Tomato', 'duration': 'annual'}
train: {'icon': '🥔', 'name': 'Potato', 'duration': 'perennial'}
