In [None]:
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
import os
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from audio_process import *
from matplotlib import pyplot as plt
from torch import optim

samplerate = 16000
# data path
# SRP_audio_path = r'F:\audio\SRP_segmented\Voice'
# SRP_egg_path = r'F:\audio\SRP_segmented\EGG'
# VRP_audio_path = r'F:\audio\10s_segment\VRP_segmented\voice_test'
# VRP_egg_path = r'F:\audio\10s_segment\VRP_segmented\egg_test'
VRP_path = r'F:\audio\test_VRP_F02\test_Voice_EGG.wav'

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
def segment_audio(audio, sr, frame_length_ms=12, hop_length_samples=1):

    # Calculate frame length in samples
    frame_length_samples = int(sr * frame_length_ms / 1000)

    # Number of frames
    num_frames = 1 + (len(audio) - frame_length_samples) // hop_length_samples

    # Initialize an array to hold the frames
    frames = np.zeros((num_frames, frame_length_samples))

    # Segment audio
    for i in range(num_frames):
        start_sample = i * hop_length_samples
        end_sample = start_sample + frame_length_samples
        frames[i] = audio[start_sample:end_sample]

    return frames

class AudioEGGDataset(Dataset):
    def __init__(self, audio_frames, egg_frames):
        """
        Initializes the dataset with pre-loaded data.
        :param audio_frames: A list or array of preprocessed and segmented audio frames.
        :param egg_frames: A list or array of preprocessed and segmented EGG frames.
        :param transform: Optional transform to be applied on a sample.
        """
        assert len(audio_frames) == len(egg_frames), "Audio and EGG frames must be the same length"
        self.audio_frames = audio_frames
        self.egg_frames = egg_frames

    def __len__(self):
        return len(self.audio_frames)

    def __getitem__(self, idx):
        audio_frame = self.audio_frames[idx]
        egg_frame = self.egg_frames[idx]

        # Convert arrays to PyTorch tensors
        audio_tensor = torch.from_numpy(audio_frame).float().unsqueeze(0)  # Add channel dimension if needed
        egg_tensor = torch.from_numpy(egg_frame).float().unsqueeze(0)      # Add channel dimension if needed

        return audio_tensor, egg_tensor

class WaveNet(nn.Module):
    def __init__(self, input_channels, dilation_channels):
        super(WaveNet, self).__init__()
        self.dilation_channels = dilation_channels
        self.receptive_field_size = 1
        self.dilated_convs = nn.ModuleList()

        dilations = [2**i for i in range(6)]
        self.dilated_convs.append(nn.Conv1d(input_channels, 2 * dilation_channels, kernel_size=3, padding=dilations[0]))
        for dilation in dilations[1:]:
            padding = dilation * (3 - 1) // 2
            self.dilated_convs.append(nn.Conv1d(dilation_channels, 2 * dilation_channels, kernel_size=3, padding=padding, dilation=dilation))
            self.receptive_field_size += dilation * 2

        self.output_conv = nn.Conv1d(dilation_channels, 1, kernel_size=1)

    def forward(self, x):
        for conv in self.dilated_convs:
            out = conv(x)
            # Splitting the output of the convolution into filter and gate parts
            filter, gate = torch.split(out, self.dilation_channels, dim=1)  # Correct dimension for splitting is 1 (channels)
            x = torch.tanh(filter) * torch.sigmoid(gate)

        return self.output_conv(x)
    
class CosineSimilarityLoss(nn.Module):
    def __init__(self):
        super(CosineSimilarityLoss, self).__init__()

    def forward(self, outputs, targets):
        # Normalize outputs and targets to unit vectors
        outputs_norm = F.normalize(outputs, p=2, dim=1)
        targets_norm = F.normalize(targets, p=2, dim=1)
        # Compute cosine similarity
        cosine_loss = 1 - torch.sum(outputs_norm * targets_norm, dim=1).mean()
        return cosine_loss
    
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score > self.best_score + self.delta:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, val_loss, model):
        if self.val_loss_min > val_loss:
            if self.verbose:
                print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
            torch.save(model.state_dict(), 'checkpoint_model.pt')
            self.val_loss_min = val_loss


In [None]:
# load data
wave, sr = librosa.load(VRP_path, sr=samplerate, mono=False) 
audio = wave[0]
egg = wave[1]
audio = voice_preprocess(audio, samplerate)
egg = process_EGG_signal(egg, samplerate)
audio = audio / np.max(np.abs(audio))
egg = egg / np.max(np.abs(egg))

