<a href="https://colab.research.google.com/github/PyDataOsaka/handson_pytorch/blob/master/cnn_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pytorch hands-on (CNN on TPU)

Adapted from [here](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)

In [0]:
!rm -r ./log

In [0]:
%tensorflow_version 2.x
%load_ext tensorboard

In [0]:
%tensorboard --logdir ./log

In [0]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [0]:
os.environ["XLA_USE_BF16"] = "1"

## Installing Pytorch/XLA

In [0]:
VERSION = "20200325"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

In [0]:
from time import time
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.transforms as transforms
import torch_xla
import torch_xla.core.xla_model as xm

## Load image data

In [0]:
def get_data(batch_size: int=64):
  transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

  trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                          download=True, transform=transform)
  trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True, num_workers=2)

  testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                         download=True, transform=transform)
  testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                           shuffle=False, num_workers=2)

  classes = ('plane', 'car', 'bird', 'cat',
             'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  
  return trainloader, testloader, classes

## CNN model

In [0]:
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5).bfloat16()
    self.pool = nn.MaxPool2d(2, 2).bfloat16()
    self.conv2 = nn.Conv2d(6, 16, 5).bfloat16()
    self.fc1 = nn.Linear(16 * 5 * 5, 120).bfloat16()
    self.fc2 = nn.Linear(120, 84).bfloat16()
    self.fc3 = nn.Linear(84, 10).bfloat16()

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

## Define functions

### Training

In [0]:
def train_tpu(model: nn.Module, trainloader, log_dir: str, device):
  model.to(device)

  loss = nn.CrossEntropyLoss()
  opt = optim.Adam(model.parameters(), lr=0.001)

  writer = SummaryWriter(log_dir)
  running_loss = 0.0
  prev_time = time()
  n_minibatches = 0

  for epoch in range(4):
    for i, data in enumerate(trainloader, 0):
      # get the inputs; data is a list of [inputs, labels]
      inputs = data[0].to(device)
      labels = data[1].to(device)

      # zero the parameter gradients
      opt.zero_grad()

      # forward + backward + optimize
      outputs = model(inputs)
      loss_value = loss(outputs, labels)
      loss_value.backward()
      # opt.step() # For CPU/GPU
      xm.optimizer_step(opt, barrier=True)  # Note: Cloud TPU-specific code!

      writer.add_scalar("loss_value", loss_value, n_minibatches)
      n_minibatches += 1

      # print statistics
      running_loss += loss_value.item()
      if i % 100 == 99:    # print every 100 mini-batches
        print('[{}, {:5d}] loss: {:.3f}, elapsed time: {:.1f} [sec]'.format(
              epoch + 1, i + 1, running_loss / 2000, time() - prev_time))
        running_loss = 0.0
        prev_time = time()

### Prediction

In [0]:
def evaluate(model: nn.Module, testloader, device):
  correct = 0
  total = 0

  with torch.no_grad():
    for data in testloader:
      inputs = data[0].to(device)
      labels = data[1].to(device)
      outputs = model(inputs)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))

## Training and evaluation of the model on TPU

In [0]:
trainloader, testloader, classes = get_data()
model = Net()

dev = xm.xla_device()

train_tpu(model, trainloader, "./log/2", dev)
evaluate(model, testloader, dev)

model.to("cpu")
torch.save({
    "model": model.state_dict(),
}, "./model_tpu.pt")

## Load trained model

In [0]:
trainloader, testloader, classes = get_data()
model = Net()

dev = xm.xla_device()

checkpoint = torch.load("./model_tpu.pt")
model.load_state_dict(checkpoint["model"])
model.to(dev)

evaluate(model, testloader, dev)