In [None]:
import pyaudio
import wave
import numpy as np
import torch
import socket
import torchaudio
import librosa


def minmaxscaler(data):
    return (data - data.min()) / (data.max() - data.min())


class STFT_Feature:
    def __init__(self, sample_rate=13000, n_fft=1024, hop_length=512, n_mel=64):
        self.transformation = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mel)


class Soft_weights_predictor:
    def __init__(self, MODEL_PATH):
        from Simple_ShufflenetV2 import Modified_ShufflenetV2
        model = Modified_ShufflenetV2(num_classes=6)
        model.load_state_dict(torch.load(MODEL_PATH))
        model.eval()
        self.model = model
        self.feature_extractor = STFT_Feature(sample_rate=13000, n_fft=1024, hop_length=512, n_mel=64)

    def predict_weights(self, noise):
        global soft_labels_pre
        noise = minmaxscaler(noise)
        noise = self.feature_extractor.transformation(noise)
        noise = librosa.core.power_to_db(noise)
        noise = torch.from_numpy(noise).unsqueeze(0)
        soft_labels_now = self.model(noise).squeeze()
        return Compare_now_pre(soft_labels_now.detach().numpy())


def Compare_now_pre(soft_labels_now):
    global soft_labels_pre
    # L2范数
    if np.linalg.norm(soft_labels_now - soft_labels_pre) / np.linalg.norm(soft_labels_pre) >= 0.3:  # Update threshold
        soft_labels_pre = np.round(soft_labels_now, 1)
        return soft_labels_pre
    return None


def UDP_sender(message, ip, port):
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        sock.sendto(message.encode(), (ip, port))
    except Exception as e:
        print(f"An error occurred: {e}")
    finally:
        sock.close()

        
class AudioRecorder:
    def __init__(self, seconds=1, chunk=1000, sample_format=pyaudio.paInt24, channels=1, fs=13000, input_device_index=1):
        self.p = pyaudio.PyAudio()
        self.stream = self.p.open(format=sample_format, channels=channels, rate=fs, frames_per_buffer=chunk, input=True, input_device_index=input_device_index)
        self.fs = fs
        self.chunk = chunk
        self.seconds = seconds
        self.channels = channels
        self.sample_format = sample_format
    def record(self, filename):
        frames = []
        for i in range(0, int(self.fs / self.chunk * self.seconds)):
            data = self.stream.read(self.chunk)
            frames.append(data)
        self.stream.stop_stream()
        self.stream.close()
        self.p.terminate()
        wf = wave.open(filename, 'wb')
        wf.setnchannels(self.channels)
        wf.setsampwidth(self.p.get_sample_size(self.sample_format))
        wf.setframerate(self.fs)
        wf.writeframes(b''.join(frames))
        wf.close()
        waveform, sample_rate = torchaudio.load(filename)
        return waveform, sample_rate