# segment audio
audio_frames = segment_audio(audio, samplerate)
egg_frames = segment_audio(egg, samplerate)

print(audio_frames.shape)


In [None]:
# plot audio and EGG frames first 10
plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(audio_frames[135248])
plt.title('Audio frame 0')
plt.subplot(2, 1, 2)
plt.plot(egg_frames[135248])
plt.title('EGG frame 0')
plt.tight_layout()
plt.show()


In [None]:
dataset = AudioEGGDataset(audio_frames, egg_frames)
# Create train and validation and test sets
batch_size = 512  # Adjust as necessary
train_size = int(0.85 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

# Create data loaders
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# plot dataset
audio, egg = next(iter(dataloader))
print(audio.shape, egg.shape)
plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(audio[240].squeeze())
plt.title('Audio frame 0')
plt.subplot(2, 1, 2)
plt.plot(egg[240].squeeze())
plt.title('EGG frame 0')
plt.tight_layout()
plt.show()


In [None]:
# Instantiate the model
channels = 32  # You may need to tune this based on your dataset
model = WaveNet(input_channels=1, dilation_channels=channels)

# cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
# Instantiate the Cosine Similarity Loss
criterion = CosineSimilarityLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

early_stopping = EarlyStopping(patience=10, verbose=True)

for epoch in range(100):  # Adjust the number of epochs based on your needs
    model.train()
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        audio, egg = data
        audio = audio.to(device)
        egg = egg.to(device)

        optimizer.zero_grad()
        output = model(audio)
        loss = criterion(output, egg)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if i % 10 == 0:  # Log every 10 batches
            print(f'Epoch {epoch}, Iteration {i}, Loss: {loss.item()}')

    average_loss = running_loss / len(dataloader)
    print(f'Epoch {epoch}, Average Training Loss: {average_loss}')

    # Validation phase
    model.eval()
    val_running_loss = 0.0
    with torch.no_grad():
        for data in val_dataloader:
            audio, egg = data
            audio = audio.to(device)
            egg = egg.to(device)
            output = model(audio)
            loss = criterion(output, egg)
            val_running_loss += loss.item()

    val_loss = val_running_loss / len(val_dataloader)
    print(f'Epoch {epoch}, Validation Loss: {val_loss}')

    # Early stopping and saving best model based on validation loss
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

    scheduler.step(val_loss)

    # Save checkpoint
    if epoch % 10 == 0:  # Save every 10 epochs in chkpt folder
        checkpoint_path = os.path.join('chkpt', f'checkpoint_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': running_loss / len(dataloader),
            'val_loss': val_loss,
        }, checkpoint_path)


In [None]:
import matplotlib.pyplot as plt

model.eval()
with torch.no_grad():
    for data in test_dataloader:
        audio, egg = data
        audio = audio.to(device)
        egg = egg.to(device)
        output = model(audio)

In [None]:
# save the model
torch.save(model.state_dict(), 'WaveNet.pth')


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
egg = egg.cpu()
output = output.cpu()


# Plot the first sample
plt.figure(figsize=(12, 6))
plt.plot(egg[1].squeeze(), label='Ground Truth')
plt.plot(output[1].squeeze(), label='Prediction')
plt.legend()

plt.show()

In [None]:
import soundfile as sf
# save the first predicted EGG as wav
output_wav = output[0]
# save wav in the current folder using soundfile

sf.write('output.wav', output_wav.squeeze().numpy(), samplerate)

# save the first ground truth EGG as wav
egg_wav = egg[0]
sf.write('egg.wav', egg_wav.squeeze().numpy(), samplerate)


In [None]:
import librosa
import matplotlib.pyplot as plt
from audio_process import *
from scipy.signal import firwin, filtfilt, medfilt
import numpy as np


# load wav in dir named output_wav
output_wav, sr = librosa.load('output_wav/WaveNet_output.wav', sr=samplerate)
egg_wav, sr = librosa.load('output_wav/WaveNet_egg_FFT_filtered.wav', sr=samplerate)
# smooth egg_wav
def smooth_signal(signal, sample_rate=1000, cutoff_hz=50, numtaps=101):
    # Design the low-pass FIR filter
    fir_coeff = firwin(numtaps, cutoff_hz, fs=sample_rate, window='hamming', pass_zero=True)

    # Apply the filter to the signal using filtfilt to avoid phase shift
    smoothed_signal = filtfilt(fir_coeff, 1.0, signal)

    return smoothed_signal

