In [None]:
!pip install git+https://github.com/asteroid-team/torch-audiomentations.git

In [1]:
from typing import List 
from argparse import ArgumentParser
parser = ArgumentParser()
from pathlib import Path 
import datetime 

from torch_audiomentations import (
    Gain,
    PolarityInversion,
    TimeInversion,
    AddBackgroundNoise,
    BandPassFilter,
    BandStopFilter, 
    AddColoredNoise,
    HighPassFilter,
    ApplyImpulseResponse,
    LowPassFilter,
    PitchShift,
    RandomCrop,
    SpliceOut,
    Compose
)


PROBABILITY_OF_APPLICATION = 0.011

parser.add_argument("--prob-aug-gain", type=float,default=PROBABILITY_OF_APPLICATION)
parser.add_argument("--prob-aug-band-pass-filter", type=float,default=PROBABILITY_OF_APPLICATION)
parser.add_argument("--prob-aug-band-stop-filter", type=float,default=PROBABILITY_OF_APPLICATION)
parser.add_argument("--prob-aug-polarity-inversion", type=float,default=PROBABILITY_OF_APPLICATION)
parser.add_argument("--prob-aug-time-inversion", type=float,default=PROBABILITY_OF_APPLICATION)

parser.add_argument("--prob-aug-background-noise", type=float,default=PROBABILITY_OF_APPLICATION)
parser.add_argument("--prob-aug-coloured-noise", type=float,default=PROBABILITY_OF_APPLICATION)
parser.add_argument("--prob-aug-high-pass-filter", type=float,default=PROBABILITY_OF_APPLICATION)
parser.add_argument("--prob-aug-low-pass-filter", type=float,default=PROBABILITY_OF_APPLICATION)
parser.add_argument("--prob-aug-impulse-response", type=float,default=PROBABILITY_OF_APPLICATION)
parser.add_argument("--prob-aug-pitch-shift", type=float,default=PROBABILITY_OF_APPLICATION)


args, _ = parser.parse_known_args()
    
apply_augmentation = Compose(
    transforms=[
        
        Gain(p=args.prob_aug_gain,
            min_gain_in_db=-15.0,
            max_gain_in_db=5.0),
        BandPassFilter(p=args.prob_aug_band_pass_filter,
                        min_center_frequency=200,
                        max_center_frequency=4000,
                        min_bandwidth_fraction=0.5,
                        max_bandwidth_fraction=1.99),
        BandStopFilter(p=args.prob_aug_band_stop_filter,
                       min_center_frequency=200,
                       max_center_frequency=4000,
                       min_bandwidth_fraction=0.5,
                       max_bandwidth_fraction=1.99),
        
        PolarityInversion(p=args.prob_aug_polarity_inversion),
        TimeInversion(p=args.prob_aug_time_inversion),
    
        AddColoredNoise(p=args.prob_aug_coloured_noise, 
                        min_snr_in_db = 3.0,
                        max_snr_in_db = 30.0,
                        min_f_decay = -2.0,
                        max_f_decay = 2.0),
        
        HighPassFilter(p=args.prob_aug_high_pass_filter,
                       min_cutoff_freq=20,
                       max_cutoff_freq=2400), 
        
        LowPassFilter(p=args.prob_aug_low_pass_filter,
                      min_cutoff_freq=150,
                      max_cutoff_freq=7500),
    
        PitchShift(p=args.prob_aug_pitch_shift,
                   min_transpose_semitones= -4.0,
                   max_transpose_semitones= 4.0,
                  sample_rate=8000)
    ]
)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import warnings 
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactive = "all"
warnings.filterwarnings("ignore")
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))


In [3]:
import sys, os
from pprint import pprint
from pathlib import Path 
import json

from datetime import datetime 

import os 
import matplotlib.pyplot as plt 
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchaudio as ta 
import torch


import torch
from torch import optim
import torch.nn as nn 
from torch_audiomentations import Compose, Gain, PolarityInversion
import pytorch_lightning as pl
from typing import List
from sklearn.metrics import precision_score, recall_score, accuracy_score





## Data preparation

In [None]:
def collate_to_tensor_batch(batch):
    x = [ x for (x,_) in batch ]
    y = [ y for (_,y) in batch ]

    x_batched = torch.stack(x).float() 
    y_batched = torch.stack(y).long()
    # return dictionary for unpacking easily as args 
    return {"x": x_batched, "y": y_batched}



