# Install pytorch-xla to use multiple TPUs

In [1]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

VERSION = "20200220"#"xrt==1.15.0"
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  3566  100  3566    0     0  49527      0 --:--:-- --:--:-- --:--:-- 50225
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to 20200220 ...
Uninstalling torch-1.4.0:
Done updating TPU runtime: <Response [200]>
  Successfully uninstalled torch-1.4.0
Uninstalling torchvision-0.5.0:
  Successfully uninstalled torchvision-0.5.0
Copying gs://tpu-pytorch/wheels/torch-nightly+20200220-cp36-cp36m-linux_x86_64.whl...
- [1 files][ 79.6 MiB/ 79.6 MiB]                                                
Operation completed over 1 objects/79.6 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly+20200220-cp36-cp36m-linux_x86_64.whl...
- [1 files][111.9 MiB/111.9 MiB]                                    

# Import libraries 

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

import numpy as np

from time import time

## import torch_xla to use pytorch on TPU 

In [0]:
# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm

## import torch_xla.distributed for multiple TPUs

In [0]:
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

# Download MNIST dataset

In [5]:
train_dataset =  datasets.MNIST('../data', train=True, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...
Done!


# Simple example with MNIST dataset

## Define simple convolution network

In [0]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

## Define train function as map function

In [0]:
# "Map function": acquires a corresponding Cloud TPU core, creates a tensor on it,
# and prints its core
def map_fn(index):
  # Training settings
  batch_size=128
  epochs=10
  lr=1
  gamma=0.7
  log_interval=10
  
  device = xm.xla_device()  

  train_dataset =  datasets.MNIST('../data', train=True, download=True,
                      transform=transforms.Compose([
                          transforms.ToTensor(),
                          transforms.Normalize((0.1307,), (0.3081,))]))

  test_dataset = datasets.MNIST('../data', train=False, transform=transforms.Compose([
                          transforms.ToTensor(),
                          transforms.Normalize((0.1307,), (0.3081,))]))

  train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)
  
  test_sampler = torch.utils.data.distributed.DistributedSampler(
    test_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=False)

  train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 
                                             sampler=train_sampler,
                                             num_workers=4)
  test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, 
                                            sampler=test_sampler,
                                            num_workers=4)

  model = Net().to(device)
  optimizer = optim.Adadelta(model.parameters(), lr=lr)

  scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

  print("Process", index ,"is using", xm.xla_real_devices([str(device)])[0])
  device_times, epoch_times = [], []
  for epoch in range(1, epochs + 1):
      start_epoch_t = time()
      para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
      model.train()
      times, losses = [], []
      for batch_idx, (data, target) in enumerate(para_train_loader):
          start_it_t = time()
          data, target = data.to(device), target.to(device)
          optimizer.zero_grad()
          output = model(data)
          loss = F.nll_loss(output, target)
          losses.append(loss.cpu().item())
          loss.backward()
          xm.optimizer_step(optimizer)
          times.append(time()-start_it_t)
          device_times.append(times[-1])
      print("[Epoch %d] %s [Train] Avg. loss: %.3f, Total time: %.3f(s/epoch), Avg. time: %.3f(s/iter)"\
      %(epoch, xm.xla_real_devices([str(device)])[0], np.mean(losses), 
        time()-start_epoch_t, np.mean(times)))
      
      para_test_loader = pl.ParallelLoader(test_loader, [device]).per_device_loader(device)
      model.eval()
      test_loss = 0
      correct = 0
      with torch.no_grad():
          len_data = 0
          for data, target in para_test_loader:
              len_data += len(data)
              data, target = data.to(device), target.to(device)
              output = model(data)
              test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
              pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
              correct += pred.eq(target.view_as(pred)).sum().item()

      test_loss /= len_data

      print("[Epoch %d] %s [Test] Avg. loss: %.4f, Accuracy: %d/%d (%.2f%%)"\
      %(epoch, xm.xla_real_devices([str(device)])[0], test_loss, correct, len_data, 
        100. * correct / len_data))
      scheduler.step()
      epoch_times.append(time()-start_epoch_t)
  print("%s, Avg. time: %.3f(s/iter), std:%.3f, Avg. time: %.3f(s/epoch), std:%.3f"\
        %(xm.xla_real_devices([str(device)])[0], np.mean(device_times), np.std(device_times),  np.mean(epoch_times), np.std(epoch_times)))

## Train the model with multiple TPUs 

In [8]:
# Spawns eight of the map functions, one for each of the eight cores on
# the Cloud TPU

# Note: Colab only supports start_method='fork'
start_trn_time = time()
xmp.spawn(map_fn, nprocs=8, start_method='fork')
print("Train time: %.3fs"%(time()-start_trn_time))

Process 0 is using TPU:0


	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, Number value)


Process 5 is using TPU:5


	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, Number value)


Process 7 is using TPU:7
Process 6 is using TPU:6


	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, Number value)
	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, Number value)


Process 4 is using TPU:4
Process 1 is using TPU:1
Process 2 is using TPU:2


	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, Number value)


Process 3 is using TPU:3


	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, Number value)
	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, Number value)
	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, Number value)


[Epoch 1] TPU:7 [Train] Avg. loss: 0.612, Total time: 12.713(s/epoch), Avg. time: 0.176(s/iter)
[Epoch 1] TPU:2 [Train] Avg. loss: 0.613, Total time: 10.733(s/epoch), Avg. time: 0.139(s/iter)
[Epoch 1] TPU:5 [Train] Avg. loss: 0.608, Total time: 13.629(s/epoch), Avg. time: 0.198(s/iter)
[Epoch 1] TPU:1 [Train] Avg. loss: 0.625, Total time: 10.870(s/epoch), Avg. time: 0.134(s/iter)
[Epoch 1] TPU:6 [Train] Avg. loss: 0.600, Total time: 12.708(s/epoch), Avg. time: 0.179(s/iter)
[Epoch 1] TPU:4 [Train] Avg. loss: 0.604, Total time: 11.212(s/epoch), Avg. time: 0.151(s/iter)
[Epoch 1] TPU:3 [Train] Avg. loss: 0.583, Total time: 10.387(s/epoch), Avg. time: 0.109(s/iter)
[Epoch 1] TPU:0 [Train] Avg. loss: 0.605, Total time: 15.750(s/epoch), Avg. time: 0.225(s/iter)
[Epoch 1] TPU:4 [Test] Avg. loss: 0.1195, Accuracy: 1209/1250 (96.72%)
[Epoch 1] TPU:2 [Test] Avg. loss: 0.1395, Accuracy: 1193/1250 (95.44%)
[Epoch 1] TPU:5 [Test] Avg. loss: 0.1482, Accuracy: 1190/1250 (95.20%)
[Epoch 1] TPU:7 [Te