In [1]:
import torch
import torch.nn as nn
loss_fn = torch.nn.MSELoss(reduction="none")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class SingleInputSingleOutputModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 512)
    
    def forward(self, x, random_arg=123):
        print(f"x_shape={x.shape} random_arg={random_arg}")
        output = self.layer1(x)
        loss = loss_fn(output, torch.randn_like(output))
        return loss.mean(dim=1)

In [3]:
class SingleInputMultiOutputModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 512)
        self.layer2 = nn.Linear(1024, 256)
    
    def forward(self, x):
        print(f"x_shape={x.shape}")
        output1 = self.layer1(x)
        output2 = self.layer2(x)
        loss1 = loss_fn(output1, torch.randn_like(output1))
        loss2 = loss_fn(output2, torch.randn_like(output2))
        return loss1.mean(dim=1), loss2.mean(dim=1)

In [4]:
class MultiInputMultiOutputModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 512)
        self.layer2 = nn.Linear(1024, 256)
    
    def forward(self, x1, x2):
        print(f"x1_shape={x1.shape} x2_shape={x2.shape}")
        output1 = self.layer1(x1)
        output2 = self.layer2(x2)
        loss1 = loss_fn(output1, torch.randn_like(output1))
        loss2 = loss_fn(output2, torch.randn_like(output2))
        return loss1.mean(dim=1), loss2.mean(dim=1)

In [11]:
def run(device_ids, model_class, inputs=1):
    model = model_class()
    model = nn.DataParallel(model, device_ids=device_ids)
    print(f"model: {model}")
    
    main_device = torch.device(f"cuda:{device_ids[0]}")
    model.to(main_device)
    
    output = model(*[torch.randn(32, 1024).to(main_device) for _ in range(inputs)])
    if isinstance(output, tuple):
        for i, o in enumerate(output):
            print(f"output{i}_shape={o.shape}")
    else:
        print(f"output_shape={output.shape}")

In [6]:
device_ids = [6, 7]

In [7]:
run(device_ids, SingleInputSingleOutputModel)

model: DataParallel(
  (module): SingleInputSingleOutputModel(
    (layer1): Linear(in_features=1024, out_features=512, bias=True)
  )
)
x_shape=torch.Size([16, 1024]) random_arg=123
x_shape=torch.Size([16, 1024]) random_arg=123
output_shape=torch.Size([32])


In [8]:
run(device_ids, SingleInputMultiOutputModel)

model: DataParallel(
  (module): SingleInputMultiOutputModel(
    (layer1): Linear(in_features=1024, out_features=512, bias=True)
    (layer2): Linear(in_features=1024, out_features=256, bias=True)
  )
)
x_shape=torch.Size([16, 1024])
x_shape=torch.Size([16, 1024])
output0_shape=torch.Size([32])
output1_shape=torch.Size([32])


In [12]:
run(device_ids, MultiInputMultiOutputModel, inputs=2)

model: DataParallel(
  (module): MultiInputMultiOutputModel(
    (layer1): Linear(in_features=1024, out_features=512, bias=True)
    (layer2): Linear(in_features=1024, out_features=256, bias=True)
  )
)
x1_shape=torch.Size([16, 1024]) x2_shape=torch.Size([16, 1024])
x1_shape=torch.Size([16, 1024]) x2_shape=torch.Size([16, 1024])
output0_shape=torch.Size([32])
output1_shape=torch.Size([32])
