Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# Multi GPU Test

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/acceleration/multi_gpu_test.ipynb)

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[ignite]"

## Setup imports

In [None]:
import torch
from monai.config import print_config
from monai.engines import create_multigpu_supervised_trainer
from monai.networks.nets import UNet

print_config()

MONAI version: 1.1.0+2.g97918e46
Numpy version: 1.22.2
Pytorch version: 1.13.0a0+d0d6b1f
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 97918e46e0d2700c050e678d72e3edb35afbd737
MONAI __file__: /workspace/monai/monai-in-dev/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.10
Nibabel version: 4.0.2
scikit-image version: 0.19.3
Pillow version: 9.0.1
Tensorboard version: 2.10.1
gdown version: 4.6.0
TorchVision version: 0.14.0a0
tqdm version: 4.64.1
lmdb version: 1.3.0
psutil version: 5.9.2
pandas version: 1.4.4
einops version: 0.6.0
transformers version: 4.21.3
mlflow version: 2.0.1
pynrrd version: 1.0.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



## Test GPUs

In [None]:
max_epochs = 2
lr = 1e-3
device = torch.device("cuda:0")
net = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)


def fake_loss(y_pred, y):
    return (y_pred[0] + y).sum()


def fake_data_stream():
    while True:
        yield torch.rand((10, 1, 64, 64)), torch.rand((10, 1, 64, 64))

### 1 GPU

In [None]:
opt = torch.optim.Adam(net.parameters(), lr)
trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [device])
trainer.run(fake_data_stream(), max_epochs=max_epochs, epoch_length=2)

2023-01-20 14:16:00,009 - Engine run starting with max_epochs=2.
2023-01-20 14:16:00,189 - Epoch[1] Complete. Time taken: 00:00:00.180
2023-01-20 14:16:00,201 - Epoch[2] Complete. Time taken: 00:00:00.011
2023-01-20 14:16:00,201 - Engine run complete. Time taken: 00:00:00.192


State:
	iteration: 4
	epoch: 2
	epoch_length: 2
	max_epochs: 2
	output: 23339.560546875
	batch: <class 'tuple'>
	metrics: <class 'dict'>
	dataloader: <class 'generator'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

### all GPUs

In [None]:
opt = torch.optim.Adam(net.parameters(), lr)
trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, None)
trainer.run(fake_data_stream(), max_epochs=max_epochs, epoch_length=2)

2023-01-20 14:16:00,208 - Engine run starting with max_epochs=2.
2023-01-20 14:16:01,364 - Epoch[1] Complete. Time taken: 00:00:01.154
2023-01-20 14:16:01,391 - Epoch[2] Complete. Time taken: 00:00:00.026
2023-01-20 14:16:01,391 - Engine run complete. Time taken: 00:00:01.181


State:
	iteration: 4
	epoch: 2
	epoch_length: 2
	max_epochs: 2
	output: 22608.560546875
	batch: <class 'tuple'>
	metrics: <class 'dict'>
	dataloader: <class 'generator'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

### CPU

In [None]:
net = net.to(torch.device("cpu:0"))
opt = torch.optim.Adam(net.parameters(), lr)
trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [])
trainer.run(fake_data_stream(), max_epochs=max_epochs, epoch_length=2)

2023-01-20 14:16:01,402 - Engine run starting with max_epochs=2.
2023-01-20 14:16:01,475 - Epoch[1] Complete. Time taken: 00:00:00.073
2023-01-20 14:16:01,575 - Epoch[2] Complete. Time taken: 00:00:00.100
2023-01-20 14:16:01,576 - Engine run complete. Time taken: 00:00:00.174


State:
	iteration: 4
	epoch: 2
	epoch_length: 2
	max_epochs: 2
	output: 21955.39453125
	batch: <class 'tuple'>
	metrics: <class 'dict'>
	dataloader: <class 'generator'>
	seed: <class 'NoneType'>
	times: <class 'dict'>