In [23]:
import numpy as np
import torch
from torch import Tensor
from typing import List

In [38]:
def minibatches(*tensors: List[Tensor], batch_size: int = 32, shuffle: bool = True):
    full_size = tensors[0].shape[0]
    for tensor in tensors:
        assert tensor.shape[0] == full_size, "One of the tensors has a different batch size"
        
    indices = np.random.permutation(full_size)
    
    for i in range(0, full_size, batch_size):
        idx = indices[slice(i, i+batch_size)]
        
        yield [tensor[idx, ...] for tensor in tensors]

In [29]:
for i in range(0, 100, 32):
    print(i, i+32)

0 32
32 64
64 96
96 128


In [25]:
obs_batch = torch.randn(1000, 10)
action_batch = torch.randn(1000)
value_batch = torch.randn(1000, 1)

In [39]:
for obs, action, value in minibatches(obs_batch, action_batch, value_batch, batch_size=256):
    print(obs.shape)
    print(action.shape)
    print(value.shape)
    print()

torch.Size([256, 10])
torch.Size([256])
torch.Size([256, 1])

torch.Size([256, 10])
torch.Size([256])
torch.Size([256, 1])

torch.Size([256, 10])
torch.Size([256])
torch.Size([256, 1])

torch.Size([232, 10])
torch.Size([232])
torch.Size([232, 1])



In [16]:
np.random.permutation(10)

array([6, 9, 2, 5, 3, 7, 1, 4, 0, 8])

In [5]:
minibatches(1, 2, 3, 32, True, batch_size=16, shuffle=False)

(1, 2, 3, 32, True)
16
False