class AudioDataset(Dataset):
    def __init__(self,
                 x_path_trg_list,
                 x_transforms=None,
                 x_aug_transforms= None, 
                 y_transforms=None):
    
        self.path_list = x_path_trg_list 
        self.x_transforms = x_transforms
        self.y_transforms = y_transforms 
        self.x_aug_transforms = x_aug_transforms
        
        
    def __len__(self):
        return len(self.path_list)
    

    def __getitem__(self, idx):
        x_path, y =  self.path_list[idx]
        y =  torch.tensor(int(y))
        x,sample_rate = ta.load(x_path)

        if self.x_aug_transforms  is not None:
            x_aug = None
            while x_aug is None:
                try:
                    x_aug = self.x_aug_transforms(torch.unsqueeze(x,1), sample_rate=sample_rate)            
                except ValueError:
                    pass 
        x = pad_sequence(x)
        return x,y


class AudioDataModule(pl.LightningDataModule):
    def __init__(self,train_df, val_df, test_df, train_batch_size=128, val_batch_size=128, test_batch_size=128):
        super().__init__()

        self.train_df = train_df
        self.val_df =  val_df
        self.test_df =  test_df

        
        self.train_paths = list(zip(self.train_df.wav_path.tolist(), self.train_df.label.tolist())) 
        self.val_paths = list(zip(self.val_df.wav_path.tolist(), self.val_df.label.tolist())) 
        self.test_paths = list(zip(self.test_df.wav_path.tolist(), self.test_df.label.tolist())) 

        self.pin_memory =  False # True if torch.cuda.is_available() else False 
        
        self.train_batch_size = train_batch_size
        self.val_batch_size= val_batch_size
        self.test_batch_size =  test_batch_size
        


    def train_dataloader(self):
        ds_train = AudioDataset(path_list = self.train_paths,
                                x_aug_transforms = apply_augmentation)
        return DataLoader(ds_train,
                          batch_size=self.train_batch_size,
                          shuffle=True,
                          drop_last=True,
                          pin_memory= self.pin_memory,
                          collate_fn= collate_to_tensor_batch)

    
    
    def val_dataloader(self):
        ds_val = AudioDataset(path_list = self.val_paths,
                              x_aug_transforms = None)
        return  DataLoader(ds_val,
                          batch_size=self.val_batch_size,
                          shuffle=True,
                          drop_last=True,
                          pin_memory= self.pin_memory,
                          collate_fn= collate_to_tensor_batch)
    
    
    def test_dataloader(self):
        ds_test = AudioDataset(path_list = self.test_paths,
                              x_aug_transforms = None)
        
        return  DataLoader(ds_test,
                          batch_size=self.test_batch_size,
                          shuffle=True,
                          drop_last=True,
                          pin_memory= self.pin_memory,
                          collate_fn= collate_to_tensor_batch)



In [4]:
MAX_SAMPLE_LEN = 20000

def pad_sequence(x:torch.Tensor, pad_to_len:int=MAX_SAMPLE_LEN):
    '''
    Pads audio file upto length pad_to_len
    returns padded audiofile 
    '''
    if x.size()[-1] > pad_to_len: # case longer than pad_to_len
        x_new = x[:,:pad_to_len] 
    else: # case shorted than pad_to_len
        padding = torch.tensor([0.0]).repeat([1,pad_to_len - x.size()[-1]])
        x_new = torch.hstack([x, padding])
        
    assert x_new.size()[-1] == pad_to_len, f"Incorrect padded length, Should be {pad_to_len}, got {x_new.size()[-1]}"
    return x_new # (1 ,1 , pad_to_len)


class AudioDataset(Dataset):
    def __init__(self,
                 path_list:List[str],
                 x_transforms=None,
                 x_aug_transforms= None, 
                 y_transforms=None):
    
        self.path_list = path_list 
        self.x_transforms = x_transforms
        self.y_transforms = y_transforms 
        self.x_aug_transforms = x_aug_transforms
        
        
    def __len__(self):
        return len(self.path_list)
    

    def __getitem__(self, idx):
        x_path =  self.path_list[idx]
        y =  torch.tensor(int(x_path.split("/")[-1].split("_")[0]))
        x,sample_rate = ta.load(x_path)

        if self.x_aug_transforms  is not None:
            x_aug = None
            while x_aug is None:
                try:
                    x_aug = self.x_aug_transforms(torch.unsqueeze(x,1), sample_rate=sample_rate)            
                except ValueError:
                    pass 
        x = pad_sequence(x)
        return x,y





