<a href="https://colab.research.google.com/github/ShogoNoguchi/TPU-parallel-operation-on-PytorchXLA_Image-Multiclass-Classification/blob/main/%E9%9F%B3%E5%A3%B0%E5%88%86%E9%A1%9EonTPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip uninstall -y tensorflow
!pip install tensorflow-cpu


Found existing installation: tensorflow 2.15.0
Uninstalling tensorflow-2.15.0:
  Successfully uninstalled tensorflow-2.15.0
Collecting tensorflow-cpu
  Downloading tensorflow_cpu-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting tensorboard<2.19,>=2.18 (from tensorflow-cpu)
  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting keras>=3.5.0 (from tensorflow-cpu)
  Downloading keras-3.7.0-py3-none-any.whl.metadata (5.8 kB)
Collecting ml-dtypes<0.5.0,>=0.4.0 (from tensorflow-cpu)
  Downloading ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting namex (from keras>=3.5.0->tensorflow-cpu)
  Downloading namex-0.0.8-py3-none-any.whl.metadata (246 bytes)
Collecting optree (from keras>=3.5.0->tensorflow-cpu)
  Downloading optree-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (47 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.

In [1]:
pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html

Looking in links: https://storage.googleapis.com/libtpu-releases/index.html


In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

def main(rank):
    # デバイスの設定
    device = xm.xla_device()

    # データディレクトリの作成
    data_dir = "data"
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    # ターゲットラベルの定義
    target_labels = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go']
    unknown_label = 'unknown'
    unique_labels = target_labels + [unknown_label]
    label_map = {label: idx for idx, label in enumerate(unique_labels)}

    # データセットクラスの定義
    class SpeechCommandsDataset(Dataset):
        def __init__(self, dataset, sample_rate=16000, n_mfcc=40, max_length=16000, label_map=None):
            self.dataset = dataset
            self.sample_rate = sample_rate
            self.n_mfcc = n_mfcc
            self.max_length = max_length
            self.label_map = label_map

            # 前処理トランスフォーム
            self.resample_transform = torchaudio.transforms.Resample(orig_freq=16000, new_freq=sample_rate)
            self.mfcc_transform = torchaudio.transforms.MFCC(
                sample_rate=sample_rate,
                n_mfcc=n_mfcc,
                melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": n_mfcc},
            )

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, idx):
            waveform, original_sample_rate, label, _, _ = self.dataset[idx]
            label = label if label in target_labels else unknown_label
            if original_sample_rate != self.sample_rate:
                waveform = self.resample_transform(waveform)
            waveform = (waveform - waveform.mean()) / waveform.std()
            if waveform.size(1) < self.max_length:
                padding = self.max_length - waveform.size(1)
                waveform = torch.nn.functional.pad(waveform, (0, padding))
            else:
                waveform = waveform[:, :self.max_length]
            mfcc = self.mfcc_transform(waveform)
            label_id = self.label_map[label] if self.label_map else label
            return mfcc, label_id

    # モデルクラスの定義
    class SpeechCommandClassifier(nn.Module):
        def __init__(self, n_mfcc=40, num_classes=len(label_map)):
            super(SpeechCommandClassifier, self).__init__()
            self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
            self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

            # フラット化後のサイズを計算
            with torch.no_grad():
                sample_input = torch.zeros(1, 1, n_mfcc, 101)
                out = self.pool(F.relu(self.conv1(sample_input)))
                out = self.pool(F.relu(self.conv2(out)))
                flattened_size = out.view(-1).shape[0]

            self.fc1 = nn.Linear(flattened_size, 128)
            self.fc2 = nn.Linear(128, num_classes)

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(x.size(0), -1)  # フラット化
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return x

    # ランク0のみがデータをダウンロード
    if rank == 0:
        train_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_dir, subset='training', download=True)
        validation_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_dir, subset='validation', download=True)
        test_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_dir, subset='testing', download=True)
    xm.rendezvous('download_complete')

    # 他のプロセスはダウンロード済みのデータを使用
    if rank != 0:
        train_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_dir, subset='training', download=False)
        validation_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_dir, subset='validation', download=False)
        test_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_dir, subset='testing', download=False)

    # データセットのインスタンスを作成
    train_data = SpeechCommandsDataset(train_dataset, label_map=label_map)
    validation_data = SpeechCommandsDataset(validation_dataset, label_map=label_map)

    # バッチサイズの設定
    batch_size = 64

    # DistributedSampler を使用してデータを分散
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_data,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True,
    )

    train_loader = DataLoader(train_data, batch_size=batch_size, sampler=train_sampler, num_workers=4)
    validation_loader = DataLoader(validation_data, batch_size=batch_size, shuffle=False, num_workers=4)

    # モデルの定義とデバイスへの移動
    model = SpeechCommandClassifier(n_mfcc=40, num_classes=len(label_map)).to(device)

    # 損失関数とオプティマイザーの定義
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # 訓練関数の定義
    def train(epoch):
        model.train()
        train_sampler.set_epoch(epoch)
        total_loss = 0
        total_correct = 0
        total_samples = 0

        for X, y in train_loader:
            X = X.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            pred = model(X)
            loss = loss_fn(pred, y)
            loss.backward()
            xm.optimizer_step(optimizer)
            total_loss += loss.item() * X.size(0)
            total_correct += (pred.argmax(1) == y).sum().item()
            total_samples += X.size(0)
            # ステップのマーク
            xm.mark_step()

        avg_loss = total_loss / total_samples
        accuracy = total_correct / total_samples * 100
        if xm.is_master_ordinal():
            print(f"エポック {epoch} 訓練 - 平均損失: {avg_loss:.4f}, 精度: {accuracy:.2f}%")

    # 検証関数の定義
    def validate(epoch):
        model.eval()
        total_loss = 0
        total_correct = 0
        total_samples = 0

        with torch.no_grad():
            for X, y in validation_loader:
                X = X.to(device)
                y = y.to(device)
                pred = model(X)
                loss = loss_fn(pred, y)
                total_loss += loss.item() * X.size(0)
                total_correct += (pred.argmax(1) == y).sum().item()
                total_samples += X.size(0)
                xm.mark_step()

        avg_loss = total_loss / total_samples
        accuracy = total_correct / total_samples * 100
        if xm.is_master_ordinal():
            print(f"エポック {epoch} 検証 - 平均損失: {avg_loss:.4f}, 精度: {accuracy:.2f}%")

    # エポック数の設定
    epochs = 5
    for epoch in range(1, epochs + 1):
        train(epoch)
        validate(epoch)

    if xm.is_master_ordinal():
        print("訓練完了！")

# xmp.spawn を使用して各 TPU コア上でプロセスを起動
if __name__ == '__main__':
    xmp.spawn(main, args=(), nprocs=8, start_method='fork')




エポック 1 訓練 - 平均損失: 1.5184, 精度: 62.30%
エポック 1 検証 - 平均損失: 1.2598, 精度: 63.77%
エポック 2 訓練 - 平均損失: 1.1109, 精度: 66.44%
エポック 2 検証 - 平均損失: 0.9926, 精度: 69.12%
エポック 3 訓練 - 平均損失: 0.8290, 精度: 73.39%
エポック 3 検証 - 平均損失: 0.7607, 精度: 75.66%
エポック 4 訓練 - 平均損失: 0.6904, 精度: 78.19%
エポック 4 検証 - 平均損失: 0.6577, 精度: 78.97%
エポック 5 訓練 - 平均損失: 0.5781, 精度: 81.13%
エポック 5 検証 - 平均損失: 0.5774, 精度: 81.32%
訓練完了！
