# SPEECH COMMAND RECOGNITION WITH PYTORCH LIGHTNING
[torchaudioのtutorial](https://pytorch.org/tutorials/intermediate/speech_command_recognition_with_torchaudio_tutorial.html)をpytorch lightningを使って書き直しました ([日本語訳](https://colab.research.google.com/github/YutaroOgawa/pytorch_tutorials_jp/blob/main/notebook/7_Audio/7_7_speech_command_recognition_with_torchaudio_jp.ipynb))  
その他の変更点 (覚えている範囲内)
- batch_size
- epoch数
- 学習率のスケジューラのパラメータ
- アーリーストッピングの導入
- Automatic Mixed Precision (AMP) の導入

In [1]:
import os
import sys
from omegaconf import DictConfig, OmegaConf

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
import torchaudio
from torchaudio.datasets import SPEECHCOMMANDS

import matplotlib.pyplot as plt
import IPython.display as ipd

In [2]:
import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

In [3]:
pl.seed_everything(0)
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

Global seed set to 0


In [4]:
LABELS = ['backward', 'bed', 'bird', 'cat','dog', 'down', 'eight', 'five', 'follow', 'forward', 'four', 'go', 'happy', 'house', 'learn',
          'left', 'marvin', 'nine', 'no', 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', 'up',
          'visual', 'wow', 'yes', 'zero']

##  データセットの準備

In [5]:
# そのまま
class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset=None):
        super().__init__("./speech", download=True)

        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as fileobj:
                return [os.path.join(self._path, line.strip()) for line in fileobj]

        if subset == "validation":
            self._walker = load_list("validation_list.txt")
            
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
            
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]

In [6]:
class SpeechDataModule(LightningDataModule):
    def __init__(self, batch_size, labels=LABELS, transform=None):
        super().__init__()
        self.batch_size = batch_size
        self.labels = labels
        self.transform = transform
        
    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_set = SubsetSC("training")
            self.valid_set = SubsetSC("validation")
        
        if stage == "test" or stage is None:
            self.test_set = SubsetSC("testing")

    def train_dataloader(self):
        loader = DataLoader(self.train_set,
                            batch_size=self.batch_size,
                            shuffle=True,
                            drop_last=True,
                            collate_fn=self.collate_fn,
                            pin_memory=True,
                            num_workers=4
                           )
        return loader

    def val_dataloader(self):
        loader = DataLoader(self.valid_set,
                            batch_size=self.batch_size,
                            shuffle=False,
                            drop_last=False,
                            collate_fn=self.collate_fn,
                            pin_memory=True,
                            num_workers=4
                           )
        return loader

    def test_dataloader(self):
        loader = DataLoader(self.test_set,
                            batch_size=self.batch_size,
                            shuffle=False,
                            drop_last=False,
                            collate_fn=self.collate_fn,
                            pin_memory=True,
                            num_workers=4
                           )
        return loader
    
    def pad_sequence(self, batch):
        batch = [item.t() for item in batch]
        batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.)
        batch = self.transform(batch.permute(0, 2, 1))  # ここでリサンプリング
        return batch
    
    def label_to_index(self, word):
        return torch.tensor(self.labels.index(word))

    def collate_fn(self, batch):
        tensors, targets = [], []

        for waveform, _, label, _, _ in batch:
            tensors += [waveform]
            targets += [self.label_to_index(label)]

        tensors = self.pad_sequence(tensors)
        targets = torch.stack(targets)

        return tensors, targets

## モデルの準備

In [7]:
# そのまま
class M5(nn.Module):
    def __init__(self, n_input=1, n_output=35, stride=16, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channel, n_output)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = F.relu(self.bn4(x))
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        return x

## Lightning Moduleの準備

In [8]:
class LitModel(LightningModule):
    def __init__(self, model_hparams, optimizer_hparams):
        super().__init__()
        self.save_hyperparameters()
        self.model = M5(model_hparams.n_input,
                        model_hparams.n_output,
                        model_hparams.stride,
                        model_hparams.n_channel
                       )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=2)

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data).squeeze()
        loss = F.nll_loss(output, target)
        self.log("tr_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self(data).squeeze()
        loss = F.nll_loss(output, target)
        
        preds = torch.argmax(output, dim=1)
        acc = (target == preds).float()
        
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        data, target = batch
        output = self(data).squeeze()
        
        preds = torch.argmax(output, dim=1)
        acc = (target == preds).float().mean()
        self.log("test_acc", acc, prog_bar=True)
        return acc

    def configure_optimizers(self):
        params = self.hparams.optimizer_hparams
        optimizer = torch.optim.AdamW(self.parameters(),
                                      lr=params.learning_rate,
                                      weight_decay=params.weight_decay
                                     )
        
        schedular = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    gamma=params.gamma,
                                                    step_size=params.step_size
                                                   )
        scheduler = {"scheduler": schedular, "interval": "epoch", "frequency": 1}
        return [optimizer], [scheduler]

---

In [9]:
orig_freq = 16000
new_freq = 8000

max_epochs = 100
batch_size = 32

model_hparams = OmegaConf.create({"n_input": 1,
                                  "n_output": 35,
                                  "stride": 16,
                                  "n_channel": 32
                                 })

optimizer_hparams = OmegaConf.create({"learning_rate": 0.01,
                                      "weight_decay": 0.0001,
                                      "gamma": 0.5,
                                      "step_size": 10,
                                      "patience": 5
                                     })

In [10]:
transform = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq)
dm = SpeechDataModule(batch_size=batch_size, transform=transform)
model = LitModel(model_hparams, optimizer_hparams)

In [11]:
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=optimizer_hparams.patience, verbose=False, mode="min")
checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath='./speech/models/', filename='spcm_{epoch:03d}', save_top_k=5)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


In [12]:
trainer = Trainer(max_epochs=max_epochs,
                  progress_bar_refresh_rate=20,
                  gpus=1,
                  precision=16,  # Automatic Mixed Precision (AMP)
                  callbacks=[early_stop_callback, checkpoint_callback]
                 )

Using native 16bit precision.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [13]:
trainer.fit(model, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | M5   | 26.9 K
-------------------------------
26.9 K    Trainable params
0         Non-trainable params
26.9 K    Total params
0.108     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 0


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [14]:
results = trainer.test()
print(results)

  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.9198546409606934}
--------------------------------------------------------------------------------
[{'test_acc': 0.9198546409606934}]


---

In [19]:
ckpt_path = trainer.checkpoint_callback.best_model_path
model = LitModel(model_hparams, optimizer_hparams)
model = LitModel.load_from_checkpoint(ckpt_path)

In [20]:
results = trainer.test(model)
print(results)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.9198546409606934}
--------------------------------------------------------------------------------
[{'test_acc': 0.9198546409606934}]


---