In [1]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, random_split
import torchaudio
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
from torchinfo import summary
import torch.nn.functional as F
from torch.nn import init


from model_model import SelfPackLSTM
from model_configs import ModelDimConfigs, TrainingConfigs
from misc_tools import get_timestamp
from model_dataset import DS_Tools, Padder, TokenMap, NormalizerKeepShape
from model_dataset import SingleRecSelectBalanceDatasetPrecombine as ThisDataset
from model_filter import XpassFilter
from paths import *
from ssd_paths import *
from misc_progress_bar import draw_progress_bar
from misc_recorder import *

from A_00_model import *

In [2]:
mytrans = nn.Sequential(
    Padder(sample_rate=TrainingConfigs.REC_SAMPLE_RATE, pad_len_ms=250, noise_level=1e-4), 
    torchaudio.transforms.MelSpectrogram(TrainingConfigs.REC_SAMPLE_RATE, 
                                         n_mels=TrainingConfigs.N_MELS, 
                                         n_fft=TrainingConfigs.N_FFT, 
                                         power=2), 
    torchaudio.transforms.AmplitudeToDB(stype="power", top_db=80), 
    NormalizerKeepShape(NormalizerKeepShape.norm_mvn)
)

with open(os.path.join(src_, "no-stress-seg.dict"), "rb") as file:
    # Load the object from the file
    mylist = pickle.load(file)

mylist.remove('AH') # we don't include this, it is too mixed. 
select = mylist

# Now you can use the loaded object
mymap = TokenMap(mylist)
use_train_ds = ThisDataset(strain_cut_audio_, 
                       os.path.join(suse_, "guide_train.csv"), 
                       select=select, 
                       mapper=mymap, 
                       transform=mytrans)
use_valid_ds = ThisDataset(strain_cut_audio_, 
                       os.path.join(suse_, "guide_validation.csv"), 
                       select=select, 
                       mapper=mymap,
                       transform=mytrans)
use_proportion = 0.01

train_loader = DataLoader(use_train_ds, batch_size=TrainingConfigs.BATCH_SIZE, 
                          shuffle=True, 
                          num_workers=TrainingConfigs.LOADER_WORKER)
train_num = len(train_loader.dataset)

valid_loader = DataLoader(use_valid_ds, batch_size=TrainingConfigs.BATCH_SIZE, 
                          shuffle=False, 
                          num_workers=TrainingConfigs.LOADER_WORKER)
valid_num = len(valid_loader.dataset)