# WaveNet - Fit a Sample

In [35]:
import sys
#sys.path.append('../../src/')
sys.path.append('../../network/')

In [36]:
import os
import torch
from types import SimpleNamespace
torch.cuda.empty_cache()

In [37]:
from models.wavenet.model import WaveNet
from models.wavenet.utils.data import DataLoader

In [38]:
params = SimpleNamespace(
    layer_size=10,
    stack_size=5,
    in_channels=256,
    res_channels=512,
    lr=2e-3,
    sample_size=10_000,
    sample_rate=22_050,
    epochs=100_000,
    model_dir='../../network/weights/wavenet/'
)

In [42]:
class Trainer:
    def __init__(self,
                 layer_size: int = 10,
                 stack_size: int = 5,
                 in_channels: int = 256,
                 res_channels: int = 512,
                 lr: float = 2e-3,
                 sample_size: int = 100_000,
                 sample_rate: int =22_050,
                 epochs: int = 10_000,
                 data_dir: str = '.',
                 model_dir: str = './',
                 model_name: str = None):
        """
        """
        self.epochs = epochs
        self.model_dir = model_dir
        self.model_name = model_name
        
        self.wavenet = WaveNet(layer_size, stack_size, in_channels, res_channels, lr=lr)

        self.data_loader = DataLoader(data_dir, self.wavenet.receptive_fields,
                                      sample_size, sample_rate, in_channels)

    def infinite_batch(self):
        while True:
            for dataset in self.data_loader:
                for inputs, targets in dataset:
                    yield inputs, targets

    def run(self):
        total_steps = 0

        for inputs, targets in self.infinite_batch():
            loss = self.wavenet.train(inputs, targets)

            total_steps += 1

            print('[{0}/{1}] loss: {2}'.format(total_steps, self.epochs, loss))

            if total_steps > self.epochs:
                break

            if total_steps % 2_500 == 0:
                self.wavenet.save(self.model_dir, self.model_name)

In [46]:
params.data_dir = '../../data/processed/tapping/tapping-glass/partial/'
params.model_name = 'wavenet-tapping-glass-tiny-jar'

In [None]:
trainer = Trainer(**params.__dict__)
trainer.run()

2 GPUs are detected.
../../data/processed/tapping/tapping-glass/partial/PLhDdb5CgZ4-tiny-jar-Copy.wav
[1/100000] loss: 5.54508638381958
[2/100000] loss: 5.5447845458984375
[3/100000] loss: 5.544338703155518
[4/100000] loss: 5.5435590744018555
[5/100000] loss: 5.542016983032227
[6/100000] loss: 5.541465759277344
[7/100000] loss: 5.537476539611816
[8/100000] loss: 5.532314300537109
[9/100000] loss: 5.529397964477539
[10/100000] loss: 5.519846439361572
[11/100000] loss: 5.521372318267822
[12/100000] loss: 5.500758171081543
[13/100000] loss: 5.4805755615234375
[14/100000] loss: 5.482760429382324
[15/100000] loss: 5.462929725646973
[16/100000] loss: 5.461312294006348
[17/100000] loss: 5.443474292755127
[18/100000] loss: 5.410279273986816
[19/100000] loss: 5.437704563140869
[20/100000] loss: 5.420593738555908
[21/100000] loss: 5.478623390197754
[22/100000] loss: 5.472309112548828
[23/100000] loss: 5.4334025382995605
[24/100000] loss: 5.479824066162109
[25/100000] loss: 5.475582122802734
[26/

[221/100000] loss: 4.666549205780029
[222/100000] loss: 4.853813171386719
[223/100000] loss: 4.7692413330078125
[224/100000] loss: 5.018270969390869
[225/100000] loss: 4.79360294342041
[226/100000] loss: 4.771384239196777
[227/100000] loss: 4.876296520233154
[228/100000] loss: 4.7889909744262695
[229/100000] loss: 4.7998881340026855
[230/100000] loss: 4.747874736785889
[231/100000] loss: 4.631435871124268
[232/100000] loss: 4.847328186035156
[233/100000] loss: 4.633153915405273
[234/100000] loss: 4.8001909255981445
[235/100000] loss: 4.9401326179504395
[236/100000] loss: 4.819965362548828
[237/100000] loss: 5.015168190002441
[238/100000] loss: 4.954274654388428
[239/100000] loss: 4.919366836547852
[240/100000] loss: 5.0643310546875
[241/100000] loss: 4.95130729675293
[242/100000] loss: 4.995399475097656
[243/100000] loss: 4.900243282318115
[244/100000] loss: 4.851797580718994
[245/100000] loss: 4.714431285858154
[246/100000] loss: 4.809416770935059
[247/100000] loss: 4.868912220001221


[441/100000] loss: 4.752132415771484
[442/100000] loss: 4.766637325286865
[443/100000] loss: 4.680840015411377
[444/100000] loss: 4.597692966461182
[445/100000] loss: 4.7862372398376465
[446/100000] loss: 4.572147846221924
[447/100000] loss: 4.748859405517578
[448/100000] loss: 4.9024786949157715
[449/100000] loss: 4.780944347381592
[450/100000] loss: 4.939645767211914
[451/100000] loss: 4.893917083740234
[452/100000] loss: 4.843239784240723
[453/100000] loss: 4.971950531005859
[454/100000] loss: 4.915539741516113
[455/100000] loss: 4.936765193939209
[456/100000] loss: 4.843847751617432
[457/100000] loss: 4.813598155975342
[458/100000] loss: 4.676341533660889
[459/100000] loss: 4.779689311981201
[460/100000] loss: 4.770943641662598
[461/100000] loss: 4.8603901863098145
[462/100000] loss: 4.747424602508545
[463/100000] loss: 4.757218837738037
[464/100000] loss: 4.824472427368164
[465/100000] loss: 4.77073860168457
[466/100000] loss: 4.868961334228516
[467/100000] loss: 4.899574756622314

[661/100000] loss: 4.902342319488525
[662/100000] loss: 4.780849933624268
[663/100000] loss: 4.939516544342041
[664/100000] loss: 4.893784046173096
[665/100000] loss: 4.843124866485596
[666/100000] loss: 4.971796035766602
[667/100000] loss: 4.915420055389404
[668/100000] loss: 4.936652183532715
[669/100000] loss: 4.843741416931152
[670/100000] loss: 4.813511848449707
[671/100000] loss: 4.676276683807373
[672/100000] loss: 4.779609680175781
[673/100000] loss: 4.7708516120910645
[674/100000] loss: 4.860287189483643
