In [None]:
import os
import torch
import datetime
import time

import torch

import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import KFold
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import subprocess
import os
from utils import *
from models import *
from my_loss import *
from data_process import *

In [None]:
CONSTANTS = InitializationTrain(
    verbose=True
)
dataset = CPCdataBinaural(metadata=CONSTANTS.metadata)

In [None]:
model = EncoderPredictor().to(CONSTANTS.device)
mel = model.logmel

In [None]:
train_loader = DataLoader(dataset=dataset, batch_size=3)

## Listener Info (Audiogram)

In [None]:
listener_info = ListenerInfo(['L0231', 'L0201'])
audiogram_l = [listener_info.info[i]['audiogram_l'] for i in range(len(listener_info.info))]
audiogram_r = [listener_info.info[i]['audiogram_r'] for i in range(len(listener_info.info))]
audiogram_cfs = [listener_info.info[i]['audiogram_cfs'] for i in range(len(listener_info.info))]

In [None]:
audiogram_l

In [None]:
listener_info = ListenerInfo(['L0231', 'L0201'])
print(listener_info.info)
print(listener_info.info[0]['audiogram_l'])
print(listener_info.info[0]['audiogram_r'])
print(listener_info.info[0]['audiogram_cfs'])

In [None]:
for speech_input_l, speech_input_r, info_dict in tqdm(train_loader, desc="Training:"):
    mel_feature_l, mel_feature_length = mel(
            input_signal=speech_input_l.to(device),
            length=torch.full((speech_input_l.shape[0],), speech_input_l.shape[1]).to(device),
        )
    listener_info = ListenerInfo(info_dict['listener'])
    listener_info.info['audiogram_l']
    break

In [None]:
print(model.asr_model.cfg)

In [None]:
for item, value in model.asr_model.cfg.items():
    print(item, value)

In [None]:
for item, value in model.asr_model.cfg['preprocessor'].items():
    print(item, value)

In [None]:
import numpy as np
from scipy.interpolate import interp1d

# 示例数据
a = np.array([0, 1, 2, 3, 4, 5, 6, 7])  # 8元素的频率数组
b = np.array([10, 20, 30, 40, 50, 60, 70, 80])  # 8元素的值数组
c = np.linspace(0, 7, 80)  # 80元素的频率数组

# 创建线性插值函数
linear_interpolation = interp1d(a, b)

# 计算c中每个频率对应的值
result = linear_interpolation(c)

# 打印结果
print(result)


In [None]:
import numpy as np

def mel_to_hz(mel):
    return 700 * (10**(mel / 2595) - 1)

def hz_to_mel(hz):
    return 2595 * np.log10(1 + hz / 700)

def get_central_frequencies(nfilt, lowfreq, highfreq):
    low_mel = hz_to_mel(lowfreq)
    high_mel = hz_to_mel(highfreq)

    mel_points = np.linspace(low_mel, high_mel, nfilt + 2)  # nfilt + 2 points to include bounds
    hz_points = mel_to_hz(mel_points)

    central_frequencies = hz_points[1:-1]  # exclude the first and last points
    return central_frequencies


nfilt = 80
lowfreq = 0
highfreq = 8000

central_frequencies = get_central_frequencies(nfilt, lowfreq, highfreq)
print(central_frequencies)


In [None]:
len(central_frequencies)

