<a href="https://colab.research.google.com/github/IRPARKS/NMML/blob/main/NMMLHW13.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1. **Importing Libraries**:
   - The code starts by importing necessary libraries including `torch` (PyTorch), `torch.nn.functional` (for various neural network functions), `DataLoader` (for batch data loading), and custom modules (`WavenetDataset`, `WaveNetModel`, `WavenetTrainer`, `Logger`) specific to the WaveNet model and training process.

2. **Loading Latest Model Checkpoint**:
   - The function `load_latest_model_from` is defined to load the latest saved model checkpoint from a specified `snapshot_path` if available. It loads the model's state dictionary (`model_state_dict`) from the checkpoint file.

3. **Initializing CUDA**:
   - It checks if CUDA (GPU) is available. If CUDA is available, the model and related tensors will be moved to the GPU for accelerated computation.

4. **WaveNet Model Initialization**:
   - The WaveNet model (`WaveNetModel`) is initialized with specified parameters such as the number of layers (`layers`), blocks (`blocks`), dilation channels (`dilation_channels`), residual channels (`residual_channels`), skip channels (`skip_channels`), end channels (`end_channels`), and output length (`output_length`). The model is then moved to the specified device (GPU if available).

5. **Initializing Dataset and DataLoader**:
   - A custom `WavenetDataset` is initialized with specified dataset parameters (`dataset_file`, `item_length`, `target_length`, `file_location`, `test_stride`). This dataset is then wrapped by a `DataLoader` for batch processing during training.

6. **Initializing Logger**:
   - A `Logger` object is initialized with specified logging intervals for training progress monitoring and model evaluation.

7. **Initializing WavenetTrainer**:
   - A `WavenetTrainer` object is initialized with the WaveNet model, dataset loader, learning rate (`lr`), snapshot saving path (`snapshot_path`), and logger. This trainer manages the training process including optimization and logging.

8. **Training Loop**:
   - The training loop runs for a specified number of `epochs`. Inside each epoch, it iterates over batches from the dataset loader.
   - For each batch:
     - Data (`x`, `target`) is moved to the specified device (GPU).
     - The model is used to compute predictions (`output`) given the input (`x`).
     - The loss is calculated using cross-entropy loss between the predicted output and target values.
     - Backpropagation (`loss.backward()`) is performed to compute gradients.
     - The optimizer (`trainer.optimizer`) is used to update model parameters (`trainer.optimizer.step()`).
     - Training progress (including loss) is printed at specified logging intervals (`logger.log_interval`).
   - At the end of each epoch, the model checkpoint is optionally saved based on `trainer.snapshot_interval`.

9. **Saving Model Checkpoints**:
   - At specified intervals (`trainer.snapshot_interval`), the current model state (including epoch number, model state dictionary, optimizer state, and loss) is saved to a file in the specified `snapshot_path`.


In [1]:
!git clone https://github.com/Vichoko/pytorch-wavenet.git

Cloning into 'pytorch-wavenet'...
remote: Enumerating objects: 1168, done.[K
remote: Counting objects: 100% (10/10), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 1168 (delta 3), reused 6 (delta 3), pack-reused 1158[K
Receiving objects: 100% (1168/1168), 268.95 MiB | 20.86 MiB/s, done.
Resolving deltas: 100% (720/720), done.


In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from audio_data import WavenetDataset
from wavenet_model import WaveNetModel
from wavenet_training import WavenetTrainer
from model_logging import Logger


# Function to load the latest model checkpoint
def load_latest_model_from(snapshot_path, model):
    # Specify the file path for the latest model checkpoint
    checkpoint_file = f'{snapshot_path}/best_model.pth'

    # Load the model checkpoint
    checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))

    # Load the model state dict
    model.load_state_dict(checkpoint['model_state_dict'])

    return model


# Initialize CUDA if available
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
ltype = torch.cuda.LongTensor if use_cuda else torch.LongTensor

# Model parameters
layers = 3
blocks = 2
dilation_channels = 8
residual_channels = 8
skip_channels = 64
end_channels = 32
output_length = 4

