In [None]:
import os
import torch
from torch import nn as nn
from torch.utils.data import Dataset, DataLoader

from scipy.io.wavfile import read

import numpy as np

import math

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

In [None]:
class CustomAudioDataset(Dataset):
    def __init__(self, audio_dir, transform=None, target_transform=None):
        self.audio_classes = ["up", "down", "left", "right"]
        self.audio_dir = audio_dir
        self.transform = transform
        self.target_transform = target_transform
        self.file_paths = []
        self.labels = []

        # Gather files for the speech commands
        for class_idx, class_name in enumerate(self.audio_classes):
            class_dir = os.path.join(audio_dir, class_name)
            for file_name in os.listdir(class_dir):
                if file_name.endswith('.wav'):
                    file_path = os.path.join(class_dir, file_name)
                    rate, data = read(file_path)
                    audio_tensor = torch.tensor(data, dtype=torch.float32)
                    stft_result = torch.stft(
                        audio_tensor,
                        n_fft=256,
                        hop_length=256 // 8,
                        win_length=256,
                        window=torch.hann_window(256),
                        return_complex=True
                    )
                    if stft_result.shape == (129, 501):
                        self.file_paths.append(file_path)
                        self.labels.append(class_idx)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = self.labels[idx]
        rate, data = read(file_path)
        audio = torch.tensor(data, dtype=torch.float32)

        audio = torch.stft(
            audio,
            n_fft=256,
            hop_length=256 // 8,
            win_length=256,
            window=torch.hann_window(256),
            return_complex=True
        ).abs() ** 2

        if self.transform:
            audio = self.transform(audio)
        if self.target_transform:
            label = self.target_transform(label)

        return audio, label

In [None]:
def collate_gpu(batch):
    x, t = torch.utils.data.default_collate(batch)
    return x.to(device=device), t.to(device=device)

In [None]:
audio_dir = '/kaggle/input/train-commands/commands/'
dataset = CustomAudioDataset(audio_dir=audio_dir)

train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_gpu)

for batch_idx, (data, labels) in enumerate(train_dataloader):
    print(f"Batch {batch_idx+1}")
    print(f"Waveform shape: {data.shape}")
    print(f"Labels: {labels}")

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn_stack = nn.Sequential(
            nn.Conv1d(129, 200, 3, 2),
            nn.LeakyReLU(),
            nn.Conv1d(200, 300, 3, 2),
            nn.MaxPool2d(2,2),
            nn.LeakyReLU(),
            nn.ConvTranspose1d(150, 50, 3, 2),
            nn.BatchNorm1d(50),
            nn.LeakyReLU(),
        )
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(6250, 3000),
            nn.ReLU(),
            nn.Linear(3000, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 4),
        )

    def forward(self, x):
        x = self.cnn_stack(x)
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)

In [None]:
learning_rate = 0.001
batch_size = 64
epochs = 10

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)


def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


for t in range(epochs):
  print(f"epoch #{t}")
  train_loop(train_dataloader, model, loss_fn, optimizer)
  test_loop(train_dataloader, model, loss_fn)

In [None]:
def HannWindow(M):
    return [(math.sin(math.pi * n / (M - 1))) ** 2 for n in range(M)]

In [None]:
def full_STFT(val_cos, win_length, hop_length, n_fft):
    pad_amount = n_fft // 2
    val_cos = np.pad(val_cos, (pad_amount, pad_amount), mode="reflect")

    num_frames = (len(val_cos) - win_length) // hop_length + 1
    stft_matrix = []

    window = HannWindow(win_length + 1)[:-1]

    for m in range(num_frames):
        frame_start = m * hop_length
        frame_end = frame_start + win_length

        segment = val_cos[frame_start:frame_end]

        windowed_segment = segment * window

        dft_result = []
        for omega in range(n_fft):
            angle = -1j * 2 * math.pi * omega * np.arange(win_length) / n_fft
            sum_value = np.sum(windowed_segment * np.exp(angle))
            dft_result.append(sum_value)

        stft_matrix.append(dft_result)

    return np.array(stft_matrix)

In [None]:
for i in os.listdir("/kaggle/input/train-commands/commands/right"):
    rate, data = read(f"/kaggle/input/train-commands/commands/right/{i}")
    if data.shape[0] != 16000:
      continue

    audio = torch.stft(
            torch.tensor(data, dtype=torch.float32),
            n_fft=256,
            hop_length=256 // 8,
            win_length=256,
            window=torch.hann_window(256),
            return_complex=True
        ).abs() ** 2

    if audio.shape[1] != 501:
        continue
    l = model(audio.unsqueeze_(0).to(device))

    pred_probab = nn.Softmax(1)(l)
    y_pred = pred_probab.argmax(1)
    print(f"Predicted class: {y_pred}")

In [None]:
model_scripted = torch.jit.script(model.to("cpu"))
model_scripted.save('model_scripted_cpu.pt') # Save