In [12]:
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn

from model import KeywordSpottingModel
from data_loader import load_speech_commands_dataset #TFDatasetAdapter
from utils import set_memory_GB,print_model_size, log_to_file, plot_learning_curves
from augmentations import add_time_shift_noise_and_align, add_noise
from train import trainig_loop






In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: ",device)

device:  cuda


In [14]:
set_memory_GB(1)

Memory fraction set to 0.022458079576498518
Memory fraction in GB: 1.0


In [15]:
train_ds, val_ds, test_ds, info = load_speech_commands_dataset()

In [16]:
# maintain seed for repructablity
np.seed = 42
# tf.random.set_seed(42)
torch.manual_seed(0)

<torch._C.Generator at 0x7fc1ec09e3f0>

In [17]:
label_names = info.features['label'].names[:10]
print(label_names)

['down', 'go', 'left', 'no', 'off', 'on', 'right', 'stop', 'up', 'yes']


In [18]:
augmentations = [
    lambda x: add_time_shift_noise_and_align(x),
    lambda x: add_noise(x,noise_level = 0.01)
]

In [29]:
from librosa.feature import mfcc, delta
from torch.utils.data import Dataset
# Define the dataset adapter:
class TFDatasetAdapter(Dataset):
    def __init__(self, tf_dataset, fixed_length, n_mfcc, n_fft, hop_length, n_mels, augmentations=None):
        self.tf_dataset = tf_dataset
        self.data = list(tf_dataset)
        self.fixed_length = fixed_length
        self.n_mfcc = n_mfcc
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.n_mels = n_mels
        self.augmentations = augmentations

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        audio, label = self.data[idx]
        audio = audio.numpy()

        #convert to float
        audio = audio.astype(np.float32)

        # Ensure the audio tensor has the correct shape (1D array)
        if audio.ndim > 1:
            audio = np.squeeze(audio)
            
        # Apply augmentations if any
        if self.augmentations:
            for aug in self.augmentations:
                audio = aug(audio)

        # Pad or trim the audio to the fixed length
        if len(audio) < self.fixed_length:
            audio = np.pad(audio, (0, self.fixed_length - len(audio)), mode='constant')
        else:
            audio = audio[:self.fixed_length]

        # Create MFCCs from an audio tensor using Librosa.
        audio = audio.astype(np.float32)
        MFCC = mfcc(y=audio, sr=16000, n_mfcc=self.n_mfcc, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels)

        # Create MFCC second, first order delta
        MFCC_delta = delta(MFCC)
        MFCC_delta2 = delta(MFCC, order=2)

        # Stack the three MFCCs together
        MFCC = np.vstack([MFCC, MFCC_delta, MFCC_delta2])

        # Remove extra dimension if it exists
        if MFCC.ndim == 3:
            MFCC = MFCC.squeeze(-1)

        return torch.tensor(MFCC, dtype=torch.float32), torch.tensor(label.numpy(), dtype=torch.long)

In [30]:
# Convert the TFDS dataset to a PyTorch Dataset
fixed_length = 16000
n_mfcc = 13
n_fft = 640
hop_length = 320
n_mels = 40
pytorch_train_dataset = TFDatasetAdapter(train_ds, fixed_length, n_mfcc, n_fft, hop_length, n_mels, augmentations)
pytorch_val_dataset = TFDatasetAdapter(val_ds, fixed_length, n_mfcc, n_fft, hop_length, n_mels, augmentations=None)

In [31]:
# Create a DataLoader to feed the data into the model
batch_size = 16
train_loader = DataLoader(pytorch_train_dataset, batch_size=batch_size, shuffle=True,num_workers=4)
val_loader = DataLoader(pytorch_val_dataset, batch_size=batch_size, shuffle=False,num_workers=4)

In [32]:
for audio, label in train_loader:
    print(audio.shape, label.shape)
    break


torch.Size([16, 39, 51]) torch.Size([16])


# Compute model size

In [53]:
# Initialize model, loss function, and optimizer
input_dim = 39  # Number of MFCC features
d_model = 51  # Number of frames
d_state = 256
d_conv = 64
expand = 4

model = KeywordSpottingModel(input_dim=input_dim, d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand, label_names=label_names, num_mamba_layers=2).to("cuda")


In [54]:

# # Register custom operation
inputs = torch.randn(batch_size, input_dim, d_model).to("cuda")
        
# macs, params, ret_layer_info = thop.profile(model, inputs=(torch.randn(batch_size, 13, 101).to("cuda"),)
# ,custom_ops={Mamba: calculate_MAMBA_flops},report_missing=True, ret_layer_info=True)
# print()
# print(f"MACs: {macs} Which are {macs/1e9} Giga-MACs, Params: {params}")

print_model_size(model,input_size=inputs)

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Customize rule calculate_MAMBA_flops() <class 'mamba_ssm.modules.mamba_simple.Mamba'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.

MACs: 578523192.0 Which are 0.578523192 Giga-MACs, Params: 408928.0