# Initialize WaveNet model
model = WaveNetModel(layers=layers,
                     blocks=blocks,
                     dilation_channels=dilation_channels,
                     residual_channels=residual_channels,
                     skip_channels=skip_channels,
                     end_channels=end_channels,
                     output_length=output_length,
                     dtype=dtype,
                     bias=True).to(device)

# Dataset parameters
dataset_file = '/content/pytorch-wavenet/train_samples/bach_chaconne/dataset.npz'
item_length = model.receptive_field + model.output_length - 1
target_length = model.output_length
file_location = 'train_samples/bach_chaconne'
test_stride = 500

# Initialize WavenetDataset
data = WavenetDataset(dataset_file=dataset_file,
                      item_length=item_length,
                      target_length=target_length,
                      file_location=file_location,
                      test_stride=test_stride)

# DataLoader for batch processing
batch_size = 16
data_loader = DataLoader(data, batch_size=batch_size, shuffle=True)

# Initialize Logger for model training
logger = Logger(log_interval=200,
                validation_interval=400,
                generate_interval=1000)

# Initialize WavenetTrainer
trainer = WavenetTrainer(model=model,
                          dataset=data_loader,
                          lr=0.001,
                          snapshot_path='snapshots',
                          snapshot_name='chaconne_model',
                          snapshot_interval=1000,
                          logger=logger,
                          dtype=dtype,
                          ltype=ltype)

# Training loop
epochs = 10
for epoch in range(epochs):
    for batch_idx, (x, target) in enumerate(data_loader):
        x, target = x.to(device), target.to(device)

        # Forward pass
        output = model(x)

        # Resize target to match the output shape
        target = target.view(-1)  # Flatten to match the output shape

        # Compute loss
        loss = F.cross_entropy(output.view(-1, output.size(-1)), target)

        # Backward and optimize
        trainer.optimizer.zero_grad()
        loss.backward()
        trainer.optimizer.step()

        # Print training progress
        if batch_idx % logger.log_interval == 0:
            print(f"Epoch: {epoch + 1}, Batch: {batch_idx}, Loss: {loss.item()}")

    # Optionally log or evaluate at end of epoch
    # Example: Save model checkpoint
    if (epoch + 1) % trainer.snapshot_interval == 0:
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': trainer.optimizer.state_dict(),
            'loss': loss
        }, f'snapshots/epoch_{epoch + 1}.pth')


one hot input
Epoch: 1, Batch: 0, Loss: 5.517580986022949
Epoch: 1, Batch: 200, Loss: 5.147301197052002
Epoch: 1, Batch: 400, Loss: 5.076307773590088
Epoch: 1, Batch: 600, Loss: 4.77754020690918
Epoch: 1, Batch: 800, Loss: 5.2534918785095215
Epoch: 1, Batch: 1000, Loss: 4.678632736206055
Epoch: 1, Batch: 1200, Loss: 4.581977844238281
Epoch: 1, Batch: 1400, Loss: 4.732357501983643
Epoch: 1, Batch: 1600, Loss: 4.235000133514404
Epoch: 1, Batch: 1800, Loss: 4.090487480163574
Epoch: 1, Batch: 2000, Loss: 4.087838649749756
Epoch: 1, Batch: 2200, Loss: 4.087427139282227
Epoch: 1, Batch: 2400, Loss: 3.736981153488159
Epoch: 1, Batch: 2600, Loss: 4.095047473907471
Epoch: 1, Batch: 2800, Loss: 4.018249988555908
Epoch: 1, Batch: 3000, Loss: 4.022060394287109
Epoch: 1, Batch: 3200, Loss: 3.8930771350860596
Epoch: 1, Batch: 3400, Loss: 3.8599467277526855
Epoch: 1, Batch: 3600, Loss: 3.641538619995117
Epoch: 1, Batch: 3800, Loss: 3.9623162746429443
Epoch: 1, Batch: 4000, Loss: 3.993640422821045
Epo

The general trend seen from this is the loss values generally decrease over batches, which is a positive sign indicating that the model is learning from the training data. Some batches show temporary increases in loss, suggesting potential challenges or noise in the training data that affect the model's learning. The decreasing trend of loss values indicates that the model is improving over the course of Epoch 1, capturing more complex patterns in the data.