In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import torchaudio
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

# TPU用のモジュールをインポート
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
# import torch_xla.runtime as xr  # この行は不要

# 必要なディレクトリを作成
data_dir = "data"
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

# ターゲットラベルの定義
target_labels = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go']
unknown_label = 'unknown'
unique_labels = target_labels + [unknown_label]
label_map = {label: idx for idx, label in enumerate(unique_labels)}

# データセットクラスの定義はそのまま
class SpeechCommandsDataset(Dataset):
    def __init__(self, dataset, sample_rate=16000, n_mfcc=40, max_length=16000, label_map=None):
        self.dataset = dataset
        self.sample_rate = sample_rate
        self.n_mfcc = n_mfcc
        self.max_length = max_length
        self.label_map = label_map
        
        # 前処理トランスフォーム
        self.resample_transform = torchaudio.transforms.Resample(orig_freq=16000, new_freq=sample_rate)
        self.mfcc_transform = torchaudio.transforms.MFCC(
            sample_rate=sample_rate, 
            n_mfcc=n_mfcc,
            melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": n_mfcc},
        )
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        waveform, original_sample_rate, label, _, _ = self.dataset[idx]
        label = label if label in target_labels else unknown_label
        if original_sample_rate != self.sample_rate:
            waveform = self.resample_transform(waveform)
        waveform = (waveform - waveform.mean()) / waveform.std()
        if waveform.size(1) < self.max_length:
            padding = self.max_length - waveform.size(1)
            waveform = torch.nn.functional.pad(waveform, (0, padding))
        else:
            waveform = waveform[:, :self.max_length]
        mfcc = self.mfcc_transform(waveform)
        label_id = self.label_map[label] if self.label_map else label
        return mfcc, label_id

# モデルクラスの定義はそのまま
class SpeechCommandClassifier(nn.Module):
    def __init__(self, n_mfcc=40, num_classes=len(label_map)):
        super(SpeechCommandClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        
        # サンプル入力を使用してフラット化後のサイズを計算
        with torch.no_grad():
            sample_input = torch.zeros(1, 1, n_mfcc, 101)
            out = self.pool(F.relu(self.conv1(sample_input)))
            out = self.pool(F.relu(self.conv2(out)))
            flattened_size = out.view(-1).shape[0]
        
        self.fc1 = nn.Linear(flattened_size, 128)
        self.fc2 = nn.Linear(128, num_classes)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # フラット化
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# トレーニングループの定義
def train_loop(dataloader, model, loss_fn, optimizer, device, epoch):
    model.train()
    total_loss = 0
    correct = 0
    total_samples = 0
    for batch_idx, (X, y) in enumerate(dataloader):
        X = X.to(device)
        y = y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        # オプティマイザーステップの変更
        xm.optimizer_step(optimizer)
        total_loss += loss.item() * X.size(0)
        correct += (pred.argmax(1) == y).sum().item()
        total_samples += X.size(0)
        if batch_idx % 100 == 0 and xm.is_master_ordinal():
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():>7f}")
    avg_loss = total_loss / total_samples
    accuracy = correct / total_samples
    if xm.is_master_ordinal():
        print(f"Train - Avg loss: {avg_loss:>8f}, Accuracy: {(100 * accuracy):>0.1f}%")

# テストループの定義
def test_loop(dataloader, model, loss_fn, device, mode='Validation'):
    model.eval()
    total_loss = 0
    correct = 0
    total_samples = 0
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            pred = model(X)
            loss = loss_fn(pred, y)
            total_loss += loss.item() * X.size(0)
            correct += (pred.argmax(1) == y).sum().item()
            total_samples += X.size(0)
    avg_loss = total_loss / total_samples
    accuracy = correct / total_samples
    if xm.is_master_ordinal():
        print(f"{mode} - Avg loss: {avg_loss:>8f}, Accuracy: {(100 * accuracy):>0.1f}%")

# マルチプロセッシング用の関数
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    # デバイスの設定
    device = xm.xla_device()
    # データセットのインスタンスを作成
    global train_dataset, validation_dataset, test_dataset
    train_data = SpeechCommandsDataset(train_dataset, label_map=label_map)
    validation_data = SpeechCommandsDataset(validation_dataset, label_map=label_map)
    test_data = SpeechCommandsDataset(test_dataset, label_map=label_map)
    
    # バッチサイズの設定（必要に応じて調整）
    batch_size = 64
    # データローダーの作成
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(validation_data, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    # データローダーをTPUデバイス用にラップ
    train_loader = pl.MpDeviceLoader(train_loader, device)
    validation_loader = pl.MpDeviceLoader(validation_loader, device)
    test_loader = pl.MpDeviceLoader(test_loader, device)
    
    # モデルのインスタンスを作成し、デバイスに移行
    model = SpeechCommandClassifier(n_mfcc=40, num_classes=len(label_map)).to(device)
    # 損失関数とオプティマイザーの定義
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    epochs = 5
    for epoch in range(1, epochs +1):
        if xm.is_master_ordinal():
            print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_loader, model, loss_fn, optimizer, device, epoch)
        test_loop(validation_loader, model, loss_fn, device)
    if xm.is_master_ordinal():
        print("訓練完了！")

if __name__ == "__main__":
    # データセットをメインプロセスでロード
    train_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_dir, subset='training', download=True)
    validation_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_dir, subset='validation', download=True)
    test_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_dir, subset='testing', download=True)
    # 利用可能なTPUコア数を取得
    tpu_cores = xm.xrt_world_size()
    print(f"Available TPU cores: {tpu_cores}")
    # マルチプロセッシングを開始
    xmp.spawn(_mp_fn, args=(None,), nprocs=tpu_cores, start_method='fork')
