## Training notebook for Deep-Augmented MUSIC

This is the training notebook for deep augmented MUSIC. Note that this notebook is made to be run on Google Colab and will clone the repository by itself.

If you are already in the repository, **skip this cell**

In [None]:
!git clone https://MichelDucartier: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@github.com/MichelDucartier/music_doa.git
%cd music_doa

#### Imports

In [None]:
from pathlib import Path
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import yaml
from IPython.display import clear_output

sys.path.append("src")
from src.deep.deep_music import DeepMUSIC, rmspe_loss, predict
from src.deep.synthetic_data import load_microphones, create_dataset

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

CHECKPOINTS_PATH = "checkpoints/"

Create a dataset from scratch:

In [None]:

data_input, data_output, data_n_sources = create_dataset('train_coherent_dataset', 10000, coherent=True)

Load an existing dataset:

In [None]:
import h5py

filename = "data/coherent_dataset.h5"

with h5py.File(filename, "r") as f:
    print("Keys: %s" % f.keys())
    data_input = list(f[list(f.keys())[0]])
    data_output = list(f[list(f.keys())[1]])
    data_n_sources = list(f[list(f.keys())[2]])

## Training loop

Here is the training loop.

We first import the model

In [None]:
mics_coords = torch.tensor(load_microphones()).to(device)

# Load the model
with open("conf/deep_music.yaml") as stream:
    try:
        conf = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)


model = DeepMUSIC(mics_coords, conf)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

start_epoch = 0

model.to(device)
model

Now the training loop. If you want to resume checkpoint from a point, uncomment the following cell:

In [None]:
# checkpoint_epoch = ...
# checkpoint_sample = ... 

# checkpoint = torch.load(os.path.join(CHECKPOINTS_PATH, f'checkpoint-{checkpoint_epoch}-{checkpoint_sample}.pth'))
# start_epoch = checkpoint['epoch']

# model.load_state_dict(checkpoint["model"])
# model.to(device)

# optimizer = optim.Adam(model.parameters(), lr=1e-3)
# optimizer.load_state_dict(checkpoint['optimizer'])

In [None]:
Path(CHECKPOINTS_PATH).mkdir(parents=True, exist_ok=True)

criterion = rmspe_loss

plt.ion()
X = []
Y = []

batch_size = 100
plot_update = 100
save_step = len(data_input) // 2
n_epochs = 10

running_loss = []
optimizer.zero_grad()

for epoch in range(start_epoch, n_epochs):
  for i, (audio, doas, n_sources) in enumerate(zip(data_input, data_output, data_n_sources)):
      audio = torch.tensor(audio).to(device)
      n_sources = int(n_sources[0])
      doas = torch.tensor(doas[: n_sources]).to(device)


      # forward + backward + optimize
      estimated_doas, _ = model(audio, n_sources)

      loss = criterion(estimated_doas, doas, n_sources)
      loss.backward()

      if (i+1) % batch_size == 0:
        optimizer.step()
        optimizer.zero_grad()


        # Update plot
      if (i+1) % plot_update == 0:
        X.append(0 if len(X) == 0 else X[-1]+1)
        Y.append(np.mean(running_loss))
        plt.plot(X, Y, color='b')
        clear_output(wait=True)
        display(plt.gcf())
        running_loss = []

      running_loss.append(loss.item())

      if i % save_step == 0:
        checkpoint = { 
          'epoch': epoch,
          'model': model.state_dict(),
          'optimizer': optimizer.state_dict(),
        }
        
        torch.save(checkpoint, os.path.join(CHECKPOINTS_PATH, f'checkpoint-{epoch}-{i}.pth'))

print('Finished Training')

In [None]:
Path("models/").mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), "models/finetuned_model.pt")