In [1]:
from pathlib import Path

In [2]:
path = Path("piano_music/")

In [3]:
import os

In [4]:
music_files = os.listdir(path)
len(music_files)

351

In [5]:
import librosa

In [6]:
sample_rate = 16000
waveform, _ = librosa.load(path/music_files[0], sr=sample_rate)

In [7]:
import torch

In [8]:
torch.tensor(waveform).shape

torch.Size([4372907])

In [9]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import pad
from pathlib import Path
import librosa
import os

class AudioMP3Dataset(Dataset):
    def __init__(self, root_dir, transforms=None, max_len=50000, sample_rate=16000):
        self.root_dir = Path(root_dir)
        self.max_len = max_len
        self.sample_rate = sample_rate
        self.transforms = transforms
        self.music_files = list(filter(lambda x: ".mp3" in x, os.listdir(root_dir)))
        
    def __len__(self):
        return len(self.music_files)
    
    def __getitem__(self, i):
        assert type(i) == int
        
        waveform, sample_rate = librosa.load(path/self.music_files[i], sr=self.sample_rate)
        waveform = torch.tensor(waveform)
        
        if self.transforms:
            waveform = self.transforms(waveform)
        
        return waveform, sample_rate

In [10]:
dataset = AudioMP3Dataset(path)

In [11]:
def collate_fn(batch, max_len=50000):
    split_waveform = []
    for waveform, _ in batch:
        for i in range(0, waveform.size(-1), max_len):
            split_waveform.append(waveform[i:i+max_len])

        split_waveform[-1] = pad(split_waveform[-1], (0, max_len - split_waveform[-1].size(-1)))
    split_waveform = torch.stack(split_waveform).unsqueeze(1)
    
    return split_waveform

In [12]:
dl = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)

In [13]:
from model import WaveNet

In [14]:
model = WaveNet(2, 2, 2)

In [15]:
for inputs in dl:
    break

In [16]:
inputs[:,0,:5000].unsqueeze(1).shape

torch.Size([88, 1, 5000])

In [18]:
out = model(inputs[:,0,:5000].unsqueeze(1))
out.shape

torch.Size([88, 256, 5000])


torch.Size([88, 256, 1])

In [23]:
pred = torch.argmax(out, dim=1)
pred.shape

torch.Size([88, 1])

In [24]:
from preprocess import decodeMuLaw

In [26]:
decodeMuLaw(pred)[:5]

tensor([[-0.6602],
        [-0.0077],
        [-0.6602],
        [-0.6602],
        [-0.0077]])