In [None]:
# | default_exp utils/pipeline_parallelism

In [None]:
# | export


from collections import defaultdict

import torch
from torch import nn

# Utils

In [None]:
# | export


def get_device(device: torch.device | str) -> torch.device:
    """Convert to torch.device object."""
    if isinstance(device, str):
        return torch.device(device)
    return device

In [None]:
# | export


def move_to_device(data, device):
    """
    Move data to the specified device.

    Args:
        data: The data to move.
        device: The device to move the data to.

    Returns:
        The data moved to the specified device.
    """
    device = get_device(device)
    if isinstance(data, (list, tuple, set)):
        return type(data)(move_to_device(d, device) for d in data)
    elif isinstance(data, dict):
        return {k: move_to_device(v, device) for k, v in data.items()}
    elif isinstance(data, torch.Tensor):
        return data.to(device)
    else:
        return data

In [None]:
test = {"a": [torch.randn(1), (torch.randn(1), {torch.randn(1)})]}
display(test)

test = move_to_device(test, "cuda:0")
display(test)

[1m{[0m[32m'a'[0m: [1m[[0m[1;35mtensor[0m[1m([0m[1m[[0m[1;36m0.7945[0m[1m][0m[1m)[0m, [1m([0m[1;35mtensor[0m[1m([0m[1m[[0m[1;36m0.5051[0m[1m][0m[1m)[0m, [1m{[0m[1;35mtensor[0m[1m([0m[1m[[0m[1;36m2.3460[0m[1m][0m[1m)[0m[1m}[0m[1m)[0m[1m][0m[1m}[0m


[1m{[0m
    [32m'a'[0m: [1m[[0m
        [1;35mtensor[0m[1m([0m[1m[[0m[1;36m0.7945[0m[1m][0m, [33mdevice[0m=[32m'cuda:0'[0m[1m)[0m,
        [1m([0m[1;35mtensor[0m[1m([0m[1m[[0m[1;36m0.5051[0m[1m][0m, [33mdevice[0m=[32m'cuda:0'[0m[1m)[0m, [1m{[0m[1;35mtensor[0m[1m([0m[1m[[0m[1;36m2.3460[0m[1m][0m, [33mdevice[0m=[32m'cuda:0'[0m[1m)[0m[1m}[0m[1m)[0m
    [1m][0m
[1m}[0m

# Parallelize

In [None]:
# | export


class PipelineModule(nn.Module):
    def __init__(
        self, module: nn.Module, processing_device: torch.device | str, output_device: torch.device | str | None = None
    ):
        super().__init__()
        self.processing_device = get_device(processing_device)
        if output_device is None:
            output_device = processing_device
        self.output_device = get_device(output_device)
        self.module = module.to(processing_device)

    def forward(self, *args, **kwargs):
        args = move_to_device(args, self.processing_device)
        kwargs = move_to_device(kwargs, self.processing_device)
        output = self.module(*args, **kwargs)
        return move_to_device(output, self.output_device)

    def extra_repr(self):
        return f"processing_device={self.processing_device}, output_device={self.output_device}"

In [None]:
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)


module = MyModule()
test = PipelineModule(module, "cuda:0")
display(test)

sample_input = torch.randn(1, 10)
output = test(sample_input)
print(f"Output device: {output.device}")


[1;35mPipelineModule[0m[1m([0m
  [33mprocessing_device[0m=[35mcu[0m[1;92mda[0m[1;92m:0[0m, [33moutput_device[0m=[35mcu[0m[1;92mda[0m[1;92m:0[0m
  [1m([0mmodule[1m)[0m: [1;35mMyModule[0m[1m([0m
    [1m([0mlinear[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m10[0m, [33mout_features[0m=[1;36m10[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
[1m)[0m

Output device: cuda:0


In [None]:
# | export


def paralellize_pipeline(
    model: nn.Module, module_to_device: dict[str, torch.device | str | list[torch.device | str]]
) -> nn.Module:
    """
    Parallelize a model across multiple devices.

    Args:
        model: The model to parallelize.
        module_to_device: A dictionary mapping module names to devices. Keys are modules names with nested modules
            separated by dots (e.g., "module.submodule"). Note that the parallelism is performed using Level Order
            Traversal (i.e. BFS) of the model. Therefore the device of the deepest module in the dictionary will be
            overwritten even if it's parent is also specified in the dictionary. Value is either a device or a 2-tuple
            of devices. The first device is the processing device, and the second device is the output device.

    Returns:
        The parallelized pipeline.
    """
    # Do not allow parent pipeline parallelism to affect child pipeline parallelism
    if isinstance(model, PipelineModule):
        model = model.module

    # Convert all dictionary values to lists
    for key, value in module_to_device.items():
        if not isinstance(value, list):
            module_to_device[key] = [value, value]

    children_submodules_to_device: dict[str, dict[str, torch.device | str]] = defaultdict(dict)
    for module_name in sorted(module_to_device.keys()):
        # Identify the devices
        devices = module_to_device[module_name]
        if isinstance(devices, list):
            processing_device, output_device = devices
        else:
            processing_device, output_device = devices, devices

        # Get the module's name and ensure it is present in the model
        module_name_split = module_name.split(".")
        child_name = module_name_split[0]
        if not hasattr(model, child_name):
            raise ValueError(f"Module {child_name} not found in the model.")

        if len(module_name_split) == 1:
            # If it is the module in question, replace it with PipelineModule
            module = getattr(model, child_name)
            setattr(model, child_name, PipelineModule(module, processing_device, output_device))

            # This will always be the first element of the dictionary for every subtree, if at all. Update the
            # output_device of all the rest of they modules in the dictionary so that they are able to work together
            for key in module_to_device:
                if key.startswith(child_name + "."):
                    module_to_device[key][1] = processing_device
        else:
            # Add to the submodules_to_device
            children_submodules_to_device[child_name][".".join(module_name_split[1:])] = [
                processing_device,
                output_device,
            ]

    # Iterate over the submodules and assign them to the appropriate devices
    for child_name, submodules_to_device in children_submodules_to_device.items():
        module = getattr(model, child_name)
        paralellize_pipeline(module, submodules_to_device)

    return model

In [None]:
class MyModule1(nn.Module):
    def __init__(self):
        super().__init__()
        self.module = MyModule()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.module(x) + self.linear(x)


class MyModule2(nn.Module):
    def __init__(self):
        super().__init__()
        self.module11 = MyModule1()
        self.module12 = MyModule1()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.module11(x) + self.module12(x) + self.linear(x)


class MyModule3(nn.Module):
    def __init__(self):
        super().__init__()
        self.module2 = MyModule2()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.module2(x) + self.linear(x)


test = MyModule3().to("cuda:0")
display(test)

test = paralellize_pipeline(
    test,
    {
        "module2.module11": "cuda:1",
        "module2.module12": "cuda:3",
        "module2.module11.module": "cuda:2",
        "module2": ["cuda:3", "cuda:0"],
    },
)
display(test)

sample_input = torch.randn(1, 10).to("cuda:0")
output = test(sample_input)
print(f"Output device: {output.device}")


[1;35mMyModule3[0m[1m([0m
  [1m([0mmodule2[1m)[0m: [1;35mMyModule2[0m[1m([0m
    [1m([0mmodule11[1m)[0m: [1;35mMyModule1[0m[1m([0m
      [1m([0mmodule[1m)[0m: [1;35mMyModule[0m[1m([0m
        [1m([0mlinear[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m10[0m, [33mout_features[0m=[1;36m10[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m)[0m
      [1m([0mlinear[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m10[0m, [33mout_features[0m=[1;36m10[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m)[0m
    [1m([0mmodule12[1m)[0m: [1;35mMyModule1[0m[1m([0m
      [1m([0mmodule[1m)[0m: [1;35mMyModule[0m[1m([0m
        [1m([0mlinear[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m10[0m, [33mout_features[0m=[1;36m10[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m)[0m
      [1m([0mlinear[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m10[0m, [33mout_f


[1;35mMyModule3[0m[1m([0m
  [1m([0mmodule2[1m)[0m: [1;35mPipelineModule[0m[1m([0m
    [33mprocessing_device[0m=[35mcu[0m[1;92mda[0m[1;92m:3[0m, [33moutput_device[0m=[35mcu[0m[1;92mda[0m[1;92m:0[0m
    [1m([0mmodule[1m)[0m: [1;35mMyModule2[0m[1m([0m
      [1m([0mmodule11[1m)[0m: [1;35mPipelineModule[0m[1m([0m
        [33mprocessing_device[0m=[35mcu[0m[1;92mda[0m[1;92m:1[0m, [33moutput_device[0m=[35mcu[0m[1;92mda[0m[1;92m:3[0m
        [1m([0mmodule[1m)[0m: [1;35mMyModule1[0m[1m([0m
          [1m([0mmodule[1m)[0m: [1;35mPipelineModule[0m[1m([0m
            [33mprocessing_device[0m=[35mcu[0m[1;92mda[0m[1;92m:2[0m, [33moutput_device[0m=[35mcu[0m[1;92mda[0m[1;92m:1[0m
            [1m([0mmodule[1m)[0m: [1;35mMyModule[0m[1m([0m
              [1m([0mlinear[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m10[0m, [33mout_features[0m=[1;36m10[0m, [33mbias[0m=[3;92mTrue[0m

Output device: cuda:0


# nbdev

In [None]:
!nbdev_export