In [1]:
import torch
from typing import List, Tuple, Callable, Any, Dict

def split_batch(batch: torch.Tensor, num_micro_batches: int) -> List[torch.Tensor]:
    """Split a batch into a list of microbatches"""
    if isinstance(batch, torch.Tensor):
        split_batch = batch.chunk(num_micro_batches)
    else: # batch is a list of tensors
        split_batch = []
        for tensor in batch:
            split_tensor = tensor.chunk(num_micro_batches)
            split_batch.append(split_tensor)
        split_batch = zip(*split_batch)
    return list(split_batch)

In [25]:
num_micro_batches = 4
batch = torch.rand(1, 28, 28)
batch = batch[0].view(-1, 28*28)
split_batch = batch.chunk(num_micro_batches)

In [28]:
split_batch[0].shape

torch.Size([1, 784])

In [3]:
batch = [torch.arange(12).reshape(4,1,3) for _ in range(2)]
for tensor in batch:
    print(tensor.shape)

torch.Size([4, 1, 3])
torch.Size([4, 1, 3])


In [5]:
num_micro_batches = 2
split_batch = []
for tensor in batch:
    split_tensor = tensor.chunk(num_micro_batches)
    split_batch.append(split_tensor)

In [10]:
print(len(split_batch))
print(len(split_batch[0]))
print(len(split_batch[0][0]))
print(split_batch)

2
2
2
[(tensor([[[0, 1, 2]],

        [[3, 4, 5]]]), tensor([[[ 6,  7,  8]],

        [[ 9, 10, 11]]])), (tensor([[[0, 1, 2]],

        [[3, 4, 5]]]), tensor([[[ 6,  7,  8]],

        [[ 9, 10, 11]]]))]


In [18]:
import os
import random
import numpy as np

import torch

def init_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def build_model():
    linear1 = torch.nn.Linear(28 * 28, 28)
    relu = torch.nn.ReLU()
    linear2 = torch.nn.Linear(28, 10)
    return torch.nn.Sequential(linear1, relu, linear2)

In [19]:
model = build_model()
print(len(model))

3


[(tensor([[[0, 1, 2]],
  
          [[3, 4, 5]]]),
  tensor([[[0, 1, 2]],
  
          [[3, 4, 5]]])),
 (tensor([[[ 6,  7,  8]],
  
          [[ 9, 10, 11]]]),
  tensor([[[ 6,  7,  8]],
  
          [[ 9, 10, 11]]]))]