# Main Diffusion Script
### Imports

In [31]:
try: 
    import librosa
except:
    !pip install librosa


#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter
from torch.autograd import Variable, Function

# Utils
import numpy as np
from numpy import ndarray
import logging, math
from typing import Sequence, Optional, Callable
from tqdm import tqdm

# Base Scripts
from Libraries.Utils import *
from MainScripts.Conf import conf

General

In [2]:
remote_kernel: bool = True #Set to true if using a remote Kernel changes the file structure
model_name: str = "wave_net_v3"
training_data_name: str = "training_full_wave"
full_model_path: str = path_to_remote_path("{}/{}".format(conf["paths"].model_path, model_name + ".pth"), remote_kernel)

Logging

In [3]:
logging_level: int = LIGHT_DEBUG
logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

Training Params

In [None]:
device: str = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32


learning_rate: float = 1e-3
epochs: int = 30
restart_training: bool = True
checkpoint_freq: int = 10

### Data Loading

In [22]:
def mu_law_encode(audio: ndarray, quantization_channels: int = 256) -> ndarray:
    mu = quantization_channels - 1
    encoded = np.sign(audio) * np.log(1 + mu * np.abs(audio)) / np.log(1 + mu)
    encoded = ((encoded + 1) / 2 * mu + 0.5)
    out = np.round(encoded).astype(np.int32)
    return np.clip(out, 0, mu)

def mu_law_decode(encoded: ndarray, quantization_channels: int = 256) -> ndarray:
    mu = quantization_channels - 1
    encoded = encoded.astype(np.float32)
    encoded = 2 * (encoded / mu) - 1  # Map from [0, 255] to [-1, 1]
    x = np.sign(encoded) * (1.0 / mu) * ((1 + mu) ** np.abs(encoded) - 1)
    return x

def prepare_data(data: ndarray, length: int = 64000, num_samples_per_file: int = 5) -> tuple[ndarray]:
    input, label = [], []
    for file in data:
        audio = mu_law_encode(file)
        starts = np.random.randint(0, data.shape[1] - length - 1, size=num_samples_per_file)
        for start in starts:
            data_seq = audio[start: start + length]
            label_seq = audio[start + 1:start + length + 1]
            input.append(data_seq)
            label.append(label_seq)

    return np.array(input), np.array(label)


In [23]:
file: ndarray = load_training_data(path_to_remote_path("{}/{}".format(conf["paths"].data_path, training_data_name + ".npy"), remote_kernel))
input, label = prepare_data(file[:10], length=32000)
inputs = []
for b in range(len(input)//10):
    x = torch.tensor(input[b * 10: (b + 1) * 10]).long()
    x = torch.nn.functional.one_hot(x, 256).float().permute(0, 2, 1)
    inputs.append(x)
input = torch.cat(inputs)

data_loader = create_dataloader(Audio_Data(input, label), batch_size)

logger.info(f"Data loaded with shape: {input.shape}")

2025-06-05 06:27:00,879 - LIGHT_DEBUG - Ndarray loaded from Data/training_full_wave.npy of shape: (5906, 147200)
2025-06-05 06:27:02,970 - INFO - Data loaded with shape: torch.Size([50, 256, 32000])


### Model Creation

In [27]:
class DilatedConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, dilation: int = 1) -> None:
        super(DilatedConv, self).__init__()
        self.padding = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=self.padding, dilation=dilation)

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv(x)
        if self.padding > 0:
            x = x[:, :, :-self.padding]
        return x

class DilatedConvStep(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super().__init__()
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.receptive_field = dilation * (kernel_size - 1) + 1

        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, bias=True)

    def forward(self, x, cache):
        """
        x: [B, C, 1] - the current input sample
        cache: [B, C, receptive_field - 1] - buffer of past samples
        """
        if cache is None:
            B, C, _ = x.shape
            cache = torch.zeros(B, C, self.receptive_field - 1, device=x.device)

        # Append x to the right, remove oldest frame
        x_full = torch.cat([cache, x], dim=2)  # Shape: [B, C, receptive_field]

        # Apply standard Conv1d over this buffer
        out = self.conv(x_full)  # Output shape: [B, C_out, 1]

        # Update cache
        new_cache = torch.cat([cache[:, :, 1:], x], dim=2)

        return out, new_cache