(578523192.0, 408928.0)

# Training loop

# With L2 regulariztion AND Droput layer

In [47]:
criterion = nn.CrossEntropyLoss().to("cuda")
base_optimizer = optim.Adam(model.parameters(), lr=0.0024, weight_decay=2.80475e-05) # weight_decay for L2 regulariztopn

from torch_optimizer import Lookahead

optimizer = Lookahead(base_optimizer, k=5, alpha=0.5)  # Wrap around Adam


In [48]:
import torch.optim as optim
# Adding learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)

In [49]:
num_epochs = 100
train_accuracies, val_accuracies, train_losses, val_losses = trainig_loop(model, num_epochs, train_loader, val_loader, criterion, optimizer, scheduler)

100%|██████████████████████████████████████████████████████████████████| 1924/1924 [00:48<00:00, 39.93it/s]

Epoch 1/100, Training Loss: 1.3170638059627986, Training Accuracy: 52.809646072345544%





Validation Loss: 0.858671082761781, Validation Accuracy: 71.1315149878477%
Learning rate after epoch 1: [0.0024]
Best model saved


100%|██████████████████████████████████████████████████████████████████| 1924/1924 [00:48<00:00, 39.97it/s]

Epoch 2/100, Training Loss: 0.7837926069208713, Training Accuracy: 72.77129578471839%





Validation Loss: 0.7097556083623705, Validation Accuracy: 76.31650013502565%
Learning rate after epoch 2: [0.0024]
Best model saved


100%|██████████████████████████████████████████████████████████████████| 1924/1924 [00:48<00:00, 40.04it/s]

Epoch 3/100, Training Loss: 0.6298098378914472, Training Accuracy: 78.53033897754233%





Validation Loss: 0.5843236756607376, Validation Accuracy: 79.77315689981097%
Learning rate after epoch 3: [0.0024]
Best model saved


100%|██████████████████████████████████████████████████████████████████| 1924/1924 [00:45<00:00, 42.42it/s]

Epoch 4/100, Training Loss: 0.5579245019231666, Training Accuracy: 80.96135721017907%





Validation Loss: 0.48109459758190243, Validation Accuracy: 83.95895220091818%
Learning rate after epoch 4: [0.0024]
Best model saved


100%|██████████████████████████████████████████████████████████████████| 1924/1924 [00:46<00:00, 41.81it/s]

Epoch 5/100, Training Loss: 0.5013808943505501, Training Accuracy: 82.93412200591504%





Validation Loss: 0.47726810187229823, Validation Accuracy: 84.03996759384283%
Learning rate after epoch 5: [0.0024]
Best model saved


100%|██████████████████████████████████████████████████████████████████| 1924/1924 [00:45<00:00, 41.83it/s]

Epoch 6/100, Training Loss: 0.45968605917946215, Training Accuracy: 84.26013195098963%





Validation Loss: 0.45348124148259905, Validation Accuracy: 85.5522549284364%
Learning rate after epoch 6: [0.0024]
Best model saved


100%|██████████████████████████████████████████████████████████████████| 1924/1924 [00:43<00:00, 44.51it/s]

Epoch 7/100, Training Loss: 0.42430258631940276, Training Accuracy: 85.59589196918976%





Validation Loss: 0.40401902301076414, Validation Accuracy: 87.01053200108021%
Learning rate after epoch 7: [0.0024]
Best model saved


100%|██████████████████████████████████████████████████████████████████| 1924/1924 [00:45<00:00, 41.86it/s]

Epoch 8/100, Training Loss: 0.39923949862526353, Training Accuracy: 86.37589781923364%





Validation Loss: 0.38933842750426767, Validation Accuracy: 87.1995679179044%
Learning rate after epoch 8: [0.0024]
Best model saved


100%|██████████████████████████████████████████████████████████████████| 1924/1924 [00:47<00:00, 40.17it/s]

Epoch 9/100, Training Loss: 0.3763978103836011, Training Accuracy: 87.14615359615198%





Validation Loss: 0.36525809988861196, Validation Accuracy: 88.81987577639751%
Learning rate after epoch 9: [0.0024]
Best model saved


100%|██████████████████████████████████████████████████████████████████| 1924/1924 [00:46<00:00, 41.41it/s]

Epoch 10/100, Training Loss: 0.3645204806665497, Training Accuracy: 87.7441580811856%





In [None]:
# Plot the learning curves
plot_learning_curves(train_accuracies, val_accuracies, train_losses, val_losses)

In [None]:
# calc weights for unbalanced dataset
from utils import compute_label_distribution
label_distribution = compute_label_distribution(train_ds)
# Convert label counts to a list of counts in the order of label indices
counts_list = [label_distribution[i] for i in sorted(label_distribution.keys())]
# Compute total number of samples
total_samples = sum(counts_list)

# Compute the number of classes
num_classes = len(counts_list)

# Calculate class weights
class_weights = [total_samples / (num_classes * count) for count in counts_list]




