<a href="https://colab.research.google.com/github/alexshtf/alexshtf.github.io/blob/master/assets/batch_iter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch

In [2]:
device = torch.device('cuda:0')

In [3]:
n_features = 1000
n_samples = 500000
X = torch.randn(n_samples, n_features, device=device)
y = torch.randn(n_samples, device=device)

In [4]:
from torch import nn

def make_network():
	return nn.Sequential(
  	nn.Linear(n_features, n_features // 2),
  	nn.ReLU(),
  	nn.Linear(n_features // 2, n_features // 8),
  	nn.ReLU(),
    nn.Linear(n_features // 8, 1)
	)

# Measure time with DataLoader

In [5]:
net = make_network().to(device)
optim = torch.optim.SGD(net.parameters(), lr=1e-3)
criterion = nn.MSELoss()
ds = torch.utils.data.TensorDataset(X, y)

In [6]:
%%time
for Xb, yb in torch.utils.data.DataLoader(ds, batch_size=64, shuffle=True):
  loss = criterion(net(Xb).squeeze(), yb)
  loss.backward()
  optim.step()
  optim.zero_grad()

CPU times: user 12.3 s, sys: 414 ms, total: 12.7 s
Wall time: 13.4 s


In [7]:
%%time
for Xb, yb in torch.utils.data.DataLoader(ds, batch_size=64, shuffle=True):
	pass

CPU times: user 4.78 s, sys: 25.5 ms, total: 4.8 s
Wall time: 4.86 s


# Measure time with manual iteration

## Without shuffling

In [8]:
def iter_tensors(*tensors, batch_size):
  device = tensors[0].device  # we assume all tensors are on the same device
  n = tensors[0].size(0)
  idxs = torch.arange(n, device=device).split(batch_size)
  for batch_idxs in idxs:
    yield tuple((x[batch_idxs, ...] for x in tensors))

In [9]:
%%time
for Xb, yb in iter_tensors(X, y, batch_size=64):
	pass

CPU times: user 221 ms, sys: 13.8 ms, total: 235 ms
Wall time: 270 ms


## With shuffling

In [10]:
def iter_tensors_with_shuffle(*tensors, batch_size, shuffle=False):
  device = tensors[0].device  # we assume all tensors are on the same device
  n = tensors[0].size(0)
  if shuffle:
    idxs = torch.arange(n, device=device)
  else:
    idxs = torch.randperm(n, device=device)
  idxs = idxs.split(batch_size)
  for batch_idxs in idxs:
    yield tuple((x[batch_idxs, ...] for x in tensors))

In [11]:
%%time
for Xb, yb in iter_tensors_with_shuffle(X, y, batch_size=64, shuffle=True):
	pass

CPU times: user 213 ms, sys: 0 ns, total: 213 ms
Wall time: 212 ms


## With BatchIter

In [12]:
class BatchIter:
    def __init__(self, *tensors, batch_size, shuffle=True):
      """
      tensors: feature tensors (each with shape: num_samples x *)
      batch_size: int
      shuffle: bool (default: True) whether to iterate over randomly shuffled samples.
      """
      self.tensors = tensors

      device = tensors[0].device
      n = tensors[0].size(0)
      if shuffle:
          idxs = torch.randperm(n, device=device)
      else:
          idxs = torch.arange(n, device=device)

      self.idxs = idxs.split(batch_size)

    def __len__(self):
        return len(self.idxs)

    def __iter__(self):
        tensors = self.tensors
        for batch_idxs in self.idxs:
            yield tuple((x[batch_idxs, ...] for x in tensors))

In [13]:
from tqdm.auto import tqdm

In [14]:
%%time
for Xb, yb in tqdm(BatchIter(X, y, batch_size=64, shuffle=True)):
  pass

  0%|          | 0/7813 [00:00<?, ?it/s]

CPU times: user 309 ms, sys: 28.2 ms, total: 337 ms
Wall time: 425 ms


# With grouping

In [15]:
def lexsort(*keys, dim=-1):
    if len(keys) == 0:
        raise ValueError(f"Must have at least 1 key, but {len(keys)=}.")

    idx = keys[0].argsort(dim=dim, stable=True)
    for k in keys[1:]:
        idx = idx.gather(dim, k.gather(dim, idx).argsort(dim=dim, stable=True))

    return idx

In [16]:
first = torch.tensor([5, 3, 5, 3, 5, 5, 3])
second = torch.tensor([4, 1, 1, 3, 3, 2, 2])
order = lexsort(second, first)
print(first[order], second[order])

tensor([3, 3, 3, 5, 5, 5, 5]) tensor([1, 2, 3, 1, 2, 3, 4])


In [17]:
def view_as_bytes(x):
  element_bytes = x.dtype.itemsize
  bytes_tensor = x.view(torch.uint8).view(x.shape + (element_bytes,))
  return bytes_tensor.unbind(dim=-1)

In [18]:
view_as_bytes(4 ** torch.arange(8, dtype=torch.int16))

(tensor([ 1,  4, 16, 64,  0,  0,  0,  0], dtype=torch.uint8),
 tensor([ 0,  0,  0,  0,  1,  4, 16, 64], dtype=torch.uint8))

In [19]:
def fnv_hash(tensor):
    """
    Computes the FNV hash for each component of a PyTorch tensor of integers.
    Args:
      tensor: torch.tensor the tensor for which we compute element-wise FNV hash
    Returns:
      A PyTorch tensor of the same size and dtype as the input tensor, containing the FNV hash for each element.
    """
    # Define the FNV prime and offset basis
    FNV_PRIME = torch.tensor(0x01000193, dtype=torch.uint32)
    FNV_OFFSET = torch.tensor(0x811c9dc5, dtype=torch.uint32)

    # Initialize the hash value with zeros (same size and dtype as tensor)
    hash_value = torch.full_like(tensor, FNV_OFFSET)
    for byte in view_as_bytes(tensor):
        hash_value = torch.bitwise_xor(hash_value * FNV_PRIME, byte)

    # No need to reshape, output already has the same size and dtype as input
    return hash_value

In [20]:
group_id = torch.tensor([5, 5, 8, 8, 8, 8, 1, 1])
seed = 1
order = lexsort(group_id, fnv_hash(group_id + seed))
print(group_id[order])

tensor([5, 5, 1, 1, 8, 8, 8, 8])


In [21]:
seed = 2
order = lexsort(group_id, fnv_hash(group_id + seed))
print(group_id[order])

tensor([1, 1, 8, 8, 8, 8, 5, 5])


In [22]:
def group_idx(group_id):
  values, counts = group_id.unique_consecutive(return_counts=True)
  idx = torch.cumsum(counts, dim=-1)
  return torch.nn.functional.pad(idx, (1, 0))

In [23]:
group_id = torch.tensor([8, 8, 8, 1, 1, 7, 7, 7, 7])
indices = group_idx(group_id)
print(indices)

tensor([0, 3, 5, 9])


In [24]:
def batch_endpoint_indices(group_idx, batch_size):
  # pad group_idx to the smallest multiple of batch_size
  padding_size = batch_size - (len(group_idx) - batch_size * (len(group_idx) // batch_size))
  if padding_size > 0:
    padding = group_idx[-1].expand(padding_size)
    group_idx = torch.cat((group_idx, padding), dim=-1)

  # extract start and end points
  start_points = group_idx[0:-1:batch_size]
  end_points = group_idx[batch_size::batch_size]

  # return them as a list, so we can iterate over them
  return start_points.tolist(), end_points.tolist()

In [25]:
from_idx, to_idx = batch_endpoint_indices(group_idx(group_id), batch_size=2)
for start, end in zip(from_idx, to_idx):
  print(start, end)

0 5
5 9


In [26]:
from_idx, to_idx = batch_endpoint_indices(group_idx(group_id), batch_size=3)
for start, end in zip(from_idx, to_idx):
  print(start, end)

0 9


In [27]:
class GroupBatchIter:
  def __init__(self, group_id, *tensors, batch_size=1, shuffle=True, shuffle_seed=42):
    self.group_id = group_id
    self.tensors = tensors

    if shuffle:
      self.idxs = lexsort(group_id, fnv_hash(group_id + seed))
    else:
      self.idxs = torch.arange(len(group_id), device=group_id.device)

    group_start_indices = group_idx(group_id[self.idxs])
    self.batch_start, self.batch_end = batch_endpoint_indices(group_start_indices, batch_size)

  def __len__(self):
    return len(self.batch_start)


  def __iter__(self):
    # we create mini-batches containing both group-id, and the additional
    # tensors
    tensors = (self.group_id,) + self.tensors

    # iterate over batch endpoints, and yield tensors
    for start, end in zip(self.batch_start, self.batch_end):
      batch_idxs = self.idxs[start:end]
      if len(batch_idxs) > 0:
        yield tuple(x[batch_idxs, ...] for x in tensors)

In [28]:
import pandas as pd

group_id = torch.tensor([8, 8, 8, 1, 1, 7, 7, 7, 7])
features = torch.arange(len(group_id) * 3).reshape(len(group_id), 3)
labels = torch.arange(len(group_id)) % 2

print(pd.DataFrame.from_dict({
    'group_id': group_id.tolist(),
    'features': features.tolist(),
    'labels': labels.tolist()
}))

   group_id      features  labels
0         8     [0, 1, 2]       0
1         8     [3, 4, 5]       1
2         8     [6, 7, 8]       0
3         1   [9, 10, 11]       1
4         1  [12, 13, 14]       0
5         7  [15, 16, 17]       1
6         7  [18, 19, 20]       0
7         7  [21, 22, 23]       1
8         7  [24, 25, 26]       0


In [29]:
for gb, Xb, yb in GroupBatchIter(group_id, features, labels, batch_size=2, shuffle=True):
  print(pd.DataFrame.from_dict({
    'group_id': gb.tolist(),
    'features': Xb.tolist(),
    'labels': yb.tolist()
}))

   group_id      features  labels
0         1   [9, 10, 11]       1
1         1  [12, 13, 14]       0
2         8     [0, 1, 2]       0
3         8     [3, 4, 5]       1
4         8     [6, 7, 8]       0
   group_id      features  labels
0         7  [15, 16, 17]       1
1         7  [18, 19, 20]       0
2         7  [21, 22, 23]       1
3         7  [24, 25, 26]       0


In [30]:
n_groups = n_samples // 8
group_id, _ = torch.multinomial(torch.ones(n_groups) / n_groups, n_samples, replacement=True).sort()
print(group_id[:50])

tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5,
        6, 6])


In [31]:
%%time
for gb, Xb, yb in GroupBatchIter(group_id, X, y, batch_size=64, shuffle=True):
  pass

CPU times: user 180 ms, sys: 25 ms, total: 205 ms
Wall time: 206 ms