class ResBlock(nn.Module):
    def __init__(self, res_channels: int, skip_channels: int, kernel_size: int = 3, dilation: int = 1) -> None:
        super(ResBlock, self).__init__()
        self.filter_conv = DilatedConv(res_channels, res_channels, kernel_size, dilation)
        self.gate_conv = DilatedConv(res_channels, res_channels, kernel_size, dilation)
        self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1)
        self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: Tensor, skip_size: int) -> tuple[Tensor, Tensor]:
        res_x = x
        filter_out = self.filter_conv(x)
        gate_out = self.gate_conv(x)
        x = self.tanh(filter_out) * self.sigmoid(gate_out)
        res = self.res_conv(x)
        res = res + res_x
        skip = self.skip_conv(x)[:, :, -skip_size:]
        return res, skip
    
class ResBlockStep(nn.Module):
    def __init__(self, res_channels, skip_channels, kernel_size, dilation):
        super().__init__()
        self.filter_conv = DilatedConvStep(res_channels, res_channels, kernel_size, dilation)
        self.gate_conv = DilatedConvStep(res_channels, res_channels, kernel_size, dilation)
        self.res_conv = nn.Conv1d(res_channels, res_channels, 1)
        self.skip_conv = nn.Conv1d(res_channels, skip_channels, 1)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.res_scale = nn.Parameter(torch.tensor(1.0))

    def forward(self, x: Tensor, filter_cache: Tensor, gate_cache: Tensor) -> tuple[Tensor]:
        f, filter_cache = self.filter_conv(x, filter_cache)
        g, gate_cache = self.gate_conv(x, gate_cache)
        a = self.tanh(f) * self.sigmoid(g)

        res = self.res_conv(a)
        skip = self.skip_conv(a)

        x = (x + res) * self.res_scale
        return x, skip, filter_cache, gate_cache


