# Try Google Tensor Processing Unit to train your model which need big memory

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NAVIFOLIO/dl_intro/blob/main/notebooks/trial_TPU_cifar10.ipynb)

Try google TPU as acceelerator to train your model.
Code below is a example of implementation which use single TPU unit for training mini-CNN.
Google Colab TPU provides big memory in free at the moment (Jan, 2025).

Google社のTPUを使って、モデルをトレーニングしてみよう。1つのTPUユニットを使用して小さなCNNを訓練してみよう。

## Install torch_xla library to your colab environment

Install torch_xla library and make sure your pytorch version is compatible with torch_xla.

PyTorch経由でTPUを利用するためには、torch_xlaライブラリが必要です。
下記のコマンドをColab上で実行し、torch_xla ライブラリを導入してください。
また、pytorch_xlaライブラリと互換性のあるPyTorchのバージョンをインストールします。

In [None]:
!pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html

## Loading CIFAR-10 and create Pytorch Dataset

Load [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) in `torchvision.datasets` and use it as `DataLoader`.

CIFAR-10データセットをロードし、ミニバッチ学習の準備を行います。

In [None]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

transform_train = transforms.Compose([
    transforms.RandomAffine([-10, 10], scale=(0.8, 1.2)),
    transforms.RandomHorizontalFlip(p = 0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), 
])

cifar10_train = CIFAR10(root="./input", train=True, download=True, transform=transform_train)
cifar10_test = CIFAR10(root="./input", train=False, download=True, transform=transform_test)

batch_size = 500 
train_loader = DataLoader(dataset=cifar10_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=cifar10_test, batch_size=len(cifar10_test), shuffle=False)


## Create your model

モデルの作成を行います。

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class mini_vgg(nn.Module):
    def __init__(self, init_weights=True):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(256, 256, bias=True)
        self.fc2 = nn.Linear(256, 10, bias=True)
        self.norm1 = nn.BatchNorm2d(64)
        self.norm2 = nn.BatchNorm2d(128)
        self.norm3 = nn.BatchNorm2d(256)
        self.pool = nn.MaxPool2d(2, 2)
        self.globalAvgPool = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(p=0.5)
        
        if init_weights:
            for module in self.modules():
                if isinstance(module, nn.Conv2d):
                    nn.init.kaiming_normal_(module.weight)
                    if module.bias is not None:
                        nn.init.constant_(module.bias, 0)
                if isinstance(module, nn.Linear):
                    nn.init.kaiming_normal_(module.weight)
                    if module.bias is not None:
                        nn.init.constant_(module.bias, 0)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.norm1(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.norm2(self.conv4(x)))
        x = self.pool(x)
        x = F.relu(self.conv5(x))
        x = F.relu(self.norm3(self.conv6(x)))
        x = self.globalAvgPool(x)
        
        x = x.view(-1, 256)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        return x
    
model = mini_vgg()

## Training your model on TPU device

1. Transfer your tensor and model to TPU memory.
2. Execute `xm.mark_step()` per training to tell XLA to memorize learning state.

モデルの訓練を行います。
1. TensorをTPUのメモリに転送する
2. 学習毎に、`xm_mark_step()`で学習の状態をXLAデバイスに記録させる

In [None]:
from torch import optim
import torch_xla
import torch_xla.core.xla_model as xm

record_loss_train = []
record_loss_test = []

def training():
  device = xm.xla_device()

  loss_func = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters())
  x_test, t_test = iter(test_loader).__next__()
  x_test, t_test = x_test.to(device), t_test.to(device)

  for epoch in range(15):
      model = model.train().to(device)
      loss_train = 0
      for j, (data, target) in enumerate(train_loader):
          data = data.to(device)
          target = target.to(device)
          y = model(data)
          loss = loss_func(y, target)
          loss_train += loss

          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          xm.mark_step()

      loss_train /= (j + 1)
      record_loss_train.append(loss_train)

      model = model.eval().to(device)
      y_test = model(x_test)
      loss_test = loss_func(y_test, t_test).item()
      record_loss_test.append(loss_test)
    
      print(f"Epoch: {epoch}, Loss_Train: {loss_train}, Loss_Test: {loss_test}")

training()

## Graphing loss by pyplot

Model in this notebook, `mini_vgg`, cannot achieve very high accuracy to
CIFAR-10. But you can test training which needs big memory, due to relatively big data size and training batch size.

Note: Fetch tensor you transfer to TPU to CPU.

matplotlib.pyplotライブラリで、エポックごとの損失と精度（正解率）の推移をグラフ化します。
モデルが小さすぎる（層が浅すぎる）ためおそらく高い精度になりませんが、大容量メモリを必要とする学習を（クラッシュせずに）実行するテストを行うことができます。

TPUに転送されたTensorをCPUにフェッチする必要があります。

In [None]:
import matplotlib.pyplot as plt

# Fetch tensor from TPU memory
plt.plot(range(len(record_loss_train.cpu())), record_loss_train.cpu(), label="Train")
plt.plot(range(len(record_loss_test.cpu())), record_loss_test.cpu(), label="Test")
plt.legend()

plt.xlabel("Epochs")
plt.ylabel("Error")
plt.show()

## Calculating percentage of correct answers for training data.

Note: Fetch tensor you transfer to TPU to CPU.

正解率の算出

In [None]:
correct = 0
total = 0
model.eval()
device  = xm.xla_device()
for i, (x, t) in enumerate(test_loader):
    x = x.to(device)
    y = model(x)
    # Fetch tensor from TPU memory
    z = y.cpu()
    correct += (z.argmax(1) == t).sum().item()
    total += len(x)
print("Accuracy[%]: ", str(correct/total*100) + "%")

## References

- [The CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html)
- [PyTorch/XLA documentation](https://pytorch.org/xla/release/r2.5/index.html)