# Install pytorch-xla library

In [1]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

VERSION = "20200220"
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  3566  100  3566    0     0  19811      0 --:--:-- --:--:-- --:--:-- 19811
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to 20200220 ...
Uninstalling torch-1.4.0:
Done updating TPU runtime: <Response [200]>
  Successfully uninstalled torch-1.4.0
Uninstalling torchvision-0.5.0:
  Successfully uninstalled torchvision-0.5.0
Copying gs://tpu-pytorch/wheels/torch-nightly+20200220-cp36-cp36m-linux_x86_64.whl...
- [1 files][ 79.6 MiB/ 79.6 MiB]                                                
Operation completed over 1 objects/79.6 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly+20200220-cp36-cp36m-linux_x86_64.whl...
- [1 files][111.9 MiB/111.9 MiB]                                    

# Import libraries 

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

import numpy as np

from time import time

## import torch_xla to use pytorch on the TPU 

In [0]:
# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm

# Simple example with MNIST dataset

## Define simple convolution network

In [0]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

## Downloade the MNIST dataset

In [5]:
train_dataset =  datasets.MNIST('../data', train=True, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...
Done!


## Training & Testing

In [6]:
# "Map function": acquires a corresponding Cloud TPU core, creates a tensor on it,
# and prints its core
batch_size=128
epochs=10
lr=1
gamma=0.7
log_interval=10

device = xm.xla_device()  

train_dataset =  datasets.MNIST('../data', train=True, download=True,
                  transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.1307,), (0.3081,))]))

test_dataset = datasets.MNIST('../data', train=False, transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.1307,), (0.3081,))]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                           num_workers=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, 
                                          num_workers=4, shuffle=False)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=lr)

scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

print("Start with %s"%device)

start_trn_t, device_times, epoch_times = time(), [], []
for epoch in range(1, epochs + 1):
  model.train()
  start_epoch_t = time()
  for batch_idx, (data, target) in enumerate(train_loader):
      start_it_t = time()
      data, target = data.to(device), target.to(device)
      optimizer.zero_grad()
      output = model(data)
      loss = F.nll_loss(output, target)
      loss.backward()
      xm.optimizer_step(optimizer, barrier=True)
      device_times.append(time()-start_it_t)
  print("[Epoch #%d] [Train] Total time: %.3f(s/epoch)\t Avg. time: %.3f(s/iter)"%(epoch, time()-start_epoch_t, device_times[-1]))
 
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
      for data, target in test_loader:
          data, target = data.to(device), target.to(device)
          output = model(data)
          test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
          pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
          correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)

  print('[Epoch #%d] [Test] Average loss: %.4f, Accuracy: %d/%d (%.1f%%)' %(
      epoch, test_loss, correct, len(test_loader.dataset),
      100. * correct / len(test_loader.dataset)))
  scheduler.step()
  epoch_times.append(time()-start_epoch_t)
print("Train time: %.3f, Avg. time: %.3f(s/iter), std:%.3f, Avg. time: %.3f(s/epoch), std:%.3f"\
        %(time()-start_trn_t, np.mean(device_times), np.std(device_times),  np.mean(epoch_times), np.std(epoch_times)))


Start with xla:1


	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, Number value)


[Epoch #1] [Train] Total time: 12.238(s/epoch)	 Avg. time: 0.641(s/iter)
[Epoch #1] [Test] Average loss: 0.0542, Accuracy: 9825/10000 (98.2%)
[Epoch #2] [Train] Total time: 12.656(s/epoch)	 Avg. time: 0.667(s/iter)
[Epoch #2] [Test] Average loss: 0.0350, Accuracy: 9887/10000 (98.9%)
[Epoch #3] [Train] Total time: 10.086(s/epoch)	 Avg. time: 0.011(s/iter)
[Epoch #3] [Test] Average loss: 0.0289, Accuracy: 9900/10000 (99.0%)
[Epoch #4] [Train] Total time: 10.129(s/epoch)	 Avg. time: 0.010(s/iter)
[Epoch #4] [Test] Average loss: 0.0277, Accuracy: 9906/10000 (99.1%)
[Epoch #5] [Train] Total time: 10.122(s/epoch)	 Avg. time: 0.009(s/iter)
[Epoch #5] [Test] Average loss: 0.0276, Accuracy: 9913/10000 (99.1%)
[Epoch #6] [Train] Total time: 10.178(s/epoch)	 Avg. time: 0.010(s/iter)
[Epoch #6] [Test] Average loss: 0.0265, Accuracy: 9919/10000 (99.2%)
[Epoch #7] [Train] Total time: 10.008(s/epoch)	 Avg. time: 0.008(s/iter)
[Epoch #7] [Test] Average loss: 0.0238, Accuracy: 9927/10000 (99.3%)
[Epoch