In [10]:
import numpy as np
import soundfile as sf
from onlinebss import OnlineAuxIvaIss, OnlineAecWpeIva, OnlineBssCpp
import logging
from functools import wraps
import os
from typing import Optional
import matplotlib.pyplot as plt
import tempfile
import pickle
import json
import time
import wave

In [18]:


# Настройка логирования
logging.basicConfig(level=logging.DEBUG, filename='iva_process.log', filemode='w',
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

logger = logging.getLogger(__name__)

LOG_EVERY_N_FRAMES_10000 = 10000  # Логировать каждый N-й фрейм
LOG_EVERY_N_FRAMES_5000 = 5000  # Логировать каждый N-й фрейм
LOG_EVERY_N_FRAMES_1000 = 1000  # Логировать каждый N-й фрейм

# Функция для очистки словаря norms
def reset_norms():
    return {
        '10000': {
            'invW': [],
            'W': [],
            'G': [],
            'S': [],
            'invS': [],
            'R': [],
            'buffer_x': [],
            # 'buffer_h': []
        },
        '5000': {
            'invW': [],
            'W': [],
            'G': [],
            'S': [],
            'invS': [],
            'R': [],
            'buffer_x': [],
            # 'buffer_h': []
        },
        '1000': {
            'invW': [],
            'W': [],
            'G': [],
            'S': [],
            'invS': [],
            'R': [],
            'buffer_x': [],
            # 'buffer_h': []
        }
    }

# Функция для сохранения словаря norms в JSON
def save_norms_to_json(norms, filename):
    with open(filename, 'w') as f:
        json.dump(norms, f, indent=4)
    # print(f"Norms saved to {filename}")

# Декоратор для логирования
def log_matrices(func):
    @wraps(func)
    def wrapper(self, *args, **kwargs):
        result = func(self, *args, **kwargs)
        frame_idx = self.frame_idx if hasattr(self, 'frame_idx') else 0
        self.frame_idx = frame_idx + 1
        
        if self.frame_idx % LOG_EVERY_N_FRAMES_10000 == 0:
            self.norms['10000']['invW'].append(np.linalg.norm(self._invW))
            self.norms['10000']['W'].append(np.linalg.norm(self._W))
            self.norms['10000']['G'].append(np.linalg.norm(self._G))
            self.norms['10000']['S'].append(np.linalg.norm(self._S))
            self.norms['10000']['invS'].append(np.linalg.norm(self._invS))
            self.norms['10000']['R'].append(np.linalg.norm(self._R))
            self.norms['10000']['buffer_x'].append(np.linalg.norm(self._buffer_x))
            # self.norms['10000']['buffer_h'].append(np.linalg.norm(self._buffer_h))
        if self.frame_idx % LOG_EVERY_N_FRAMES_5000 == 0:
            self.norms['5000']['invW'].append(np.linalg.norm(self._invW))
            self.norms['5000']['W'].append(np.linalg.norm(self._W))
            self.norms['5000']['G'].append(np.linalg.norm(self._G))
            self.norms['5000']['S'].append(np.linalg.norm(self._S))
            self.norms['5000']['invS'].append(np.linalg.norm(self._invS))
            self.norms['5000']['R'].append(np.linalg.norm(self._R))
            self.norms['5000']['buffer_x'].append(np.linalg.norm(self._buffer_x))
            # self.norms['5000']['buffer_h'].append(np.linalg.norm(self._buffer_h))
        if self.frame_idx % LOG_EVERY_N_FRAMES_1000 == 0:
            self.norms['1000']['invW'].append(np.linalg.norm(self._invW))
            self.norms['1000']['W'].append(np.linalg.norm(self._W))
            self.norms['1000']['G'].append(np.linalg.norm(self._G))
            self.norms['1000']['S'].append(np.linalg.norm(self._S))
            self.norms['1000']['invS'].append(np.linalg.norm(self._invS))
            self.norms['1000']['R'].append(np.linalg.norm(self._R))
            self.norms['1000']['buffer_x'].append(np.linalg.norm(self._buffer_x))
            # self.norms['1000']['buffer_h'].append(np.linalg.norm(self._buffer_h))
        
        for handler in logger.handlers:
            handler.flush()
        
        return result
    return wrapper


# Функция для регуляризации матриц
def regularize(matrix, lambda_reg=1e-5):
    return matrix - lambda_reg * matrix

# Функция для ограничения значений матриц
def clip(matrix, min_value=-1e10, max_value=1e10):
    return np.clip(matrix, min_value, max_value)

# Функция для нормализации матриц
def normalize(matrix):
    norm = np.linalg.norm(matrix)
    if norm < 1e-20:
        return matrix
    return matrix / norm

# Подкласс с логированием
class LoggedOnlineAecWpeIva(OnlineAecWpeIva):
    def __init__(self, num_channels, norms):
        super().__init__(num_channels)
        self.norms = norms
        self.total_normalization_time = 0

    @log_matrices
    def process(
        self,
        frame: np.ndarray,
        aec_frame: Optional[np.ndarray] = None,
        alpha: float = 0.99,
        beta: float = 0.996,
        iter: int = 1,
        mdp: bool = False,
        denoise_db_reduction: float = 11,
        denoise_beta: float = 3,
        denoise_alpha: float = 1.05,
    ) -> np.ndarray:
        if self._N > 0:
            if self._aligner:
                aec_frame = self._aligner.process(aec_frame, frame[:, 0])
            self._update_aec(aec_frame)
        x = self._stft.analysis(frame)
        y = self._process(x, alpha, beta, iter)
        if mdp:
            ys = self._mdp.process(x, y, alpha) 
        else: 
            ys = self._projection_back(y)
        if self._denoise:
            ys = self._denoise.process(ys, denoise_db_reduction, denoise_beta, denoise_alpha)
        result = self._stft.synthesis(ys)
        
        start_time = time.time()
                
        # # Нормализация матриц
                        
        self._invW = normalize(self._invW)
        # self._W = normalize(self._W)
        # # self._G = normalize(self._G)
        self._S = normalize(self._S) # ctoit poprobovat, no W invS and R rostut
        self._invS = normalize(self._invS) # ctoit poprobovat, no W S and R rostut
        # self._R = normalize(self._R) # plavet
        # self._buffer_x = normalize(self._buffer_x)
        # self._buffer_h = normalize(self._buffer_h)
        
        self.total_normalization_time += time.time() - start_time
        return result

def save_plot(data, title, filename, LOG_EVERY_N_FRAMES):
    plt.figure(figsize=(12, 8))
    plt.plot(data)
    plt.xlabel('Frame Index (every {} frames)'.format(LOG_EVERY_N_FRAMES))
    plt.ylabel('Norm')
    plt.yscale('log')  # Устанавливаем логарифмический масштаб по оси Y
    plt.title(title)
    plt.grid(True)
    plt.savefig(filename)
    plt.close()
    # print(f"Save {filename}")

# Параметры
dir_file = "/home/pyatanin/VSproj/30.09.2024/august_records/"
dir_out = "/home/pyatanin/VSproj/30.09.2024/write_records2/"
png_dir = '/home/pyatanin/VSproj/30.09.2024/png_graf_9_10/'
json_dir = '/home/pyatanin/VSproj/30.09.2024/json_data/'
filename_List = [
    'record_3',
]
sample_rate = 16000
frame_size = 512
num_channels = 4
bits_per_sample = 64  # 64-битные данные
bytes_per_sample = bits_per_sample // 8
frame_byte_size = frame_size * num_channels * bytes_per_sample

# Создаем один экземпляр LoggedOnlineAecWpeIva
norms = reset_norms()
iva = LoggedOnlineAecWpeIva(4, norms)
iva.enable_denoise()

for file_ in filename_List:
    to_read_ = dir_file + file_ + ".pcm"
    print(to_read_)
    logger.debug(f"Обработка {to_read_}")

    try:
        # Определение размера файла
        file_size = os.path.getsize(to_read_)
        # Вычисление общего количества фреймов
        total_frames = file_size // frame_byte_size
        # print(f"Общее количество фреймов в {file_}: {total_frames}")

        for run_idx in range(5):  # Прогоняем файл 1 раз
        # for run_idx in range(5):  # Прогоняем файл 5 раз
            # Сбрасываем словарь norms перед каждым прогоном
            print(run_idx)
            iva.norms = reset_norms()
            iva.frame_idx = 0  # Сбрасываем счетчик фреймов
            processed_frames_buffer = []
            # raw_frames_buffer = []
            wh = True
            

            with open(to_read_, 'rb') as pcm_file:
                frame_idx = 0
                while wh:
                    frame = pcm_file.read(frame_byte_size)
                    if not frame or len(frame) < frame_byte_size:
                        frame_reshape = np.frombuffer(frame, dtype=np.float64).reshape(-1, num_channels)
                        logger.debug(f"!!!!!!! {frame_idx + 1}: {frame_reshape.shape}")
                        break
                    frame_reshape = np.frombuffer(frame, dtype=np.float64).reshape(-1, num_channels)
                    # raw_frames_buffer.append(frame_reshape)
                    try:
                        processed_frame = iva.process(frame_reshape)
                        processed_frames_buffer.append(processed_frame)
                        if len(processed_frames_buffer) == 50000:
                            output_filename_processed = f'{dir_out}{file_}_run{run_idx + 1}_part{frame_idx // 50000}_processed.wav'
                            output_filename_raw = f'{dir_out}{file_}_run{run_idx + 1}_part{frame_idx // 50000}_raw.wav'
                            
                            # Сохранение обработанных фреймов
                            with wave.open(output_filename_processed, 'wb') as wav_file:
                                wav_file.setnchannels(num_channels)
                                wav_file.setsampwidth(2)  # 2 байта для int16
                                wav_file.setframerate(sample_rate)
                                int16_data = (np.vstack(processed_frames_buffer) * 32767).astype(np.int16)
                                wav_file.writeframes(int16_data.tobytes())
                                
                                # wh = False
                                
                            
                            # # Сохранение необработанных фреймов
                            # with wave.open(output_filename_raw, 'wb') as wav_file:
                            #     wav_file.setnchannels(num_channels)
                            #     wav_file.setsampwidth(2)  # 2 байта для int16
                            #     wav_file.setframerate(sample_rate)
                            #     int16_data = (np.vstack(raw_frames_buffer) * 32767).astype(np.int16)
                            #     wav_file.writeframes(int16_data.tobytes())
                            
                            processed_frames_buffer = []  # Очистка буфера
                            # raw_frames_buffer = []  # Очистка буфера
                            
                    except Exception as e:
                        logger.error(f"Ошибка при обработке фрейма {frame_idx + 1}: {e}")
                        raise
                    frame_idx += 1
            # Сохранение оставшихся фреймов в буфере
            if processed_frames_buffer:
                output_filename_processed = f'{dir_out}{file_}_run{run_idx + 1}_part{frame_idx // 50000 + 1}_processed.wav'
                output_filename_raw = f'{dir_out}{file_}_run{run_idx + 1}_part{frame_idx // 50000 + 1}_raw.wav'
                
                # Сохранение обработанных фреймов
                with wave.open(output_filename_processed, 'wb') as wav_file:
                    wav_file.setnchannels(num_channels)
                    wav_file.setsampwidth(2)  # 2 байта для int16
                    wav_file.setframerate(sample_rate)
                    int16_data = (np.vstack(processed_frames_buffer) * 32767).astype(np.int16)
                    wav_file.writeframes(int16_data.tobytes())
                    
                
                # # Сохранение необработанных фреймов
                # with wave.open(output_filename_raw, 'wb') as wav_file:
                #     wav_file.setnchannels(num_channels)
                #     wav_file.setsampwidth(2)  # 2 байта для int16
                #     wav_file.setframerate(sample_rate)
                #     int16_data = (np.vstack(raw_frames_buffer) * 32767).astype(np.int16)
                #     wav_file.writeframes(int16_data.tobytes())
                    

            # Сохранение norms в JSON
            json_filename = f'{json_dir}norms_{file_}_run{run_idx + 1}.json'
            save_norms_to_json(iva.norms, json_filename)

            # Сохранение графиков для каждого прогона
            for interval in ['10000', '5000', '1000']:
                for key in iva.norms[interval]:
                    save_plot(iva.norms[interval][key], f'Norm of {key} (every {interval} frames) - Run {run_idx + 1}', f'{png_dir}norm_{key}_{interval}_{file_}_run{run_idx + 1}.png', int(interval))
# Логируем общее время нормализации
        logger.info(f"Общее время нормализации для {file_}: {iva.total_normalization_time} секунд")
    except FileNotFoundError:
        logger.error(f"Файл {to_read_} не найден.")
    except Exception as e:
        logger.error(f"Произошла ошибка при чтении файла {to_read_}: {e}")
print(iva.total_normalization_time)

/home/pyatanin/VSproj/30.09.2024/august_records/record_3.pcm
0
1
2
3
4
1138.2960669994354
