In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
import os

from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
samplerate = 16000

# load data
audio_path = r'F:\Audio\Audio'
egg_path = r'F:\Audio\EGG'
audio_list = os.listdir(audio_path)
egg_list = os.listdir(egg_path)

class AudioEGGDataset(Dataset):
    def __init__(self, audio_path, egg_path, transform=None):
        self.audio_path = audio_path
        self.egg_path = egg_path
        self.audio_list = os.listdir(audio_path)
        self.egg_list = os.listdir(egg_path)
        self.transform = transform

    def __len__(self):
        return len(self.audio_list)

    def __getitem__(self, idx):
        try:
            audio_file = os.path.join(self.audio_path, self.audio_list[idx])
            egg_file = os.path.join(self.egg_path, self.egg_list[idx])

            # Load audio and EGG data
            audio, sr = librosa.load(audio_file, sr=samplerate)  # None for native sampling rate, or replace with specific rate
            egg, _ = librosa.load(egg_file, sr=samplerate)      # Assume same sample rate as audio

            # Find the maximum length in the dataset or a predetermined max length
            max_length = 160000  # This could also be dynamically calculated or set based on your data
            # Pad or truncate to the maximum length
            audio = librosa.util.fix_length(audio, size=max_length)
            egg = librosa.util.fix_length(egg, size=max_length)

            if self.transform:
                audio = self.transform(audio)
                egg = self.transform(egg)

            # Convert to PyTorch tensors and add channel dimension
            audio = torch.from_numpy(audio).float().unsqueeze(0)  # Add channel dimension
            egg = torch.from_numpy(egg).float().unsqueeze(0)

        except Exception as e:
            print(f"Error loading {audio_file} and {egg_file}: {e}")
            return None

        return audio, egg

dataset = AudioEGGDataset(audio_path, egg_path)
# Create DataLoader
batch_size = 2  # Adjust as necessary
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)


In [3]:


class WaveNet(nn.Module):
    def __init__(self, input_channels, dilation_channels):
        super(WaveNet, self).__init__()
        self.dilation_channels = dilation_channels
        self.receptive_field_size = 1
        self.dilated_convs = nn.ModuleList()

        # Create several dilated layers
        dilations = [2**i for i in range(6)]  # Adjust as necessary
        for dilation in dilations:
            # Add proper padding to maintain input length
            padding = dilation * 1  # Assuming kernel_size=2
            self.dilated_convs.append(nn.Conv1d(input_channels, 2 * dilation_channels, kernel_size=2, padding=padding, dilation=dilation))
            input_channels = dilation_channels  # After the first layer, input_channels should match dilation_channels

        self.output_conv = nn.Conv1d(dilation_channels, 1, kernel_size=1)

    def forward(self, x):
        for conv in self.dilated_convs:
            out = conv(x)
            # Splitting the output of the convolution into filter and gate parts
            filter, gate = torch.split(out, self.dilation_channels, dim=1)  # Correct dimension for splitting is 1 (channels)
            x = torch.tanh(filter) * torch.sigmoid(gate)

        return self.output_conv(x)

# Instantiate the model
channels = 32  # You may need to tune this based on your dataset
model = WaveNet(input_channels=1, dilation_channels=channels)


# cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)



In [5]:
from torch import optim

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(50):  # Adjust the number of epochs based on your needs
    for i, data in enumerate(dataloader):
        audio, egg = data
        audio = audio.to(device)
        egg = egg.to(device)

        optimizer.zero_grad()
        output = model(audio)
        loss = criterion(output, egg)
        loss.backward()
        optimizer.step()

        if i % 10 == 0:
            print(f'Epoch {epoch}, Iteration {i}, Loss: {loss.item()}')