# Description

This is the training section of the Spectrogram U-Net for Music Source Separation. Before running this, make sure the Spectrograms directory (along with the data points for testing and training, of course) is generated by running the preprocessor.py file once.

The training process utilizes the architecture of U-Net implemented in the UNet.py file and the loss functions implemented in the loss_functions.py

# Packages

In [1]:
import torch
from torch.utils.data import DataLoader
from architectures.UNet.UNet import SpectrogramUNet
from architectures.UNet.loss_functions import VocalLoss, InstrumentLoss
from dataset import DSDDataset
from torchsummary import summary

# Initializations and hyperparameters

### Initializations

In [2]:
SPECTROGRAMS_PATH = './Spectrograms'
VOCAL_ONLY = False
MODEL_PATH = "./models/vocal-accompaniment-separation/" if VOCAL_ONLY else "./models/all-separation/"

### UNet Parameters

In [3]:
IN_CHANNELS = 1
OUT_CHANNELS = 2 if VOCAL_ONLY else 4
FEATURES = [32, 64, 128, 256, 512]

### Training Parameters

In [4]:
EPOCHS = 50
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-6
BATCH_SIZE = 8

In [11]:
VOCAL_ALPHA = 0.651 #as per my calculations that hinge on borderline delusions

In [12]:
ALPHA_VOCAL = 0.233
ALPHA_DRUM = 0.263
ALPHA_GUITAR = 0.286
ALPHA4_OTHER = 0.218

# Model

In [13]:
model = SpectrogramUNet(in_channel=IN_CHANNELS, out_channel=OUT_CHANNELS, features=FEATURES)

In [14]:
summary(model=model)

Layer (type:depth-idx)                        Param #
├─ModuleList: 1-1                             --
|    └─DoubleConv: 2-1                        --
|    |    └─Sequential: 3-1                   9,696
|    └─DoubleConv: 2-2                        --
|    |    └─Sequential: 3-2                   55,680
|    └─DoubleConv: 2-3                        --
|    |    └─Sequential: 3-3                   221,952
|    └─DoubleConv: 2-4                        --
|    |    └─Sequential: 3-4                   886,272
|    └─DoubleConv: 2-5                        --
|    |    └─Sequential: 3-5                   3,542,016
├─ModuleList: 1-2                             --
|    └─UpSampling: 2-6                        --
|    |    └─Sequential: 3-6                   3,277,568
|    └─DoubleDeConv: 2-7                      --
|    |    └─Sequential: 3-7                   1,771,008
|    └─UpSampling: 2-8                        --
|    |    └─Sequential: 3-8                   819,584
|    └─DoubleDeConv: 

Layer (type:depth-idx)                        Param #
├─ModuleList: 1-1                             --
|    └─DoubleConv: 2-1                        --
|    |    └─Sequential: 3-1                   9,696
|    └─DoubleConv: 2-2                        --
|    |    └─Sequential: 3-2                   55,680
|    └─DoubleConv: 2-3                        --
|    |    └─Sequential: 3-3                   221,952
|    └─DoubleConv: 2-4                        --
|    |    └─Sequential: 3-4                   886,272
|    └─DoubleConv: 2-5                        --
|    |    └─Sequential: 3-5                   3,542,016
├─ModuleList: 1-2                             --
|    └─UpSampling: 2-6                        --
|    |    └─Sequential: 3-6                   3,277,568
|    └─DoubleDeConv: 2-7                      --
|    |    └─Sequential: 3-7                   1,771,008
|    └─UpSampling: 2-8                        --
|    |    └─Sequential: 3-8                   819,584
|    └─DoubleDeConv: 

# Training

### Prerequisites

In [15]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [16]:
print(torch.version.cuda)

12.4


In [17]:
training_set = DSDDataset(spectrograms_path=SPECTROGRAMS_PATH, vocal_only=VOCAL_ONLY, train=True)
val_set = DSDDataset(spectrograms_path=SPECTROGRAMS_PATH, vocal_only=VOCAL_ONLY, train=False)

In [18]:
train_loader = DataLoader(dataset=training_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=BATCH_SIZE, shuffle=False)

