In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import librosa
import soundfile as sf
import IPython.display as ipd
import os
import time
from tqdm.notebook import tqdm

In [None]:
torch.set_num_threads(10)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
class AudioDS(Dataset):
    def __init__(self, audio_file, sample_rate=16000, sequence_length=800):
        audio, sr = librosa.load(audio_file, sr=sample_rate, mono=True)
        
        audio = np.clip(audio, -1.0, 1.0)
        
        self.audio = ((audio + 1) / 2 * 255).astype(np.int64)
        self.sequence_length = sequence_length
        
    def __len__(self):
        return max(0, len(self.audio) - self.sequence_length)
    
    def __getitem__(self, idx):
        x = self.audio[idx:idx+self.sequence_length]
        y = self.audio[idx+1:idx+self.sequence_length+1]
        
        return torch.LongTensor(x), torch.LongTensor(y)

class WaveNet(nn.Module):
    def __init__(self, n_layers=3, channels=8, kernel_size=2):
        super(WaveNet, self).__init__()
        
        self.embedding = nn.Embedding(256, channels)
        
        self.dilated_convs = nn.ModuleList()
        self.residual_convs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()
        
        for i in range(n_layers):
            dilation = 2 ** i
            self.dilated_convs.append(
                nn.Conv1d(channels, channels * 2, kernel_size, 
                        padding=(kernel_size-1)*dilation, dilation=dilation)
            )
            self.residual_convs.append(nn.Conv1d(channels, channels, 1))
            self.skip_convs.append(nn.Conv1d(channels, channels, 1))
        
        self.output_conv = nn.Conv1d(channels, 256, 1)
    
    def forward(self, x):
        x = self.embedding(x).transpose(1, 2)
        
        skip_connections = 0
        
        for i in range(len(self.dilated_convs)):
            residual = x
            
            d = self.dilated_convs[i](x)
            d = d[:, :, :x.size(2)]
            
            filter_gate, gate = torch.chunk(d, 2, dim=1)
            
            gated = torch.tanh(filter_gate) * torch.sigmoid(gate)
            
            x = residual + self.residual_convs[i](gated)
            
            skip = self.skip_convs[i](gated)
            skip_connections = skip_connections + skip
        
        y = self.output_conv(F.relu(skip_connections))
        
        return y
    
    def generate(self, num_samples=1000, initial_samples=None, temperature=1.0):
        self.eval()
        
        if initial_samples is None:
            initial_samples = torch.zeros(1, 16).long().to(device)
        else:
            initial_samples = initial_samples.to(device)
        
        generated = initial_samples.clone()
        
        with torch.no_grad():
            for i in range(num_samples):
                output = self.forward(generated[:, -16:])
                
                logits = output[:, :, -1] / temperature
                probs = F.softmax(logits, dim=1)
                next_sample = torch.multinomial(probs, 1)
                
                generated = torch.cat((generated, next_sample), dim=1)
                
                if i % 100 == 0 and i > 0:
                    time.sleep(0.01)
                
        return generated.squeeze().cpu().numpy()

def create_test_audio(filename="test_audio.wav", duration=2, sample_rate=16000):
    if os.path.exists(filename):
        return filename
    
    t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
    audio = 0.5 * np.sin(2 * np.pi * 440 * t) + 0.3 * np.sin(2 * np.pi * 880 * t)
    
    sf.write(filename, audio, sample_rate)
    return filename

def train_model(model, dataloader, epochs=2, learning_rate=0.001, checkpoint_path="wavenet_checkpoint.pt"):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        print(f"Resuming from epoch {start_epoch}")
    
    for epoch in range(start_epoch, epochs):
        model.train()
        total_loss = 0
        
        for i, (x, y) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")):
            x, y = x.to(device), y.to(device)
            
            output = model(x)
            output = output.transpose(1, 2) 
            
            output = output.reshape(-1, 256)
            y = y.reshape(-1) 
            
            loss = criterion(output, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if i % 500 == 0 and i > 0:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss.item(),
                }, checkpoint_path)
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")
        
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss/len(dataloader),
        }, f"wavenet_epoch_{epoch+1}.pt")
    
    return model

def generate_audio(model, length=4000, temperature=1.0, sample_rate=16000, filename="generated_audio.wav"):
    print("Generating audio samples...")
    generated = model.generate(num_samples=length, temperature=temperature)
    
    
    plt.figure(figsize=(10, 4))
    plt.hist(generated, bins=50)
    plt.title("Distribution of Generated Values")
    plt.xlabel("Value")
    plt.ylabel("Frequency")
    plt.show()
    
    audio = (generated.astype(float) / 255) * 2 - 1

    sf.write(filename, audio, sample_rate)
    
    plt.figure(figsize=(10, 4))
    plt.plot(audio)
    plt.title("Generated Audio")
    plt.xlabel("Sample")
    plt.ylabel("Amplitude")
    plt.show()
    
    return ipd.Audio(audio, rate=sample_rate)

In [None]:
if __name__ == "__main__":
    audio_files = [f for f in os.listdir('ds') if f.endswith('.mp3') or f.endswith('.wav')]
    if audio_files:
        audio_file = os.path.join('ds', audio_files[0])
    else:
        audio_file = create_test_audio(duration=2) 
    
    dataset = AudioDS(audio_file, sequence_length=400)
    
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    
    model = WaveNet(n_layers=2, channels=8)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    model = train_model(model, dataloader, epochs=2, learning_rate=0.001)

    torch.save({
        'model_state_dict': model.state_dict(),
    }, "wavenet_final.pt")
    
    generated_audio = generate_audio(model, length=4000, temperature=0.9)
    
    print("opt complete")
    print("'generated_audio.wav file created'")