In [None]:
class HurricaneData(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        for label_folder in os.listdir(root_dir):
            label_path = os.path.join(root_dir, label_folder)
            if os.path.isdir(label_path):
                for audio_file in os.listdir(os.path.join(label_path,'ssn')):
                    audio_path = os.path.join(label_path, audio_file)
                    self.samples.append(audio_path)

In [None]:
import os
root_dir = '/home/ubuntu/elec823/hurricane'
print(os.listdir(root_dir))
label_folder = os.listdir(root_dir)[0]
label_path = os.path.join(root_dir, label_folder)
print(label_path)
print(os.path.join(label_path,'ssn'))

In [None]:
for mod_folder in os.listdir(root_dir):
    ssn_path = os.path.join(root_dir, mod_folder, 'ssn')
print(os.listdir(ssn_path))

In [None]:

import os
root_dir = '/home/ubuntu/elec823/hurricane'
samples = []
for mod_folder in os.listdir(root_dir):
    if mod_folder.startswith("."):
        continue
    ssn_path = os.path.join(root_dir, mod_folder, 'ssn')
    # print(ssn_path)
    for snr in os.listdir(ssn_path):
        if snr.startswith("."):
            continue
        snr_path = os.path.join(ssn_path, snr)
        for audio_file in os.listdir(snr_path):
            audio_path = os.path.join(snr_path, audio_file)
            samples.append(audio_path)

In [None]:
print(len(samples))
print(samples[1])

In [None]:
import torchaudio
waveform, sample_rate = torchaudio.load(samples[1])

In [None]:
waveform.shape
import torch
a = torch.mean(waveform, dim=0)
a.shape

In [None]:
os.path.join(label_path,'ssn')

In [None]:
samples

In [27]:
class HurricaneData(Dataset):
    def __init__(self, state, root_dir='/home/ubuntu/elec823/hurricane', transform=None):
        self.state = state
        self.root_dir = root_dir
        self.transform = transform
        self.scores = scipy.io.loadmat(os.path.join(root_dir, 'scores.mat'))['intell']
        self.all_samples = []
        self.noise_types = {"cs":0, "ssn":1}
        self.snrs = {"snrHi":0, "snrMid":1, "snrLo":2}

        for mod_folder in os.listdir(root_dir):
            if mod_folder.startswith("."):
                continue
            ssn_path = os.path.join(root_dir, mod_folder, 'ssn')
            if not os.path.isdir(ssn_path):
                continue
            for snr in os.listdir(ssn_path):
                if snr.startswith("."):
                    continue
                snr_path = os.path.join(ssn_path, snr)
                for audio_file in os.listdir(snr_path):
                    audio_path = os.path.join(snr_path, audio_file)
                    self.all_samples.append(audio_path)
        idx = 0
        val_list = []
        for i in range(0, len(self.all_samples), 180):
            val_list.extend(self.all_samples[i:i+36])
        if self.state == 'train':
            self.samples = [item for item in self.all_samples if item not in val_list]
        elif self.state == 'valid':
            self.samples = val_list

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        audio_path = self.samples[idx]
        split_str = audio_path.split('/')
        output = [split_str[-4], split_str[-3], split_str[-2], split_str[-1].split('_')[-1].split('.')[0]]
        numbers = int(''.join(re.findall(r'\d+', output[0])))
        noise_type = self.noise_types[output[1]]
        snr = self.snrs[output[2]]
        utt = int(output[3])
        score = self.scores[numbers-1][noise_type][snr][utt-1]
        
        waveform, sample_rate = torchaudio.load(audio_path)
        waveform = torch.mean(waveform, dim=0)
        
        # Pad or trim the audio to 3 seconds
        desired_length = sample_rate * 3  # keep 3 seconds
        if waveform.size(-1) < desired_length:
            padding = desired_length - waveform.size(-1)
            waveform = torch.nn.functional.pad(waveform, (0, padding), "constant")
        elif waveform.size(-1) > desired_length:
            waveform = waveform[..., :desired_length]
        # if self.transform:
        #     waveform = self.transform(waveform)

        return waveform, waveform, score
dataset_train = HurricaneData('train')
dataset_valid = HurricaneData('valid')
train_loader = DataLoader(dataset=dataset, batch_size=3)
for i, j in train_loader:
    break

In [28]:
print(dataset_train.__len__())
print(dataset_valid.__len__())


8640
2160


In [34]:
a = ["1", "2", "3"]
c = [4,5,6]
b =[]
b=b+a
print(b)

['1', '2', '3']


In [20]:
dataset.samples

['/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_009.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_029.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_133.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_041.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_154.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_023.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_120.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_063.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_174.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_178.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_160.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_025.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_001.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_024.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_004.wav',
 '/home/ubuntu/elec823/hurricane/mod10/ssn/snrLo/hvd_17

In [None]:
for i in train_loader:
    print(i.shape)
    break

In [None]:
import scipy.io

data = scipy.io.loadmat('/home/ubuntu/elec823/cache.mat')['intell']

print(data)

In [12]:
input_str = '*/123/123/hurricane/mod10/ssn/snrLo/hvd_009.wav'

# 使用'/'分割字符串
split_str = input_str.split('/')

# 从分割后的字符串列表中提取所需部分
output = [split_str[-4], split_str[-3], split_str[-2], split_str[-1].split('_')[-1].split('.')[0]]

print(output)
import re
noise_types = {"cs":0, "ssn":1}
snrs = {"snrHi":0, "snrMid":1, "snrLo":2}
numbers = int(''.join(re.findall(r'\d+', output[0])))
noise_type = noise_types[output[1]]
snr = snrs[output[2]]
utt = int(output[3])
print(numbers, noise_type, snr, utt)

['mod10', 'ssn', 'snrLo', '009']
10 1 2 9


In [16]:
import os
import torch
import datetime
import time

import torch

import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import KFold
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import subprocess
import os
from utils import *
from models import *
from my_loss import *
from data_process import *

model = EncoderPredictorHI_v3().to(device)

[NeMo W 2023-04-24 23:19:49 optimizers:54] Apex was not found. Using the lamb or fused_adam optimizer will error out.
    
[NeMo W 2023-04-24 23:19:49 experimental:27] Module <class 'nemo.collections.asr.modules.audio_modules.SpectrogramToMultichannelFeatures'> is experimental, not ready for production and is not fully supported. Use at your own risk.


[NeMo I 2023-04-24 23:19:54 mixins:170] Tokenizer SentencePieceTokenizer initialized with 128 tokens


[NeMo W 2023-04-24 23:19:55 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: /data/NeMo_ASR_SET/English/v2.0/train/tarred_audio_manifest.json
    sample_rate: 16000
    batch_size: 32
    shuffle: true
    num_workers: 8
    pin_memory: true
    use_start_end_token: false
    trim_silence: false
    max_duration: 20.0
    min_duration: 0.1
    shuffle_n: 2048
    is_tarred: true
    tarred_audio_filepaths: /data/NeMo_ASR_SET/English/v2.0/train/audio__OP_0..4095_CL_.tar
    
[NeMo W 2023-04-24 23:19:55 modelPT:168] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manifest_filepath:
    - /data/ASR/LibriSpeech/librispeech_withs

[NeMo I 2023-04-24 23:19:55 features:287] PADDING: 0
[NeMo I 2023-04-24 23:19:57 save_restore_connector:247] Model EncDecCTCModelBPE was successfully restored from /home/ubuntu/.cache/huggingface/hub/models--nvidia--stt_en_conformer_ctc_large/snapshots/2c8326e4e43ae5b994612cfea3f3029818fb23c6/stt_en_conformer_ctc_large.nemo.


In [17]:
for name, param in model.named_parameters():
    print(name, param.size())

hearing_impairment.conv1.0.weight torch.Size([32, 160, 7])
hearing_impairment.conv1.0.bias torch.Size([32])
hearing_impairment.conv1.2.weight torch.Size([32])
hearing_impairment.conv1.2.bias torch.Size([32])
hearing_impairment.conv2.0.weight torch.Size([64, 32, 5])
hearing_impairment.conv2.0.bias torch.Size([64])
hearing_impairment.conv2.2.weight torch.Size([64])
hearing_impairment.conv2.2.bias torch.Size([64])
hearing_impairment.conv3.0.weight torch.Size([128, 64, 3])
hearing_impairment.conv3.0.bias torch.Size([128])
hearing_impairment.conv3.2.weight torch.Size([128])
hearing_impairment.conv3.2.bias torch.Size([128])
hearing_impairment.fc.0.weight torch.Size([512, 4736])
hearing_impairment.fc.0.bias torch.Size([512])
conformer_encoder.pre_encode.out.weight torch.Size([512, 10240])
conformer_encoder.pre_encode.out.bias torch.Size([512])
conformer_encoder.pre_encode.conv.0.weight torch.Size([512, 1, 3, 3])
conformer_encoder.pre_encode.conv.0.bias torch.Size([512])
conformer_encoder.pre_