In [None]:
import os
import numpy as np
import pandas as pd
import librosa
import librosa.display
import torch
import torch.nn.functional as F
import wandb
from tqdm.auto import tqdm
from torch import nn

In [None]:
experiment_name = wandb.util.generate_id()
wandb.init(
    project="audio-classification",

    config={
        "learning_rate": 5e-4,
        "epochs": 1100,
        "batch_size": 16,
        "architecture": "CoAtNet",
    }
)

In [None]:
def extract_keystroke(y, sr, keystroke_interval=0.6, energy_threshold=0.0015):
    frame_size = int(sr * keystroke_interval)
    hop_length = int(frame_size / 2)

    energy = librosa.feature.rms(y=y,
                                 frame_length=frame_size,
                                 hop_length=hop_length)[0]
    keystroke_mask = (energy > energy_threshold).astype(int)
    onset_frames = librosa.onset.onset_detect(y=y,
                                              sr=sr,
                                              units='time')
    onset_frames = (onset_frames * sr / hop_length).astype(int)

    onset_frames_keystroke = onset_frames[keystroke_mask[onset_frames].astype(
        bool)]
    keystrokes = []

    for onset_frame in onset_frames_keystroke:
        start_sample = onset_frame * hop_length
        end_sample = start_sample + frame_size
        keystrokes.append(y[start_sample:end_sample])

    return keystrokes

In [None]:
def get_mel_spectrograms(keystrokes, sr, n_mels, hop_length, n_fft):
    mel_spectrograms = []

    for keystroke in keystrokes:
        mel_spectrogram = librosa.feature.melspectrogram(y=keystroke,
                                                         sr=sr,
                                                         n_mels=n_mels,
                                                         hop_length=hop_length,
                                                         n_fft=n_fft)
        mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
        mel_spectrograms.append(mel_spectrogram)

    return mel_spectrograms

In [None]:
CLASS_MAPPING = {
    'A': 0, 'B': 1, 'C': 2, 'D': 3,
    'E': 4, 'F': 5, 'G': 6, 'H': 7,
    'I': 8, 'J': 9, 'K': 10, 'L': 11,
    'M': 12, 'N': 13, 'O': 14, 'P': 15,
    'Q': 16, 'R': 17, 'S': 18, 'T': 19,
    'U': 20, 'V': 21, 'W': 22, 'X': 23,
    'Y': 24, 'Z': 25, '0': 26, '1': 27,
    '2': 28, '3': 29, '4': 30, '5': 31,
    '6': 32, '7': 33, '8': 34, '9': 35,
}


class KeystrokeDataset(torch.utils.data.Dataset):
    def __init__(self, audio_path, n_mels=64, hop_length=256, n_fft=1024, keystroke_interval=0.6, energy_threshold=0.0015):
        self.audio_path = audio_path
        self.n_mels = n_mels
        self.hop_length = hop_length
        self.n_fft = n_fft
        self.keystroke_interval = keystroke_interval
        self.energy_threshold = energy_threshold

        self.mel_spectrograms = []
        self.labels = []

        for f in os.listdir(self.audio_path):
            if f.endswith('.wav'):
                wav_file = os.path.join(self.audio_path, f)

                y, sr = librosa.load(wav_file, mono=True)

                keystrokes = extract_keystroke(y,
                                               sr,
                                               self.keystroke_interval,
                                               self.energy_threshold)

                mel_spectrograms = get_mel_spectrograms(keystrokes,
                                                        sr,
                                                        self.n_mels,
                                                        self.hop_length,
                                                        self.n_fft)
                label = f.split('_')[0]
                numerical_label = CLASS_MAPPING.get(label, -1)

                self.mel_spectrograms.extend(mel_spectrograms)
                self.labels.extend([numerical_label] * len(mel_spectrograms))

    def __len__(self):
        return len(self.mel_spectrograms)

    def __getitem__(self, idx):
        mel_spectrogram = self.mel_spectrograms[idx]
        label = self.labels[idx]

        max_length = 52
        mel_spectrogram = torch.tensor(mel_spectrogram)
        mel_spectrogram = F.pad(
            mel_spectrogram, (0, max_length - mel_spectrogram.shape[1]))

        return mel_spectrogram, torch.tensor(label)

In [None]:
class CoAtNet(nn.Module):
    def __init__(self, n_classes=36):
        super(CoAtNet, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3),
                      stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
        )

        encoder_layer = nn.TransformerEncoderLayer(d_model=64, nhead=8)
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=2)

        self.fc = nn.Linear(64, n_classes)

    def forward(self, x):
        # print(f"Length input: {len(x)}")
        # print(f"Before first layer: {x.shape}")
        x = self.conv_layers(x)

        x = x.view(x.size(0), -1, x.size(1)).permute(1, 0, 2)
        # print(f"After reshape: {x.shape}")

        x = self.transformer_encoder(x)
        # print(f"After transformer:{x.shape}")

        x = x.permute(1, 0, 2)

        x, _ = torch.max(x, dim=1)
        # x = x.view(x.size(0), -1)
        # print(f"Before FC: {x.shape}")

        x = self.fc(x)
        # print(f"Output shape: {x.shape}")
        return x

In [None]:
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

AUDIO_PATH = 'data/Phone_Recording/'

dataset = KeystrokeDataset(AUDIO_PATH)
train_set, test_set = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
test_loader = DataLoader(test_set, batch_size=16, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = CoAtNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

n_epochs = 1100

for epoch in range(n_epochs):
    model.train()
    for inputs, labels in tqdm(train_loader):

        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        reshaped_inputs = inputs.unsqueeze(1)
        output = model(reshaped_inputs)

        loss = criterion(output, labels)

        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch + 1}/{n_epochs}, Loss: {loss.item():.4f}')
    wandb.log({"Epoch": epoch, "loss": loss.item()})

    if (epoch + 1) % 10 == 0:
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for inputs, labels in test_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                reshaped_inputs = inputs.unsqueeze(1)
                output = model(reshaped_inputs)
                _, predicted = torch.max(output.data, 1)

                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            print(f'Accuracy: {100 * correct / total:.2f}%')