def adjust_lower_part(prediction, ground_truth, threshold=-0.1):
    # Find the indices where the prediction signal is below the threshold
    indices = prediction < threshold

    # Compute the average difference in the lower part of the signals
    offset = np.mean(ground_truth[indices] - prediction[indices])

    # Apply the offset to the prediction signal
    adjusted_prediction = prediction + offset

    return adjusted_prediction

def low_pass_filter(signal, sample_rate=samplerate, cutoff_hz=50, numtaps=4):
    # Design the low-pass FIR filter
    fir_coeff = firwin(numtaps, cutoff_hz, fs=sample_rate, window='hamming', pass_zero=True)

    # Apply the filter to the signal using filtfilt to avoid phase shift
    filtered_signal = filtfilt(fir_coeff, 1.0, signal)

    return filtered_signal

output_wav = low_pass_filter(output_wav)

# plot output and ground truth
plt.figure(figsize=(12, 6))
plt.plot(output_wav[22000:22200], label='Prediction')
plt.plot(egg_wav[22000:22200], label='Ground Truth')
plt.legend()
plt.show()


In [None]:
# find the cycles in the output and ground truth
def find_cycles(signal, threshold=0.1):
    # Find the indices where the signal is above the threshold
    indices = np.where(signal > threshold)[0]

    # Compute the differences between consecutive indices
    diffs = np.diff(indices)

    # Find the indices where the differences are greater than 1
    cycle_indices = np.where(diffs > 1)[0]

    # Split the indices into cycles
    cycles = np.split(indices, cycle_indices + 1)

    return cycles

output_cycles = find_cycles(output[0])
egg_cycles = find_cycles(egg[0])

def find_qci(EGG):
    unit = unit_EGG(EGG)
    # qci is the area under the curve of the unit EGG signal
    qci = np.trapz(unit[1], unit[0])
    return qci

def unit_EGG(EGG):
    '''
    Input: each cycle of EGG signal
    Output: unit EGG signal for computing qci
    '''
    EGG_shifted = EGG - np.min(EGG)
    
    # Normalize the amplitude to have a maximum of 1
    normalized_amplitude = EGG_shifted / np.max(EGG_shifted)
    
    # Normalize the time axis
    num_samples = len(EGG)
    normalized_time = np.linspace(0, 1, num_samples, endpoint=False)
    
    return normalized_time, normalized_amplitude

def find_dEGGmax(signal):
    scaled_signal = signal * 32767
    rounded_signal = np.round(scaled_signal).astype(np.int16)

    # Check if the signal has sufficient length
    if len(rounded_signal) < 2:
        return np.nan  # or another appropriate value indicating an issue

    # Find the largest positive difference dmax over the period
    dmax = np.max(np.abs(np.diff(rounded_signal)))
    
    # Assuming a sinusoidal waveform with peak-to-peak amplitude Ap-p = 2
    Ap_p = np.max(rounded_signal) - np.min(rounded_signal)
    
    # Check if Ap_p is non-zero to avoid division by zero
    if Ap_p == 0:
        return np.nan  # or another appropriate value indicating an issue

    # Calculate QD
    period_length_T = len(rounded_signal)
    QD = 2 * dmax / (Ap_p * np.sin(2 * np.pi / period_length_T))
    
    return QD

# in each cycle, calculate qci and dEGGmax
output_qci = []
output_dEGGmax = []
for cycle in egg_cycles:
    qci = find_qci(output[cycle])
    dEGGmax = find_dEGGmax(output[cycle])
    output_qci.append(qci)
    output_dEGGmax.append(dEGGmax)

egg_qci = []
egg_dEGGmax = []
for cycle in egg_cycles:
    qci = find_qci(egg[cycle])
    dEGGmax = find_dEGGmax(egg[cycle])
    egg_qci.append(qci)
    egg_dEGGmax.append(dEGGmax)

# plot qci and dEGGmax
plt.figure(figsize=(12, 6))
plt.plot(output_qci, label='Prediction QCI')
plt.plot(egg_qci, label='Ground Truth QCI')
plt.legend()
plt.show()



In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Load the data
data = pd.read_csv('EGG_Signal_Data.csv')

# Extract columns
sample_number = data['SampleNumber']
true_egg = data['TrueEGG']
predicted_egg = data['PredictedEGG']

# Plotting
plt.figure(figsize=(12, 3))
plt.plot(sample_number, true_egg, label='True EGG', color='green', alpha=0.6)
plt.plot(sample_number, predicted_egg, label='Predicted EGG', color='red', alpha=0.6)
plt.xlabel('Sample Number')
plt.ylabel('Amplitude')
plt.title('EGG')
plt.legend()
plt.grid(True)
plt.show()