def collate_to_tensor_batch(batch):
    x = [ x for (x,_) in batch ]
    y = [ y for (_,y) in batch ]

    x_batched = torch.stack(x).float() 
    y_batched = torch.stack(y).long()
    # return dictionary for unpacking easily as args 
    return {"x": x_batched, "y": y_batched}




class AudioDataModule(pl.LightningDataModule):
    def __init__(self, train_batch_size=128, val_batch_size=128, test_batch_size=128):
        super().__init__()
        self.all_ds_paths = self.get_all_sample_paths()
        
        train_paths, val_paths, test_paths = self.split_train_val_test(self.all_ds_paths)
        self.train_paths = train_paths
        self.val_paths = val_paths
        self.test_paths = test_paths 
        self.pin_memory =  False # True if torch.cuda.is_available() else False 
        
        self.train_batch_size = train_batch_size
        self.val_batch_size= val_batch_size
        self.test_batch_size =  test_batch_size
        
    def split_train_val_test(self, list_of_paths):
        '''
        splits list of all paths into train, val and testing sets
        '''
        xy_train, xy_val_and_test =  train_test_split(list_of_paths, test_size=0.33, random_state=42)
        xy_val, xy_test =  train_test_split(xy_val_and_test, test_size=0.33, random_state=42)
        return xy_train, xy_val, xy_test 


    def get_all_sample_paths(self):
        '''
        function lists all samples available for training.
        '''
        data_path = Path().cwd() / "datasets/free-spoken-digit-dataset-master/recordings"
        all_file_paths = [str(x) for x in [ data_path /  x for x in os.listdir(data_path)]]
        return all_file_paths 


        
    def train_dataloader(self):
        ds_train = AudioDataset(path_list = self.train_paths,
                                x_aug_transforms = apply_augmentation)
        return DataLoader(ds_train,
                          batch_size=self.train_batch_size,
                          shuffle=True,
                          drop_last=True,
                          pin_memory= self.pin_memory,
                          collate_fn= collate_to_tensor_batch)

    
    
    def val_dataloader(self):
        ds_val = AudioDataset(path_list = self.val_paths,
                              x_aug_transforms = None)
        return  DataLoader(ds_val,
                          batch_size=self.val_batch_size,
                          shuffle=True,
                          drop_last=True,
                          pin_memory= self.pin_memory,
                          collate_fn= collate_to_tensor_batch)
    
    
    def test_dataloader(self):
        ds_test = AudioDataset(path_list = self.test_paths,
                              x_aug_transforms = None)
        
        return  DataLoader(ds_test,
                          batch_size=self.test_batch_size,
                          shuffle=True,
                          drop_last=True,
                          pin_memory= self.pin_memory,
                          collate_fn= collate_to_tensor_batch)


x = AudioDataModule().train_dataloader()
b = next(iter(x))
b['x'].size()

torch.Size([128, 1, 20000])

## Network M5 

In [7]:
'''
Architecture definition

We will define the model architecture in this file

'''

import torch 
import torch.nn.functional as F 
import torch.nn as nn 

class M5(nn.Module):
    '''
    Model origins: 
        https://arxiv.org/pdf/1610.00087.pdf
    '''
    def __init__(self, n_input=1, n_output=10, stride=16, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channel, n_output)

    def forward(self, x, y):
        # y is a dummy variable, it is just such that the dictionary 
        # can easily be unpacked via M5(**batch)
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = F.relu(self.bn4(x))
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        return F.log_softmax(x, dim=2)


# device = "cuda" if torch.cuda.is_available() else "cpu"
# model = M5()
# model = model.to(device)


def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

n = count_trainable_params(model)


print("Model\n",model)
print("Number of trainable layers\n", n)


model = M5()
model_class_name = type(model).__name__




# loading trained model
# filename = [file for file in dir_content if model_class_name in file][0]
path = "/home/akinwilson/Code/pytorch-example/model/2022_07_27-09:49:59_PM-model-M5-val-acc-0.98.pt"

state_dict = torch.load(path)
model.load_state_dict(state_dict,  strict=False)

model = model.to("cuda")

