In [1]:
import torch
from typing import List, Tuple, Callable, Any, Dict

def split_batch(batch: torch.Tensor, num_micro_batches: int) -> List[torch.Tensor]:
    """Split a batch into a list of microbatches"""
    if isinstance(batch, torch.Tensor):
        split_batch = batch.chunk(num_micro_batches)
    else: # batch is a list of tensors
        split_batch = []
        for tensor in batch:
            split_tensor = tensor.chunk(num_micro_batches)
            split_batch.append(split_tensor)
        split_batch = zip(*split_batch)
    return list(split_batch)

In [25]:
num_micro_batches = 4
batch = torch.rand(1, 28, 28)
batch = batch[0].view(-1, 28*28)
split_batch = batch.chunk(num_micro_batches)

In [28]:
split_batch[0].shape

torch.Size([1, 784])

In [3]:
batch = [torch.arange(12).reshape(4,1,3) for _ in range(2)]
for tensor in batch:
    print(tensor.shape)

torch.Size([4, 1, 3])
torch.Size([4, 1, 3])


In [5]:
num_micro_batches = 2
split_batch = []
for tensor in batch:
    split_tensor = tensor.chunk(num_micro_batches)
    split_batch.append(split_tensor)

In [10]:
print(len(split_batch))
print(len(split_batch[0]))
print(len(split_batch[0][0]))
print(split_batch)

2
2
2
[(tensor([[[0, 1, 2]],

        [[3, 4, 5]]]), tensor([[[ 6,  7,  8]],

        [[ 9, 10, 11]]])), (tensor([[[0, 1, 2]],

        [[3, 4, 5]]]), tensor([[[ 6,  7,  8]],

        [[ 9, 10, 11]]]))]


In [18]:
import os
import random
import numpy as np

import torch

def init_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def build_model():
    linear1 = torch.nn.Linear(28 * 28, 28)
    relu = torch.nn.ReLU()
    linear2 = torch.nn.Linear(28, 10)
    return torch.nn.Sequential(linear1, relu, linear2)

In [19]:
model = build_model()
print(len(model))

3


In [1]:
import torch
from torch import nn
from collections import OrderedDict

model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

In [2]:
len(model)

4

In [8]:
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
from benchmarks.unet_memory.unet import unet


B, C = 11, 128
model = unet(depth=5, num_convs=B, base_channels=C,
             input_channels=3, output_channels=1)
model = cast(nn.Sequential, model)

In [10]:
len(model)

505

In [18]:
def get_model_size(model: nn.Module):
    param_size = 0
    buffer_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    size_all_mb = (param_size + buffer_size) / 1024**2
    print('model size: {:.3f}MB'.format(size_all_mb))
    return size_all_mb

get_model_size(model)

model size: 8445.107MB


8445.107177734375

In [3]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

model = torchvision.models.resnet18(num_classes=10)

In [18]:
from collections import OrderedDict
from typing import Iterator, Tuple

from torch import nn
from typing import Iterator, Tuple, List, Dict, Any, Callable, Optional


def flatten_sequential(module: nn.Sequential) -> nn.Sequential:
    """flatten_sequentials a nested sequential module."""
    if not isinstance(module, nn.Sequential):
        raise TypeError('not sequential')

    return nn.Sequential(OrderedDict(_flatten_sequential(module)))


def _flatten_sequential(module: nn.Sequential) -> Iterator[Tuple[str, nn.Module]]:
    for name, child in module.named_children():
        # flatten_sequential child sequential layers only.
        if isinstance(child, nn.Sequential):
            for sub_name, sub_child in _flatten_sequential(child):
                yield (f'{name}_{sub_name}', sub_child)
        else:
            yield (name, child)

def split_module(
      module: torch.nn.Module, 
      partition_sizes: List[int], 
    ) -> nn.ModuleList:
    print('-' * 80)
    print('Splitting the module in MyGPipe...')
    layers = OrderedDict()
    partitions = []

    i = 0
    for name, layer in module.named_children():
        layers[name] = layer
        if len(layers) == partition_sizes[i]:
            # partitions.append(nn.Sequential(layers).to(devices[i]))
            partitions.append(nn.Sequential(layers))
            # print_mem_usage(devices[i])

            layers.clear()
            i += 1
    
    return torch.nn.ModuleList(partitions)


In [22]:
model = flatten_sequential(nn.Sequential(model))
# partitions = split_module(model, [2, 2, 3])
for name, layer in model.named_children():
    print(name, layer)


0_0_0_0 ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=

In [21]:
partitions

ModuleList()

In [15]:
len(model.)

1