In [None]:
class_weights_normalized = class_weights / np.sum(class_weights)  # Normalize class weights
# print(class_weights_normalized)
class_weights_tensor = torch.tensor(class_weights_normalized).to("cuda")
# fine tune the model with balanced dataset
max_samples = int(len(pytorch_train_dataset) * min(class_weights_normalized) *len(class_weights_normalized))
sampler = torch.utils.data.WeightedRandomSampler(
    weights=class_weights_tensor,
    num_samples=len(pytorch_train_dataset),
    replacement=True
)
train_loader = DataLoader(pytorch_train_dataset, batch_size=batch_size,num_workers=4,prefetch_factor=2,sampler=sampler)


In [None]:
model = KeywordSpottingModel(input_dim=input_dim, d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand, label_names=info.features['label'].names).to("cuda")
model.load_state_dict(torch.load("best_model.pth"),strict=False)
criterion = nn.CrossEntropyLoss()
base_optimizer = optim.Adam(model.parameters(), lr=0.00001, weight_decay=1e-6) # weight_decay for L2 regulariztopn
optimizer = Lookahead(base_optimizer, k=5, alpha=0.5)  # Wrap around Adam
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)

In [None]:
num_epochs = 10
train_accuracies, val_accuracies, train_losses, val_losses = trainig_loop(model, num_epochs, train_loader, val_loader, criterion, optimizer, scheduler, save_best_model=False)

In [51]:
# load test data
model.load_state_dict(torch.load("best_model.pth"))
pytorch_test_dataset = TFDatasetAdapter(test_ds, fixed_length, n_mfcc, n_fft, hop_length, n_mels, augmentations=None)
test_loader = DataLoader(pytorch_test_dataset, batch_size=batch_size, shuffle=False,num_workers=4,prefetch_factor=2)
# model.load_state_dict(torch.load("best_model.pth"),strict=False)
# Evaluate the model on the test set
accuracy = 0
total = 0
model.eval()

with torch.no_grad():
    for audio, labels in test_loader:
        audio, labels = audio.to("cuda"), labels.to("cuda")
        outputs = model(audio)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        accuracy += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * accuracy / total}%')


  model.load_state_dict(torch.load("best_model.pth"))


Test Accuracy: 90.40255277368679%


In [None]:
def compute_inference_GPU_mem(model, input):
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    m0 = torch.cuda.max_memory_allocated()
    model(input)
    m1 = torch.cuda.max_memory_allocated()
    # Compute total memory used
    total_memory = (m1 - m0) / 1e6  # Convert to MB
    
    return total_memory

In [None]:
import pandas as pd
# from utils import compute_inference_GPU_mem
#save model size(macs, params) and accuracy
macs, params = print_model_size(model,input_size=inputs,verbose=True)
macs = macs/1e9
accuracy = max(val_accuracies)
data = {'Model': ['KeywordSpottingModel_RSM_Norm_0-1-2_order'], 'GMACs': [macs], 'Params': [params], 'Accuracy': [accuracy]}
model_config = {'input_dim': input_dim, 'd_model': d_model, 'd_state': d_state, 'd_conv': d_conv, 'expand': expand}
data.update(model_config)
inf_GPU_mem = compute_inference_GPU_mem(model, input=torch.randn(1, input_dim, d_model).to("cuda"))
#inference macs and params
inf_macs, inf_params = print_model_size(model,input_size=torch.randn(1, input_dim, d_model).to("cuda"))
inference_data = {'Inference CUDA Mem in MB': [inf_GPU_mem], 'Inference GMACs': [inf_macs/1e9], 'Inference Params': [inf_params]}
data.update(inference_data)
df = pd.DataFrame(data, index=[0])
df.to_csv('results.csv', mode='a', header=False)


[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Customize rule calculate_MAMBA_flops() <class 'mamba_ssm.modules.mamba_simple.Mamba'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.

MACs: 39179832.0 Which are 0.039179832 Giga-MACs, Params: 46624.0

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Customize rule calculate_MAMBA_flops() <class 'mamba_ssm.modules.mamba_simple.Mamba'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.

MACs: 3941637.0 Which are 0.003941637 Giga-MACs, Params: 46624.0



[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.activation.SiLU'>. Treat it as zero Macs and zero Params.
[00m[INFO] Customize rule calculate_MAMBA_flops() <class 'mamba_ssm.modules.mamba_simple.Mamba'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.ModuleList'>. Treat it as zero Macs and zero Params.
[00m[91m[WARN] Cannot find rule for <class 'torch.nn.modules.normalization.RMSNorm'>. Treat it as zero Macs and zero Params.
[00m[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[91m[WARN] Cannot find rule for <class 'model.KeywordSpottingModel'>. Treat it as zero Macs and zero Params.
[00m
MACs: 578523192.0 Which are 0.578523192 Giga-MACs, Params: 408928.0

Layer-wise information:
Layer: proj
Total FLOPs: 1623024.0, Total Params: 2040.0

Layer: mamba_layers
Total FLOPs: 57689