### Multiple GPU Training Loop Using Data Parallel

#### Imports

In [1]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision import transforms
from tqdm import tqdm
from torch.nn import DataParallel

#### Loading Datasets

In [2]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize(mean=0.0, std=1.0),
                                  transforms.Lambda(lambda x: torch.flatten(x))])
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [3]:
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize(mean=0.0, std=1.0),
                                  transforms.Lambda(lambda x: torch.flatten(x))])
)

In [4]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True, pin_memory=True)

In [5]:
out = next(iter(train_dataloader))

In [6]:
out[0].shape, out[1].shape

(torch.Size([64, 784]), torch.Size([64]))

In [7]:
class NN(torch.nn.Module):
  def __init__(self):
    super(NN, self).__init__()
    self.model = torch.nn.Sequential(
        torch.nn.Linear(784, 256),
        torch.nn.ReLU(),
        torch.nn.Linear(256, 128),
        torch.nn.ReLU(),
        torch.nn.Linear(128, 64),
        torch.nn.ReLU(),
        torch.nn.Linear(64, 10),
        torch.nn.LogSoftmax(dim=1)
    )

  def forward(self, x):
    print("Model Input: ", x.shape)
    return self.model(x)

### Checking CUDA Device Availability

In [8]:
is_device_available = torch.cuda.is_available()
if is_device_available:
  device = 'cuda'
  print(f"Device Name: {torch.cuda.get_device_name()}")
  print(f"Num Devices: {torch.cuda.device_count()}")
else:
  device = 'cpu'
  print(f"Device Name: CPU")

Device Name: Quadro P4000
Num Devices: 2


#### Move model to device

In [9]:
nn_gpu = NN()

if torch.cuda.device_count() >= 1:
  print("Number of GPUs: ", torch.cuda.device_count())
  # if you have 3 GPUs, then input [30, xxx] will be split into  [10, ...], [10, ...], [10, ...] on 3 GPUs
  nn_gpu = DataParallel(nn_gpu)

nn_gpu.to(device)


Number of GPUs:  2


DataParallel(
  (module): NN(
    (model): Sequential(
      (0): Linear(in_features=784, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=64, bias=True)
      (5): ReLU()
      (6): Linear(in_features=64, out_features=10, bias=True)
      (7): LogSoftmax(dim=1)
    )
  )
)

In [10]:
optimizer = torch.optim.Adam(nn_gpu.parameters())
criterion = torch.nn.CrossEntropyLoss()

In [12]:
def train(model, data, epochs=10):
  for _ in tqdm(range(epochs)):
    for inputs, labels in data:
      inputs = inputs.to(device)
      labels = labels.to(device)
      print('Outside: ', inputs.shape)
      output = model(inputs)
      loss = criterion(output, labels)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

In [13]:
train(nn_gpu, train_dataloader)

  0%|                                                                                                                                      | 0/10 [00:00<?, ?it/s]

Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 78

 10%|████████████▌                                                                                                                 | 1/10 [00:35<05:18, 35.44s/it]

Model Input: Model Input:  torch.Size([16, 784])
 torch.Size([16, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32

 20%|█████████████████████████▏                                                                                                    | 2/10 [01:05<04:19, 32.45s/it]

Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([32, 784])
Model Input:  torch.Size([16, 784])
Model Input:  torch.Size([16, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 78

 30%|█████████████████████████████████████▊                                                                                        | 3/10 [01:36<03:40, 31.48s/it]

Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([32, 784])
Model Input:  torch.Size([16, 784])
Model Input:  torch.Size([16, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 78

 40%|██████████████████████████████████████████████████▍                                                                           | 4/10 [02:06<03:06, 31.07s/it]

Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([32, 784])
Model Input:  torch.Size([16, 784])
Model Input:  torch.Size([16, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 78

 50%|███████████████████████████████████████████████████████████████                                                               | 5/10 [02:37<02:34, 30.88s/it]

Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([32, 784])
Model Input:  torch.Size([16, 784])
Model Input:  torch.Size([16, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 78

 60%|███████████████████████████████████████████████████████████████████████████▌                                                  | 6/10 [03:07<02:03, 30.80s/it]

Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([32, 784])
Model Input:  torch.Size([16, 784])
Model Input:  torch.Size([16, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 78

 70%|████████████████████████████████████████████████████████████████████████████████████████▏                                     | 7/10 [03:38<01:32, 30.84s/it]

Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([32, 784])
Model Input:  torch.Size([16, 784])
Model Input:  torch.Size([16, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32

 80%|████████████████████████████████████████████████████████████████████████████████████████████████████▊                         | 8/10 [04:09<01:01, 30.96s/it]

Outside:  torch.Size([32, 784])
Model Input:  torch.Size([16, 784])
Model Input:  torch.Size([16, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 78

 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍            | 9/10 [04:41<00:31, 31.27s/it]

Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([32, 784])
Model Input:  torch.Size([16, 784])
Model Input:  torch.Size([16, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 78

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [05:14<00:00, 31.40s/it]

Outside:  torch.Size([64, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Outside:  torch.Size([32, 784])
Model Input:  torch.Size([16, 784])
Model Input:  torch.Size([16, 784])





In [14]:
def eval(model, data):
  total, correct = 0, 0
  for inputs, labels in data:
    inputs = inputs.to(device)
    labels = labels.to(device)
    output = model(inputs)
    output = output.argmax(dim=1)
    correct_predictions = (output == labels).sum()
    total_predictions = labels.size()[0]
    total += total_predictions
    correct += correct_predictions
  print(f"Accuracy: {correct / total}")

In [15]:
eval(nn_gpu, test_dataloader)

Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32, 784])
Model Input:  torch.Size([32