In [11]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, random_split
import torchaudio
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
from torchinfo import summary
import torch.nn.functional as F
from torch.nn import init
from H_10_models import SmallNetwork, MediumNetwork, LargeNetwork, ResLinearNetwork, LSTMNetwork
from model_configs import ModelDimConfigs, TrainingConfigs
from misc_tools import get_timestamp, ARPABET
from model_dataset import DS_Tools, Padder, TokenMap, NormalizerKeepShape
from model_dataset import SingleRecSelectBalanceDatasetPrecombine as ThisDataset
from model_filter import XpassFilter
from paths import *
from ssd_paths import *
from misc_progress_bar import draw_progress_bar
from misc_recorder import *
from H_11_drawer import draw_learning_curve_and_accuracy
import argparse
from tqdm import tqdm

In [13]:
class TrainingConfigs: 
    # BATCH_SIZE = 64
    BATCH_SIZE = 256 # NOTE: 20240813 changed to 32 due to smaller data size. 

    REC_SAMPLE_RATE = 16000
    N_FFT = 400
    N_MELS = 64

    N_MFCC = 13

    N_SPEC = 201
    
    LOADER_WORKER = 32

In [14]:
def load_data(type="f", sel="full", load="train", model_save_dir=""):
    if type == "l":
        mytrans = nn.Sequential(
            Padder(sample_rate=TrainingConfigs.REC_SAMPLE_RATE, pad_len_ms=250, noise_level=1e-4), 
            XpassFilter(cut_off_upper=500),
            torchaudio.transforms.MelSpectrogram(TrainingConfigs.REC_SAMPLE_RATE, 
                                                n_mels=TrainingConfigs.N_MELS, 
                                                n_fft=TrainingConfigs.N_FFT, 
                                                power=2), 
            torchaudio.transforms.AmplitudeToDB(stype="power", top_db=80), 
            # NormalizerKeepShape(NormalizerKeepShape.norm_mvn)
        )
    elif type == "h": 
        mytrans = nn.Sequential(
            Padder(sample_rate=TrainingConfigs.REC_SAMPLE_RATE, pad_len_ms=250, noise_level=1e-4), 
            XpassFilter(cut_off_upper=10000, cut_off_lower=4000),
            torchaudio.transforms.MelSpectrogram(TrainingConfigs.REC_SAMPLE_RATE, 
                                                n_mels=TrainingConfigs.N_MELS, 
                                                n_fft=TrainingConfigs.N_FFT, 
                                                power=2), 
            torchaudio.transforms.AmplitudeToDB(stype="power", top_db=80), 
            # NormalizerKeepShape(NormalizerKeepShape.norm_mvn)
        )
    else: 
        mytrans = nn.Sequential(
            Padder(sample_rate=TrainingConfigs.REC_SAMPLE_RATE, pad_len_ms=250, noise_level=1e-4), 
            torchaudio.transforms.MelSpectrogram(TrainingConfigs.REC_SAMPLE_RATE, 
                                                n_mels=TrainingConfigs.N_MELS, 
                                                n_fft=TrainingConfigs.N_FFT, 
                                                power=2), 
            torchaudio.transforms.AmplitudeToDB(stype="power", top_db=80), 
            # NormalizerKeepShape(NormalizerKeepShape.norm_mvn)
            # We don't want to use normalizer here 
        )
    with open(os.path.join(src_, "no-stress-seg.dict"), "rb") as file:
        # Load the object from the file
        mylist = pickle.load(file)
        mylist.remove('AH') # we don't include this, it is too mixed. 

    if sel == "c": 
        select = ARPABET.intersect_lists(mylist, ARPABET.list_consonants())
    elif sel == "v":
        select = ARPABET.intersect_lists(mylist, ARPABET.list_vowels())
    else:
        select = mylist
    # Now you can use the loaded object
    mymap = TokenMap(mylist)
    if load == "train": 
        train_ds = ThisDataset(strain_cut_audio_, 
                            os.path.join(suse_, "guide_train.csv"), 
                            select=select, 
                            mapper=mymap, 
                            transform=mytrans)
        
        train_ds_indices = DS_Tools.read_indices(os.path.join(model_save_dir, f"train_{sel}.use"))
        use_train_ds = torch.utils.data.Subset(train_ds, train_ds_indices)
        # use_train_ds = train_ds
        train_loader = DataLoader(use_train_ds, batch_size=TrainingConfigs.BATCH_SIZE, 
                                shuffle=True, 
                                num_workers=TrainingConfigs.LOADER_WORKER)
        
        return train_loader
    elif load == "valid":
        valid_ds = ThisDataset(strain_cut_audio_, 
                            os.path.join(suse_, "guide_validation.csv"), 
                            select=select, 
                            mapper=mymap,
                            transform=mytrans)
        valid_ds_indices = DS_Tools.read_indices(os.path.join(model_save_dir, f"valid_{sel}.use"))
        use_valid_ds = torch.utils.data.Subset(valid_ds, valid_ds_indices)
        # use_valid_ds = valid_ds
        valid_loader = DataLoader(use_valid_ds, batch_size=TrainingConfigs.BATCH_SIZE, 
                                shuffle=False, 
                                num_workers=TrainingConfigs.LOADER_WORKER)
        return valid_loader