class DenseLayer(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super(DenseLayer, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, in_channels, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        self.relu = nn.ReLU()

    def forward(self, x: Tensor) -> Tensor:
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return x

class ResStack(nn.Module):
    def __init__(self, stack_size: int, layer_size: int, res_channels: int, skip_channels: int, kernel_size: int) -> None:
        super(ResStack, self).__init__()
        self.res_blocks = nn.ModuleList()
        for stack in range(stack_size):
            for layer in range(layer_size):
                dilation = 2 ** layer
                self.res_blocks.append(ResBlock(res_channels, skip_channels, kernel_size, dilation))

    def forward(self, x: Tensor, skip_size: int) -> tuple[Tensor, Tensor]:
        res_out = x
        skip_outs = []
        for res_block in self.res_blocks:
            res_out, skip = res_block(res_out, skip_size)
            skip_outs.append(skip)
        return res_out, torch.sum(torch.stack(skip_outs), dim=0)


class WaveNet(nn.Module):
    def __init__(self, in_channels: int, res_channels: int, skip_channels: int, out_channels: int, stack_size: int = 2, layer_size: int = 5, kernel_size: int = 2):
        super().__init__()
        self.receptive_field = (kernel_size - 1) * (2 ** layer_size - 1) * stack_size + 1

        self.input_proj = nn.Conv1d(in_channels, res_channels, 1)
        self.causal_conv = DilatedConv(res_channels, res_channels, kernel_size, dilation=1)

        self.res_blocks = nn.ModuleList()
        self.res_blocks_step = nn.ModuleList()

        for s in range(stack_size):
            for l in range(layer_size):
                d = 2 ** l
                self.res_blocks.append(ResBlock(res_channels, skip_channels, kernel_size, d))
                self.res_blocks_step.append(ResBlockStep(res_channels, skip_channels, kernel_size, d))

        self.dense = DenseLayer(skip_channels, out_channels)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.input_proj(x)
        x = self.causal_conv(x)
        skips = 0
        for block in self.res_blocks:
            x, skip = block(x, x.size(2))
            skips += skip
        out = self.dense(skips)
        return self.tanh(out)

    @torch.no_grad()
    def fast_sample(self, num_samples, temp=1.0, device="cpu"):
        self.eval()
        num_classes = self.dense.conv2.out_channels
        x = torch.zeros(1, num_classes, 1, device=device)
        x[:, 127, 0] = 1.0

        filter_caches = [None] * len(self.res_blocks_step)
        gate_caches = [None] * len(self.res_blocks_step)

        outputs = []

        for i in tqdm(range(num_samples), desc="Fast Sampling"):
            x_in = self.input_proj(x)
            x_in = self.causal_conv(x_in)

            skips = []
            for i, block in enumerate(self.res_blocks_step):
                x_in, skip, fc, gc = block(x_in, filter_caches[i], gate_caches[i])
                filter_caches[i] = fc
                gate_caches[i] = gc
                skips.append(skip)

            skip_out = sum(skips)
            y = self.dense(skip_out)
            logits = y[:, :, -1]
            probs = torch.nn.functional.softmax(logits / temp, dim=1)
            next_token = torch.multinomial(probs, 1)

            outputs.append(next_token.item())
            x = torch.nn.functional.one_hot(next_token, num_classes).permute(0, 2, 1).float()
        waveform = mu_law_decode(np.array(outputs), num_classes)
        self.train()
        return waveform



In [28]:
wave_net = WaveNet(
    in_channels=256,
    res_channels=128,
    skip_channels=128,
    out_channels=256,
    stack_size=2,
    layer_size=8,
    kernel_size=2
).to(device)

### Model Loading

In [20]:
optimizer = optim.AdamW(wave_net.parameters(), lr=learning_rate, weight_decay=1e-3, betas=[0.95, 0.999], eps=1e-6)
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=1.0)
start_epoch: int = 0
if os.path.exists(full_model_path):
    model = torch.load(full_model_path, map_location=device)
    wave_net.load_state_dict(model["model"])
    if not restart_training:
        optimizer.load_state_dict(model["optim"])
        lr_scheduler.load_state_dict(model["scheduler"])
        start_epoch = model.get("epoch", 0)
    logger.info(f"Model {model_name} loaded with {count_parameters(wave_net)} Parameters")
else: 
    logger.info(f"Model {model_name} created with {count_parameters(wave_net)} Parameters")

2025-06-05 06:25:51,902 - INFO - Model wave_net_v3 loaded with ~3.277M Parameters


### Training

In [22]:
loss_fn = nn.CrossEntropyLoss()
logger.info(f"Training started on {device}")
if device == "cuda":
    scaler = torch.cuda.amp.GradScaler()
else:
    scaler = torch.amp.GradScaler(device=device)
loss_list: list = []
val_loss_list: list = []
total_time: float = 0.0
best_loss = float('inf')
epochs_no_improve: int = 0

wave_net.train()
for e in range(0, epochs):
    total_loss: float = 0
    validation_loss: float = 0
    start_time: float = time.time()

    for b_idx, (x, y) in enumerate(data_loader):
        optimizer.zero_grad()
        x, y = x.to(device), y.to(device).long()
        with torch.autocast(device_type=device):
            output = wave_net(x)
            loss = loss_fn(output, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        if np.isnan(loss.item()):
            logger.info("Breaking due to NaN loss.")
            break

        if logger.getEffectiveLevel() == LIGHT_DEBUG:
            current_batch = b_idx + 1
            all_params = torch.cat([param.view(-1) for param in wave_net.parameters()])
            print(f"\r{time.strftime('%Y-%m-%d %H:%M:%S')},000 - LIGHT_DEBUG - Batch {current_batch:03d}/{len(data_loader):03d} Loss: {loss.item():.3f} Min/Max params: {torch.min(all_params):.3f}, {torch.max(all_params):.3f}", end='', flush=True)
    else:
        if logger.getEffectiveLevel() == LIGHT_DEBUG:
            print(flush=True)

        avg_loss = total_loss / len(data_loader)
        loss_list.append(avg_loss)

        if lr_scheduler is not None:
            if isinstance(lr_scheduler, (optim.lr_scheduler.ReduceLROnPlateau)):
                lr_scheduler.step(avg_loss)
            else:
                lr_scheduler.step()

        
        epoch_time = time.time() - start_time
        total_time += epoch_time
        remaining_time = int((total_time / (e + 1)) * (epochs - e - 1))

        logger.info(f"Epoch {e + 1:03d}: Avg. Loss: {avg_loss:.5e} Remaining Time: {remaining_time // 3600:02d}h {(remaining_time % 3600) // 60:02d}min {round(remaining_time % 60):02d}s LR: {optimizer.param_groups[0]['lr']:.5e} ")
        
        if checkpoint_freq > 0 and (e + 1) % checkpoint_freq == 0:
            checkpoint_path: str = f"{full_model_path[:-4]}_epoch_{e + 1:03d}.pth"
            torch.save({"model": wave_net.state_dict(), "optim": optimizer.state_dict(), "scheduler": lr_scheduler.state_dict(), "epoch": e + 1}, checkpoint_path)
            if e + 1 != checkpoint_freq:
                last_path: str = f"{full_model_path[:-4]}_epoch_{(e + 1) - checkpoint_freq:03d}.pth"
                del_if_exists(last_path)
            logger.light_debug(f"Checkpoint saved model to {checkpoint_path}")
        continue
    break


torch.save({"model": wave_net.state_dict(), "optim": optimizer.state_dict(), "scheduler": lr_scheduler.state_dict(), "epoch": e + 1}, full_model_path)

logger.light_debug(f"Saved model to {full_model_path}")

if checkpoint_freq > 0:
    checkpoint_path: str = f"{full_model_path[:-4]}_epoch_{e + 1 - ((e + 1) % checkpoint_freq):03d}.pth"
    del_if_exists(checkpoint_path)

2025-06-04 21:11:03,764 - INFO - Training started on cuda


2025-06-04 21:11:05,000 - LIGHT_DEBUG - Batch 013/013 Loss: 5.083 Min/Max params: -0.126, 1.000


2025-06-04 21:11:05,667 - INFO - Epoch 001: Avg. Loss: 5.16331e+00 Remaining Time: 00h 00min 55s LR: 1.00000e-03 


2025-06-04 21:11:07,000 - LIGHT_DEBUG - Batch 013/013 Loss: 5.043 Min/Max params: -0.138, 1.000


2025-06-04 21:11:07,511 - INFO - Epoch 002: Avg. Loss: 5.10552e+00 Remaining Time: 00h 00min 52s LR: 1.00000e-03 


2025-06-04 21:11:09,000 - LIGHT_DEBUG - Batch 013/013 Loss: 5.011 Min/Max params: -0.151, 1.000


2025-06-04 21:11:09,388 - INFO - Epoch 003: Avg. Loss: 5.07164e+00 Remaining Time: 00h 00min 50s LR: 1.00000e-03 


2025-06-04 21:11:11,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.988 Min/Max params: -0.159, 1.000


2025-06-04 21:11:11,297 - INFO - Epoch 004: Avg. Loss: 5.04567e+00 Remaining Time: 00h 00min 48s LR: 1.00000e-03 


2025-06-04 21:11:13,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.970 Min/Max params: -0.164, 1.000


2025-06-04 21:11:13,178 - INFO - Epoch 005: Avg. Loss: 5.01918e+00 Remaining Time: 00h 00min 47s LR: 1.00000e-03 


2025-06-04 21:11:15,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.953 Min/Max params: -0.173, 1.000


2025-06-04 21:11:15,017 - INFO - Epoch 006: Avg. Loss: 4.99784e+00 Remaining Time: 00h 00min 44s LR: 1.00000e-03 


2025-06-04 21:11:16,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.942 Min/Max params: -0.169, 1.000


2025-06-04 21:11:16,847 - INFO - Epoch 007: Avg. Loss: 4.98178e+00 Remaining Time: 00h 00min 42s LR: 1.00000e-03 


2025-06-04 21:11:18,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.931 Min/Max params: -0.182, 1.000


2025-06-04 21:11:18,688 - INFO - Epoch 008: Avg. Loss: 4.96913e+00 Remaining Time: 00h 00min 41s LR: 1.00000e-03 


2025-06-04 21:11:20,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.923 Min/Max params: -0.194, 1.000


2025-06-04 21:11:20,554 - INFO - Epoch 009: Avg. Loss: 4.95718e+00 Remaining Time: 00h 00min 39s LR: 1.00000e-03 


2025-06-04 21:11:22,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.915 Min/Max params: -0.208, 1.000


2025-06-04 21:11:22,406 - INFO - Epoch 010: Avg. Loss: 4.94809e+00 Remaining Time: 00h 00min 37s LR: 1.00000e-03 
2025-06-04 21:11:22,591 - LIGHT_DEBUG - Checkpoint saved model to Models/wave_net_v3_epoch_010.pth


2025-06-04 21:11:24,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.912 Min/Max params: -0.217, 1.000


2025-06-04 21:11:24,410 - INFO - Epoch 011: Avg. Loss: 4.94041e+00 Remaining Time: 00h 00min 35s LR: 1.00000e-03 


2025-06-04 21:11:26,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.905 Min/Max params: -0.215, 1.000


2025-06-04 21:11:26,255 - INFO - Epoch 012: Avg. Loss: 4.93467e+00 Remaining Time: 00h 00min 33s LR: 1.00000e-03 


2025-06-04 21:11:28,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.899 Min/Max params: -0.230, 1.000


2025-06-04 21:11:28,139 - INFO - Epoch 013: Avg. Loss: 4.92926e+00 Remaining Time: 00h 00min 31s LR: 1.00000e-03 


2025-06-04 21:11:29,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.893 Min/Max params: -0.248, 1.000


2025-06-04 21:11:29,978 - INFO - Epoch 014: Avg. Loss: 4.92247e+00 Remaining Time: 00h 00min 29s LR: 1.00000e-03 


2025-06-04 21:11:31,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.889 Min/Max params: -0.257, 1.000


2025-06-04 21:11:31,852 - INFO - Epoch 015: Avg. Loss: 4.91405e+00 Remaining Time: 00h 00min 27s LR: 1.00000e-03 


2025-06-04 21:11:33,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.887 Min/Max params: -0.266, 1.000


2025-06-04 21:11:33,717 - INFO - Epoch 016: Avg. Loss: 4.90972e+00 Remaining Time: 00h 00min 26s LR: 1.00000e-03 


2025-06-04 21:11:35,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.883 Min/Max params: -0.278, 1.000


2025-06-04 21:11:35,592 - INFO - Epoch 017: Avg. Loss: 4.90563e+00 Remaining Time: 00h 00min 24s LR: 1.00000e-03 


2025-06-04 21:11:37,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.879 Min/Max params: -0.294, 1.000


2025-06-04 21:11:37,641 - INFO - Epoch 018: Avg. Loss: 4.90322e+00 Remaining Time: 00h 00min 22s LR: 1.00000e-03 


2025-06-04 21:11:39,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.882 Min/Max params: -0.305, 1.000


2025-06-04 21:11:39,897 - INFO - Epoch 019: Avg. Loss: 4.89913e+00 Remaining Time: 00h 00min 20s LR: 1.00000e-03 


2025-06-04 21:11:41,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.879 Min/Max params: -0.316, 1.000


2025-06-04 21:11:41,881 - INFO - Epoch 020: Avg. Loss: 4.89819e+00 Remaining Time: 00h 00min 18s LR: 1.00000e-03 
2025-06-04 21:11:41,939 - LIGHT_DEBUG - Models/wave_net_v3_epoch_010.pth deleted
2025-06-04 21:11:41,940 - LIGHT_DEBUG - Checkpoint saved model to Models/wave_net_v3_epoch_020.pth


2025-06-04 21:11:43,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.874 Min/Max params: -0.333, 1.000


2025-06-04 21:11:43,799 - INFO - Epoch 021: Avg. Loss: 4.89429e+00 Remaining Time: 00h 00min 17s LR: 1.00000e-03 


2025-06-04 21:11:45,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.873 Min/Max params: -0.350, 1.000


2025-06-04 21:11:45,695 - INFO - Epoch 022: Avg. Loss: 4.89217e+00 Remaining Time: 00h 00min 15s LR: 1.00000e-03 


2025-06-04 21:11:47,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.872 Min/Max params: -0.358, 1.000


2025-06-04 21:11:47,696 - INFO - Epoch 023: Avg. Loss: 4.89105e+00 Remaining Time: 00h 00min 13s LR: 1.00000e-03 


2025-06-04 21:11:49,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.904 Min/Max params: -0.365, 1.000


2025-06-04 21:11:49,722 - INFO - Epoch 024: Avg. Loss: 4.89625e+00 Remaining Time: 00h 00min 11s LR: 1.00000e-03 


2025-06-04 21:11:51,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.892 Min/Max params: -0.384, 1.000


2025-06-04 21:11:51,597 - INFO - Epoch 025: Avg. Loss: 4.90933e+00 Remaining Time: 00h 00min 09s LR: 1.00000e-03 


2025-06-04 21:11:53,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.883 Min/Max params: -0.409, 1.000


2025-06-04 21:11:53,437 - INFO - Epoch 026: Avg. Loss: 4.90658e+00 Remaining Time: 00h 00min 07s LR: 1.00000e-03 


2025-06-04 21:11:55,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.881 Min/Max params: -0.430, 1.000


2025-06-04 21:11:55,313 - INFO - Epoch 027: Avg. Loss: 4.90219e+00 Remaining Time: 00h 00min 05s LR: 1.00000e-03 


2025-06-04 21:11:57,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.876 Min/Max params: -0.433, 1.000


2025-06-04 21:11:57,162 - INFO - Epoch 028: Avg. Loss: 4.89627e+00 Remaining Time: 00h 00min 03s LR: 1.00000e-03 


2025-06-04 21:11:58,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.873 Min/Max params: -0.437, 1.000


2025-06-04 21:11:58,993 - INFO - Epoch 029: Avg. Loss: 4.89494e+00 Remaining Time: 00h 00min 01s LR: 1.00000e-03 


2025-06-04 21:12:00,000 - LIGHT_DEBUG - Batch 013/013 Loss: 4.867 Min/Max params: -0.445, 1.000


2025-06-04 21:12:00,841 - INFO - Epoch 030: Avg. Loss: 4.89075e+00 Remaining Time: 00h 00min 00s LR: 1.00000e-03 
2025-06-04 21:12:00,895 - LIGHT_DEBUG - Models/wave_net_v3_epoch_020.pth deleted
2025-06-04 21:12:00,896 - LIGHT_DEBUG - Checkpoint saved model to Models/wave_net_v3_epoch_030.pth
2025-06-04 21:12:00,948 - LIGHT_DEBUG - Saved model to Models/wave_net_v3.pth
2025-06-04 21:12:00,956 - LIGHT_DEBUG - Models/wave_net_v3_epoch_030.pth deleted


In [29]:
waveform = wave_net.fast_sample(32000, 1.0, device)
print(waveform)

Fast Sampling: 100%|██████████| 32000/32000 [02:55<00:00, 182.23it/s]

[ 9.5727378e-01 -4.6014797e-02  6.4593613e-01 ... -2.8403198e-02
  8.6213098e-05 -2.5710301e-02]





In [30]:
save_audio_file(waveform, "test.wav", 32000)

2025-06-05 06:31:04,343 - LIGHT_DEBUG - Normalized to range: [-0.99999,0.99999]
2025-06-05 06:31:04,352 - LIGHT_DEBUG - Saved file to:test.wav
