In [None]:
import torch
from wavenet_model import *
from audio_data import WavenetDataset
from wavenet_training import *
from model_logging import * # Or TensorboardLogger if using TensorBoard

In [None]:
dtype = torch.FloatTensor # Default data type
ltype = torch.LongTensor # Default label type

use_cuda = torch.cuda.is_available()
if use_cuda:
    print('Using GPU')
    dtype = torch.cuda.FloatTensor
    ltype = torch.cuda.LongTensor
else:
    print('Using CPU')

In [None]:
model = WaveNetModel(layers=10,
                     blocks=3,
                     dilation_channels=32,
                     residual_channels=32,
                     skip_channels=1024,
                     end_channels=512,
                     output_length=16, # Adjust based on dataset/task
                     dtype=dtype,
                     bias=True)

# If using GPU, move the model to GPU
if use_cuda:
    model.cuda()

print('Model initialized:')
print('Receptive field:', model.receptive_field)
print('Parameter count:', model.parameter_count())

In [None]:
# Ensure dataset_file path and file_location are correct
# item_length should generally be model.receptive_field + model.output_length - 1
data = WavenetDataset(dataset_file='train_samples/bach_chaconne/dataset.npz',
                      item_length=model.receptive_field + model.output_length - 1,
                      target_length=model.output_length,
                      file_location='train_samples/bach_chaconne',
                      test_stride=500) # test_stride splits data for testing
print('Dataset loaded with ' + str(len(data)) + ' items')

In [None]:
# Example using the standard logger (modify if using TensorBoard)
logger = Logger(log_interval=200, validation_interval=400) # Adjust intervals as needed

trainer = WavenetTrainer(model=model,
                         dataset=data,
                         lr=0.001, # Learning rate
                         snapshot_path='snapshots', # Directory to save model checkpoints
                         snapshot_name='chaconne_model', # Base name for checkpoints
                         snapshot_interval=1000, # How often to save checkpoints
                         logger=logger,
                         dtype=dtype,
                         ltype=ltype)

In [None]:
print('Starting training...')
trainer.train(batch_size=16, # Adjust batch size based on GPU memory
              epochs=2)      # Train for 2 epochs as requested
print('Training finished.')

In [None]:
# Optional: Load the best/latest model from snapshots if not continuing directly
# model = load_latest_model_from('snapshots', use_cuda=use_cuda) # Ensure model is on the correct device (CPU/GPU)

# Prepare starting data if needed (example uses data from the dataset)
# start_data = data[some_index][0]
# start_data = torch.max(start_data, 0)[1] # Convert one-hot to integers if necessary
# if use_cuda:
#     start_data = start_data.cuda()

# Generate audio
num_samples_to_generate = 16000 # Example: 1 second at 16kHz
print(f'Generating {num_samples_to_generate} samples...')

generated_audio = model.generate_fast(num_samples=num_samples_to_generate,
                                      # first_samples=start_data, # Optional starting sequence
                                      temperature=1.0) # Temperature for sampling diversity
print('Sample generation complete.')

# You can then save or play the generated audio
# import soundfile as sf
# sf.write('generated_sample.wav', generated_audio.cpu().numpy(), 16000) # Example using soundfile

# Or use IPython display within the notebook
import IPython.display as ipd
ipd.Audio(generated_audio.cpu().numpy(), rate=16000) # Ensure audio is on CPU for numpy conversion