In [32]:
from typing import Sequence, List
from torch import Tensor
import numpy as np

class PolyDataIterator:
    
    def __init__(self, batch_size=32):
        self.batch_size = batch_size
        
    def __call__(self, inputs: Sequence[Tensor], targets: Sequence[Tensor]) -> List[List[Tensor]]:
        PolyDataIterator._validate(inputs, targets)

        set_sizes = [len(tar) for tar in targets]
        strata_sizes = [size / sum(set_sizes) for size in set_sizes]
        intervals = np.arange(0, sum(set_sizes), self.batch_size)
        for start in intervals:
            end = start + self.batch_size

            outputs = []
            for inp, tar, strata in zip(inputs, targets, strata_sizes):
                start_idx = round(start * strata)
                end_idx = round(end * strata)
                inp_ = inp[start_idx: end_idx]
                tar_ = tar[start_idx: end_idx]
                outputs.append([inp_, tar_])

            yield outputs

    def _validate(inputs, targets):
        for inp, tar in zip(inputs, targets):
            if len(inp) != len(tar):
                raise AttributeError(f"Mismatching number of samples for a dataset "
                                     f"and the corresponding targets ({len(inp)} and {len(tar)})")

In [34]:
import torch
torch.manual_seed(0)

inputs1 = torch.randint(low=0, high=21, size=(5, 10))
targets1 = torch.randint(low=0, high=2, size=(5, 10))

inputs2 = torch.randint(low=0, high=21, size=(10, 10))
targets2 = torch.randint(low=0, high=2, size=(10, 10))

inputs3 = torch.randint(low=0, high=21, size=(15, 10))
targets3 = torch.randint(low=0, high=2, size=(15, 10))

print(inputs1)
print(inputs2)
print(inputs3)

tensor([[11,  3, 14,  9,  1,  9, 13, 10,  4, 12],
        [ 2, 14, 15, 12, 16, 11, 18,  9, 11,  9],
        [20,  2, 11,  2, 11, 11,  1, 17, 15, 18],
        [ 1,  0, 11, 12, 10, 18, 10,  2,  8,  4],
        [11,  0,  7, 19, 13, 11, 19,  4, 11, 19]])
tensor([[17, 12,  7, 18,  2,  6, 15,  5,  2, 17],
        [11, 16, 20,  8, 16, 18, 10, 18, 16,  3],
        [18,  4,  1, 13,  7,  5,  1, 19,  7,  1],
        [ 7,  4, 13, 16, 14, 16, 11, 20,  9,  4],
        [ 0,  1, 12,  6,  3, 10, 12, 11, 13, 16],
        [18, 10,  6,  7,  2,  7, 11,  2, 11, 17],
        [16, 14,  8, 20,  8,  0,  9, 15,  7, 16],
        [ 9,  4,  5, 15, 18,  9,  7, 20,  7,  4],
        [ 3, 15,  6,  3, 12,  8, 17, 14, 11,  9],
        [ 5,  2,  0,  0, 11, 10, 11, 13, 12,  5]])
tensor([[18,  2,  2, 19,  6, 15,  2,  3, 14, 17],
        [11,  1,  0, 13, 20, 20,  2,  6, 13,  2],
        [ 6, 10,  9,  6, 12, 16, 12, 10, 15, 16],
        [ 3,  8,  7,  4,  8, 17, 16,  6,  4, 18],
        [11,  4,  6, 10, 17, 18,  1,  1, 14,  9]

In [35]:
iterator = PolyDataIterator(batch_size=5)
for i, batch in enumerate(iterator(inputs=[inputs1, inputs2, inputs3], targets=[targets1, targets2, targets3])):
    print("\nBatch", i)
    print("set1 inputs:")
    print(batch[0][0])
    print("set1 targets:")
    print(batch[0][1])
    print("set2 inputs:")
    print(batch[1][0])
    print("set2 targets:")
    print(batch[1][1])
    print("set3 inputs:")
    print(batch[2][0])
    print("set3 targets:")
    print(batch[2][1])
    


Batch 0
set1 inputs:
tensor([[11,  3, 14,  9,  1,  9, 13, 10,  4, 12]])
set1 targets:
tensor([[0, 1, 1, 1, 1, 0, 1, 0, 0, 1]])
set2 inputs:
tensor([[17, 12,  7, 18,  2,  6, 15,  5,  2, 17],
        [11, 16, 20,  8, 16, 18, 10, 18, 16,  3]])
set2 targets:
tensor([[1, 0, 1, 0, 0, 1, 1, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 1, 0, 1]])
set3 inputs:
tensor([[18,  2,  2, 19,  6, 15,  2,  3, 14, 17],
        [11,  1,  0, 13, 20, 20,  2,  6, 13,  2]])
set3 targets:
tensor([[0, 1, 0, 1, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 1, 1]])

Batch 1
set1 inputs:
tensor([[ 2, 14, 15, 12, 16, 11, 18,  9, 11,  9]])
set1 targets:
tensor([[1, 0, 1, 0, 1, 0, 0, 0, 0, 0]])
set2 inputs:
tensor([[18,  4,  1, 13,  7,  5,  1, 19,  7,  1]])
set2 targets:
tensor([[0, 0, 0, 1, 1, 1, 0, 0, 1, 1]])
set3 inputs:
tensor([[ 6, 10,  9,  6, 12, 16, 12, 10, 15, 16],
        [ 3,  8,  7,  4,  8, 17, 16,  6,  4, 18],
        [11,  4,  6, 10, 17, 18,  1,  1, 14,  9]])
set3 targets:
tensor([[0, 1, 1, 1, 0, 0, 1,

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=b2f14aee-af04-4db5-af55-57a3a58b9f40' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>