In [86]:
# import dependecies
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch
import math
from scipy.io import wavfile
from scipy.signal import butter, filtfilt
from scipy.io.wavfile import write
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

In [87]:
def preprocess_audio(path, target_rms=0.1):
    sr, waveform = wavfile.read(path)
    if waveform.ndim > 1:
        waveform = waveform.mean(axis=1)
    waveform = waveform.astype(np.float32)
    waveform -= np.mean(waveform)
    
    # Normalize amplitude and prevent clipping
    waveform /= (np.max(np.abs(waveform)) + 1e-7)
    waveform *= 0.99
    
    # High-pass filter to remove DC / subsonic
    b, a = butter(1, 20 / (sr / 2), btype='highpass')
    waveform = filtfilt(b, a, waveform)
    
    # RMS normalization
    def rms(x): return np.sqrt(np.mean(x**2))
    waveform *= target_rms / (rms(waveform) + 1e-9)
    
    return sr, waveform

In [88]:
class CausalLayer(nn.Module):
    def __init__(self, channels, kernel_size, dilation):
        super().__init__()
        self.left_pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(channels, channels, kernel_size, dilation=dilation)

        # optional 1x1 conv for residual
        self.res_conv = nn.Conv1d(channels, channels, 1)
        # optional 1x1 conv for skip
        self.skip_conv = nn.Conv1d(channels, channels, 1)

    def forward(self, x):
        x_padded = F.pad(x, (self.left_pad, 0))
        out = torch.tanh(self.conv(x_padded))

        skip = self.skip_conv(out)
        res = self.res_conv(out) + x  # residual connection
        return res, skip

class WaveNet(nn.Module):
    def __init__(self, in_channels=1, channels=16, kernel_size=3, dilations=None):
        super().__init__()
        self.dilations = dilations if dilations is not None else [2 ** i for i in range(10)]
        self.input_proj = nn.Conv1d(in_channels, channels, 1)

        self.blocks = nn.ModuleList([
            CausalLayer(channels, kernel_size, d)
            for d in self.dilations
        ])

        self.output_proj = nn.Conv1d(channels, 1, 1)

    def forward(self, x):
        x = self.input_proj(x)
        skip_connections = []

        for block in self.blocks:
            x, skip = block(x)
            skip_connections.append(skip)

        # sum all skip connections and apply final projection
        x = sum(skip_connections)
        x = self.output_proj(x)
        return x

In [95]:
sr, clean_quant = preprocess_audio("riff_clean.wav")
sr, amp_quant   = preprocess_audio("scale_amp.wav")

clean_quant = torch.from_numpy(clean_quant.copy()).to(torch.float32)
amp_quant = torch.from_numpy(amp_quant.copy()).to(torch.float32)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = WaveNet().to(device)

# state_dict = torch.load("model.pth", map_location="cuda")  # map_location if not using GPU
# model.load_state_dict(state_dict)

state_dict = torch.load("model.pth", map_location=device)
model_dict = model.state_dict()

# Only keep keys that match in size
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_dict and v.size() == model_dict[k].size()}

model_dict.update(filtered_state_dict)
model.load_state_dict(model_dict)



  sr, waveform = wavfile.read(path)


<All keys matched successfully>

In [96]:
res = model(clean_quant.unsqueeze(0).to(device))
# q = mu_law_decode(amp_quant)
write("riff_clean.wav.wav", sr, (res * 32767).squeeze(0).detach().cpu().numpy().astype(np.int16))
# write("outpu11.wav", sr, (amp_quant * 32767).detach().cpu().numpy().astype(np.int16))


In [None]:
import torch
import matplotlib.pyplot as plt

# Assuming clean_tensor is [seq_len] or [batch_size, seq_len]
# tensor_np = pred_np
print(q.shape)
plt.figure(figsize=(12, 4))
# plt.plot(mu_law_decode(y_pred.squeeze(0).numpy()), color='blue')
plt.plot(q, color='blue')

plt.title("Clean Tensor Waveform")
plt.xlabel("Sample Index")
plt.ylabel("Amplitude (quantized)")
plt.grid(True)
plt.show()
