In [None]:
# import dependecies
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch
import math

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

In [None]:
import numpy as np
from scipy.io import wavfile
from scipy.signal import butter, filtfilt

def preprocess_audio(path, target_rms=0.1, mu_law=False):
    sr, waveform = wavfile.read(path)
    if waveform.ndim > 1:
        waveform = waveform.mean(axis=1)
    waveform = waveform.astype(np.float32)
    waveform -= np.mean(waveform)
    
    # Normalize amplitude and prevent clipping
    waveform /= (np.max(np.abs(waveform)) + 1e-7)
    waveform *= 0.99
    
    # High-pass filter to remove DC / subsonic
    b, a = butter(1, 20 / (sr / 2), btype='highpass')
    waveform = filtfilt(b, a, waveform)
    
    # RMS normalization
    def rms(x): return np.sqrt(np.mean(x**2))
    waveform *= target_rms / (rms(waveform) + 1e-9)
    
    if mu_law:
        waveform = mu_law_encode(waveform)
    
    return sr, waveform


In [None]:
# Load
import pickle
from scipy.io import wavfile


class CausalConv1d(nn.Module):
    """1D causal convolution."""
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1):
        super().__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            padding=self.pad, dilation=dilation
        )
        
    def forward(self, x):
        out = self.conv(x)
        if self.pad > 0:
            out = out[:, :, :-self.pad]  # remove causal padding
        return out

class ResidualBlock(nn.Module):
    """Residual block with gated activation."""
    def __init__(self, channels, kernel_size, dilation):
        super().__init__()
        self.filter_conv = CausalConv1d(channels, channels, kernel_size, dilation)
        self.gate_conv = CausalConv1d(channels, channels, kernel_size, dilation)
        self.residual_conv = nn.Conv1d(channels, channels, 1)
        self.skip_conv = nn.Conv1d(channels, channels, 1)
        
    def forward(self, x):
        out = torch.tanh(self.filter_conv(x)) * torch.sigmoid(self.gate_conv(x))
        skip = self.skip_conv(out)
        res = self.residual_conv(out) + x
        return res, skip

class WaveNet(nn.Module):
    def __init__(self, in_channels=1, channels=16, kernel_size=3, num_blocks=1, dilations=None):
        super().__init__()
        self.causal_in = CausalConv1d(in_channels, channels, kernel_size=1)
        self.dilations = dilations if dilations is not None else [2 ** i for i in range(10)]
        self.blocks = nn.ModuleList([
            ResidualBlock(channels, kernel_size, d) 
            for _ in range(num_blocks) 
            for d in self.dilations
        ])
        self.relu = nn.ReLU()
        self.out1 = nn.Conv1d(channels, channels, 1)
        self.out2 = nn.Conv1d(channels, 1, 1)

    def forward(self, x):
        x = self.causal_in(x)
        skip_connections = 0
        for block in self.blocks:
            x, skip = block(x)
            skip_connections = skip_connections + skip if isinstance(skip_connections, torch.Tensor) else skip
        out = self.relu(skip_connections)
        out = self.relu(self.out1(out))
        out = self.out2(out)
        return out

    @property
    def receptive_field(self):
        rf = 1
        for d in self.dilations:
            rf += (3 - 1) * d
        return rf

class AmpDatasetVectorized(torch.utils.data.Dataset):
    """
    Fully vectorized dataset:
    - Feed entire waveform to model at once
    - No Python slicing loops
    - Output is trimmed to match receptive field
    """
    def __init__(self, clean_wave, amp_wave, model: WaveNet):
        self.x = torch.tensor(clean_wave, dtype=torch.float32).unsqueeze(0)  # (1, L)
        self.y = torch.tensor(amp_wave, dtype=torch.float32).unsqueeze(0)    # (1, L)
        self.rf = model.receptive_field
        assert self.x.shape[-1] >= self.rf, "Waveform too short for receptive field"

    def __len__(self):
        return 1  # Entire waveform in one pass

    def __getitem__(self, idx):
        # Slice outputs to ignore initial zeros from causal padding
        x_input = self.x
        y_target = self.y[:, self.rf - 1:]  # trim first rf-1 samples
        return x_input, y_target


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pkl", weights_only=False, map_location="cpu").to(device)


sr, waveform = wavfile.read("riff_clean.wav")
if waveform.ndim > 1:
    waveform = waveform.mean(axis=1)

waveform = waveform.astype(np.float32)
waveform /= np.abs(waveform).max()
# clean_quant = mu_law_encode(waveform)
model.eval()

# with torch.no_grad():
#     # x: 1D tensor of input waveform
#     x = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)  # (1, L)
    
#     # forward pass
#     y_pred = model(x)[:, :, model.receptive_field - 1:]  # remove initial causal padding
    
#     y_pred = y_pred.squeeze(0).cpu()  # (L - rf + 1,)

# # q = mu_law_decode(clean_quant)
# q = y_pred.squeeze(0).numpy()

# from scipy.io.wavfile import write
# write("outpu11.wav", sr, (q * 32767).astype(np.int16))

In [None]:
MU = 2048

def mu_law_encode(x, mu=MU):
    x = np.clip(x, -1.0, 1.0)
    mag = np.log1p(mu * np.abs(x)) / np.log1p(mu)
    encoded = ((np.sign(x) * mag) + 1) / 2 * mu
    return np.round(encoded).astype(np.int64)
def mu_law_decode(encoded, mu=MU):
    x = (encoded.astype(np.float32) / mu) * 2 - 1
    sign = np.sign(x)
    mag = (1 / mu) * ((1 + mu) ** np.abs(x) - 1)
    return sign * mag

In [13]:
from scipy.io import wavfile

sr, clean_quant = preprocess_audio("scale_clean.wav")
sr, amp_quant   = preprocess_audio("scale_amp.wav")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# q = mu_law_decode(amp_quant)
from scipy.io.wavfile import write
write("outpu11.wav", sr, (clean_quant * 32767).astype(np.int16))

  sr, waveform = wavfile.read(path)


In [None]:
import torch
import matplotlib.pyplot as plt

# Assuming clean_tensor is [seq_len] or [batch_size, seq_len]
# tensor_np = pred_np
print(q.shape)
plt.figure(figsize=(12, 4))
# plt.plot(mu_law_decode(y_pred.squeeze(0).numpy()), color='blue')
plt.plot(q, color='blue')

plt.title("Clean Tensor Waveform")
plt.xlabel("Sample Index")
plt.ylabel("Amplitude (quantized)")
plt.grid(True)
plt.show()