Model
 M5(
  (conv1): Conv1d(1, 32, kernel_size=(80,), stride=(16,))
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(32, 64, kernel_size=(3,), stride=(1,))
  (bn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
  (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool4): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=64, out_features=10, bias=True)


### Training, validation and testing loop

In [9]:


def train_one_step(model, data, optimizer):
    optimizer.zero_grad()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    data = {k:v.to(device) for (k,v) in data.items()}
    y_hat = model(**data)
    y_hat = y_hat.squeeze()
    y = data['y'].squeeze()
    loss = F.nll_loss(y_hat , y, reduction='mean')
    loss.backward()
    optimizer.step()
    return loss 


def train_one_epoch(model, data_loader, optimizer):
    # model put into training mode 
    model.train()
    total_loss = 0 
    for batch_idx, data in enumerate(data_loader):
        loss = train_one_step(model, data, optimizer)
        total_loss += loss.item()
    return total_loss



def validate_one_step(model, data):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    data = {k:v.to(device) for (k,v) in data.items()}
    y_hat = model(**data)

    y_hat = y_hat.squeeze()
    y = data['y'].squeeze()
    
    loss = F.nll_loss(y_hat, y , reduction='mean')

    return loss 

def validate_one_epoch(model,  data_loader):
    all_predictions = [ ]
    all_targets = [ ]
    model.eval()
    total_loss = 0 
    for batch_idx, data in enumerate(data_loader):
        loss = validate_one_step(model, data)
        total_loss += loss.item()


    output = model(**data)

    pred = get_likely_index(output)
    all_predictions.append(pred.cpu().numpy().tolist())
    all_targets.append(data['y'].cpu().numpy().tolist())

    flatten =  lambda l: sum(l , [])
    
    tot_preds = flatten(all_predictions)
    tot_trgs = flatten(all_targets)


    acc = accuracy_score(tot_trgs, tot_preds)
    recall = recall_score(tot_trgs, tot_preds, average='macro')
    precision = precision_score(tot_trgs, tot_preds, average='macro')


    return total_loss, acc, recall, precision




def number_of_correct(pred, target):
    # count number of correct predictions
    return pred.squeeze().eq(target).sum().item()


def get_likely_index(tensor):
    # find most likely label index for each element in the batch
    return tensor.argmax(dim=-1)


def test(model, epoch):
    model.eval()

    correct = 0
    test_loader = data_module.test_dataloader()
    num_samples = test_loader.batch_size * len(test_loader)
    
    all_predictions = [] 
    all_targets = []

    for data in test_loader:

        device = "cuda" if torch.cuda.is_available() else "cpu"
        data = {k:v.to(device) for (k,v) in data.items()}

        output = model(**data)

        pred = get_likely_index(output)
        all_predictions.append(pred.cpu().numpy().tolist())
        all_targets.append(data['y'].cpu().numpy().tolist())

        correct += number_of_correct(pred, data['y'])


    flatten =  lambda l: sum(l , [])
    
    tot_preds = flatten(all_predictions)
    tot_trgs = flatten(all_targets)

    acc = correct / num_samples
    recall = recall_score(tot_trgs, tot_preds, average='macro')
    precision = precision_score(tot_trgs, tot_preds, average='macro')

    print(f"Epoch: {epoch}")
    print(f"========>Test accuracy: {acc:.2f}")
    print(f"========>Test recall: {recall}")
    print(f"========>Test precision: {precision}")    

    return acc, recall, precision


LR = 0.01
WEIGHT_DECAY = 0.0001
MAX_EPOCH = 1000
GAMMA = 0.1

TRAIN_BATCH_SIZE = 64
VAL_BATCH_SIZE = 256


model_class_name = type(model).__name__

# model = torch.nn.DataParallel(model)

optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

opt_class_name = type(optimizer).__name__ 

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=GAMMA) 

data_module = AudioDataModule(train_batch_size=TRAIN_BATCH_SIZE,val_batch_size=VAL_BATCH_SIZE)




early_stopping_condition_met = False 

EARLY_STOPPING_PATIENCE = 20
metric_history = []
epoch = 0 


def check_es(hist, min_or_max='max', es_patience=5):
    extrema = lambda x : max(x) if min_or_max=='max' else min(x)
    return len(hist) - hist.index(extrema(hist)) >= es_patience , extrema(hist)