In [15]:
# Initialize lists to store per-iteration statistics
means = []
variances = []
sample_counts = []

In [18]:
for run in range(5, 21): 
    # Read in the dataset
    ts = "0905160507"
    train_name = "H21"
    model_save_dir = os.path.join(model_save_, f"{train_name}-{ts}-{run}")

    train_loader = load_data(type="f", 
                            load="train", 
                            model_save_dir=model_save_dir)
    # We use normal dataset, but just deleted the normalizer from my_trans
    # In this way, we just loop over the batches and collect the mels. 
    all_mels = []
    for x, y in tqdm(train_loader):
        all_mels.append(x)

    all_mels_cat = torch.cat(all_mels, dim=0)

    # Compute mean and variance for this iteration
    iteration_mean = all_mels_cat.mean()
    iteration_var = all_mels_cat.var(unbiased=True)  # Unbiased variance
    iteration_samples = all_mels_cat.numel()         # Total number of elements

    means.append(iteration_mean)
    variances.append(iteration_var)
    sample_counts.append(iteration_samples)

100%|██████████| 98/98 [00:02<00:00, 36.30it/s]
100%|██████████| 98/98 [00:02<00:00, 47.79it/s] 
100%|██████████| 98/98 [00:02<00:00, 48.74it/s]
100%|██████████| 98/98 [00:01<00:00, 52.85it/s]
100%|██████████| 98/98 [00:02<00:00, 48.79it/s]
100%|██████████| 98/98 [00:02<00:00, 40.75it/s]
100%|██████████| 98/98 [00:02<00:00, 48.21it/s]
100%|██████████| 98/98 [00:02<00:00, 47.32it/s]
100%|██████████| 98/98 [00:02<00:00, 43.70it/s]
100%|██████████| 98/98 [00:02<00:00, 47.83it/s]
100%|██████████| 98/98 [00:01<00:00, 49.32it/s]
100%|██████████| 98/98 [00:01<00:00, 49.05it/s]
100%|██████████| 98/98 [00:02<00:00, 46.59it/s]
100%|██████████| 98/98 [00:02<00:00, 45.74it/s]
100%|██████████| 98/98 [00:02<00:00, 45.05it/s]
100%|██████████| 98/98 [00:02<00:00, 48.93it/s]


In [19]:
# Aggregate results to compute total mean and variance
total_samples = sum(sample_counts)

# Compute total mean
total_mean = sum(n * m for n, m in zip(sample_counts, means)) / total_samples

# Compute total variance
total_variance = sum(n * (v + m**2) for n, v, m in zip(sample_counts, variances, means)) / total_samples - total_mean**2

# Compute total standard deviation
total_std = torch.sqrt(total_variance)

print(f"Estimated Mean: {total_mean}, Estimated Std: {total_std}")

Estimated Mean: -41.01996612548828, Estimated Std: 20.34762954711914


In [None]:
# Estimated Mean: -41.035614013671875, Estimated Std: 20.338979721069336
# Estimated Mean: -41.01996612548828, Estimated Std: 20.34762954711914

In [21]:
mean_std_dict = {'mean': total_mean, 'std': total_std}

with open(os.path.join(src_, "mv_config_20.pkl"), "wb") as file:
    pickle.dump(mean_std_dict, file)