<a href="https://colab.research.google.com/github/ShogoNoguchi/TPU-parallel-operation-on-PytorchXLA/blob/main/TPU%E4%B8%A6%E5%88%97.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

# データセットの準備
training_data = datasets.FashionMNIST(
    root="data", train=True, download=True, transform=ToTensor()
)
test_data = datasets.FashionMNIST(
    root="data", train=False, download=True, transform=ToTensor()
)

# データセットサイズを取得
train_size = len(training_data)
test_size = len(test_data)

# モデル定義
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

def train_loop(dataloader, model, loss_fn, optimizer, device, dataset_size):
    for batch_idx, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        xm.optimizer_step(optimizer)

        # 主プロセスのみ出力
        if batch_idx % 100 == 0 and xm.is_master_ordinal():
            processed_samples = batch_idx * len(X)
            print(f"Batch {batch_idx}, Loss: {loss.item():>7f}, Processed Samples: {processed_samples}/{dataset_size}")

def test_loop(dataloader, model, loss_fn, device, dataset_size):
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item() * len(X)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    accuracy = correct / dataset_size
    avg_loss = test_loss / dataset_size

    # 主プロセスのみ出力
    if xm.is_master_ordinal():
        print(f"Test Error: \n Accuracy: {(100 * accuracy):>0.1f}%, Avg loss: {avg_loss:>8f} \n")

def _mp_fn(rank, flags):
    device = xm.xla_device()

    train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)
    train_mp_loader = pl.MpDeviceLoader(train_dataloader, device)
    test_mp_loader = pl.MpDeviceLoader(test_dataloader, device)

    model = NeuralNetwork().to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    for epoch in range(5):
        if xm.is_master_ordinal():
            print(f"Epoch {epoch + 1}\n-------------------------------")
        train_loop(train_mp_loader, model, loss_fn, optimizer, device, train_size)
        test_loop(test_mp_loader, model, loss_fn, device, test_size)



In [None]:
if __name__ == "__main__":
    xmp.spawn(_mp_fn, args=(None,), nprocs=8, start_method="fork")



Epoch 1
-------------------------------
Batch 0, Loss: 2.308045, Processed Samples: 0/60000
Batch 100, Loss: 1.075573, Processed Samples: 6400/60000
Batch 200, Loss: 0.652618, Processed Samples: 12800/60000
Batch 300, Loss: 0.423205, Processed Samples: 19200/60000
Batch 400, Loss: 0.379817, Processed Samples: 25600/60000
Batch 500, Loss: 0.491438, Processed Samples: 32000/60000
Batch 600, Loss: 0.528216, Processed Samples: 38400/60000
Batch 700, Loss: 0.523964, Processed Samples: 44800/60000
Batch 800, Loss: 0.539412, Processed Samples: 51200/60000
Batch 900, Loss: 0.288072, Processed Samples: 57600/60000
Test Error: 
 Accuracy: 82.2%, Avg loss: 0.498026 

Epoch 2
-------------------------------
Batch 0, Loss: 0.614214, Processed Samples: 0/60000
Batch 100, Loss: 0.433167, Processed Samples: 6400/60000
Batch 200, Loss: 0.211252, Processed Samples: 12800/60000
Batch 300, Loss: 0.455947, Processed Samples: 19200/60000
Batch 400, Loss: 0.326837, Processed Samples: 25600/60000
Batch 500, L