train_losses, val_losses = [ ], [ ]
for epoch in range(1, MAX_EPOCH + 1):

    data_loader = data_module.train_dataloader()

    total_train_epoch_losss = train_one_epoch(model, data_loader, optimizer)
    avg_per_sample_loss = total_train_epoch_losss/len(data_loader)
    train_losses.append((epoch, avg_per_sample_loss))
    print(f"Epoch: {epoch}.\n======>Avg loss: {avg_per_sample_loss}")
    scheduler.step()


    data_loader = data_module.val_dataloader()

    total_val_epoch_losss, val_acc, val_recall, val_precision = validate_one_epoch(model, data_loader)

    avg_per_sample_loss = total_val_epoch_losss/len(data_loader)
    val_losses.append((epoch, avg_per_sample_loss))
    print(f"Epoch: {epoch}.\n======>Avg loss: {avg_per_sample_loss}")

    print(f"Current learning rate: {scheduler.get_lr()[0]}")
    
    test_acc, test_recall, test_precision = test(model, epoch)  

    metric_history.append(val_acc)
    early_stopping_condition_met, extrema_val = check_es(metric_history, es_patience=EARLY_STOPPING_PATIENCE)

    if early_stopping_condition_met:
        print("Early stopping condition met")
        break 


print(f"Final epoch: {epoch}")
# mat


def save_torch_object(obj, path):
    torch.save(obj.state_dict(), path)

model_dir = Path().cwd() / "model" 
(model_dir).mkdir(parents=True, exist_ok=True)

date = datetime.now().strftime("%Y_%m_%d-%I:%M:%S_%p")


model_file_name = f"{date}-model-{model_class_name}-val-acc-{extrema_val:.2f}.pt"
optimizer_file_name = f"{date}-opt-{opt_class_name}-epoch-{epoch}.pt"


save_torch_object( model , model_dir / model_file_name)
save_torch_object( optimizer , model_dir / optimizer_file_name)

Epoch: 1.


RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

RuntimeError: Error(s) in loading state_dict for M5:
	Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "conv2.weight", "conv2.bias", "bn2.weight", "bn2.bias", "bn2.running_mean", "bn2.running_var", "conv3.weight", "conv3.bias", "bn3.weight", "bn3.bias", "bn3.running_mean", "bn3.running_var", "conv4.weight", "conv4.bias", "bn4.weight", "bn4.bias", "bn4.running_mean", "bn4.running_var", "fc1.weight", "fc1.bias". 
	Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.bn1.num_batches_tracked", "module.conv2.weight", "module.conv2.bias", "module.bn2.weight", "module.bn2.bias", "module.bn2.running_mean", "module.bn2.running_var", "module.bn2.num_batches_tracked", "module.conv3.weight", "module.conv3.bias", "module.bn3.weight", "module.bn3.bias", "module.bn3.running_mean", "module.bn3.running_var", "module.bn3.num_batches_tracked", "module.conv4.weight", "module.conv4.bias", "module.bn4.weight", "module.bn4.bias", "module.bn4.running_mean", "module.bn4.running_var", "module.bn4.num_batches_tracked", "module.fc1.weight", "module.fc1.bias". 

[1.0000000000000003e-05]

In [None]:


def split_train_val_test(list_of_paths):
    # train_test_split(list_path_paths test_size=0.33, random_state=42)
    xy_train, xy_val_and_test =  train_test_split(list_of_paths, test_size=0.33, random_state=42)
    xy_val, xy_test =  train_test_split(xy_val_and_test, test_size=0.33, random_state=42)
    return xy_train, xy_val, xy_test 


def get_all_sample_paths():
    '''
    function lists all samples available for training.
    '''
    data_path = Path().cwd() / "datasets/free-spoken-digit-dataset-master/recordings"
    all_file_paths = [str(x) for x in [ data_path /  x for x in os.listdir(data_path)]]
    return all_file_paths 



# waveform, sample_rate = ta.load(all_file_paths[0])

def plot_waveform(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("waveform")
    plt.show(block=False)
    
# plot_waveform(waveform, sample_rate)

# waveform.var(axis=1)


def get_avg_mean_and_var(full_path_file_list):
    means, variances, sample_rates = [], [], []
    for f in full_path_file_list: 
        waveform, x  = ta.load(f)
        means.append(waveform.mean(axis=1))
        variances.append(waveform.var(axis=1))
        sample_rates.append(x)
        
    
    return torch.hstack(means).mean(axis=0), torch.hstack(variances).mean(axis=0), set(sample_rates)
        
# get_avg_mean_and_var(get_all_sample_paths())
# train_paths, val_paths, test_paths  = split_train_val_test(get_all_sample_paths())