In [19]:
for feature, target in train_loader:
    print(f"Feature shape: {feature.shape}\n")
    print(f"Targets: {target.keys()}\n")
    print(f"Vocal target shape: {target['vocals'].shape}\n")
    print(f"Guitar target shape: {target['guitar'].shape}")

    break

Feature shape: torch.Size([8, 1, 1025, 173])

Targets: dict_keys(['vocals', 'drums', 'guitar', 'other'])

Vocal target shape: torch.Size([8, 1, 1025, 173])

Guitar target shape: torch.Size([8, 1, 1025, 173])


In [20]:
def determine_loss(model, feature, target, loss_fn):

 outputs = model(feature)

 vocal_channel_output = outputs[:, 0, :, :].unsqueeze(1)

 if VOCAL_ONLY:
    accompaniment_channel_output = outputs[:, 1, :, :].unsqueeze(1)
    loss_list = [vocal_channel_output ,target['vocals'],  accompaniment_channel_output ,target['accompaniment']]
 else:
    drums_channel_output = outputs[:, 1, :, :].unsqueeze(1)
    guitar_channel_output = outputs[:, 2, :, :].unsqueeze(1)
    other_channel_output = outputs[:, 3, :, :].unsqueeze(1)
    loss_list = [vocal_channel_output ,target['vocals'],  drums_channel_output ,target['drums'], guitar_channel_output, target['guitar'], other_channel_output, target['other']]

 loss = loss_fn(*loss_list)
 return loss

In [21]:
def train_one(model, dataloader, loss_fn, optimizer, device):
    
    model = model.to(device)
    running_loss = 0
    last_loss = 0

    for i, data in enumerate(dataloader):
         
        feature, target = data

        feature = feature.to(device)
        
        for key in target:
         target[key] = target[key].to(device)
        
        
        optimizer.zero_grad()
        
        loss = determine_loss(model, feature, target, loss_fn )
        loss.backward()

        optimizer.step()

        running_loss+=loss.item()

        if (i+1)%5==0:
         last_loss = running_loss/5
         print(f'Batch {i+1},  loss: {last_loss}')
         running_loss=0
    return last_loss


In [22]:
def train(model, train_loader, val_loader, loss_fn, optimizer, device, epochs):

    for epoch in range(epochs):

        print(f'EPOCH {epoch+1}:')

        model.train(True)
        avg_loss = train_one(model, train_loader, loss_fn, optimizer, device)

        running_val_loss = 0.0

        model.eval()

        with torch.no_grad():

            for i, vdata in enumerate(val_loader):
                vfeature, vtarget = vdata
                vfeature = vfeature.to(device)

                for key in vtarget:
                 vtarget[key] = vtarget[key].to(device)
                
                vloss = determine_loss(model, vfeature, vtarget, loss_fn)
                running_val_loss += vloss.item()
            avg_vloss = running_val_loss / (i + 1)
            print(f'LOSS train {avg_loss}. Validation loss: {avg_vloss} \n\n\n')
        
        if (epoch+1) % 5 == 0:
           checkpoint = {
                'epoch': epoch,  
                'model_state_dict': model.state_dict(), 
                'optimizer_state_dict': optimizer.state_dict(), 
                'avg_loss':avg_loss,
                'avg_vloss':avg_vloss
               
                   }

           torch.save(checkpoint, f'checkpoint_{epoch+1}.pth')


In [23]:
optimizer = torch.optim.Adam(lr=LEARNING_RATE, params=model.parameters(), weight_decay=WEIGHT_DECAY)
loss_fn = VocalLoss(alpha=VOCAL_ALPHA) if VOCAL_ONLY else InstrumentLoss(alpha1=ALPHA_VOCAL, alpha2=ALPHA_DRUM, alpha3=ALPHA_GUITAR, alpha4=ALPHA4_OTHER)

### Training the model

In [24]:
train(model, train_loader, val_loader, loss_fn, optimizer, device, EPOCHS)

EPOCH 1:
Batch 5,  loss: 0.321671724319458
Batch 10,  loss: 0.3291707456111908


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), MODEL_PATH+'multisepmodelp1.pth')