# Description

This is the training section of the Spectrogram U-Net for Music Source Separation. Before running this, make sure the Spectrograms directory (along with the data points for testing and training, of course) is generated by running the preprocessor.py file once.

The training process utilizes the architecture of U-Net implemented in the UNet.py file and the loss functions implemented in the loss_functions.py

# Packages

In [None]:
import torch
from torch.utils.data import DataLoader
from architectures.UNet.UNet import SpectrogramUNet
from architectures.UNet.loss_functions import VocalLoss, InstrumentLoss
from dataset import DSDDataset
from torchsummary import summary

# Initializations and hyperparameters

### Initializations

In [2]:
SPECTROGRAMS_PATH = './Spectrograms'
VOCAL_ONLY = True
MODEL_PATH = "./models/vocal-accompaniment-separation/" if VOCAL_ONLY else "./models/all-separation/"

### UNet Parameters

In [3]:
IN_CHANNELS = 1
OUT_CHANNELS = 2 if VOCAL_ONLY else 4
FEATURES = [32, 64, 128, 256, 512]

### Training Parameters

In [4]:
EPOCHS = 50
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-6
BATCH_SIZE = 8

In [None]:
VOCAL_ALPHA = 0.651 #as per my calculations that hinge on borderline delusions

In [6]:
ALPHA1 = 0.297
ALPHA2 = 0.262
ALPHA3 = 0.232
ALPHA4 = 0.209

# Model

In [7]:
model = SpectrogramUNet(in_channel=IN_CHANNELS, out_channel=OUT_CHANNELS, features=FEATURES)

In [8]:
summary(model=model)

Layer (type:depth-idx)                        Param #
├─ModuleList: 1-1                             --
|    └─DoubleConv: 2-1                        --
|    |    └─Sequential: 3-1                   9,696
|    └─DoubleConv: 2-2                        --
|    |    └─Sequential: 3-2                   55,680
|    └─DoubleConv: 2-3                        --
|    |    └─Sequential: 3-3                   221,952
|    └─DoubleConv: 2-4                        --
|    |    └─Sequential: 3-4                   886,272
|    └─DoubleConv: 2-5                        --
|    |    └─Sequential: 3-5                   3,542,016
├─ModuleList: 1-2                             --
|    └─UpSampling: 2-6                        --
|    |    └─Sequential: 3-6                   3,277,568
|    └─DoubleDeConv: 2-7                      --
|    |    └─Sequential: 3-7                   1,771,008
|    └─UpSampling: 2-8                        --
|    |    └─Sequential: 3-8                   819,584
|    └─DoubleDeConv: 

Layer (type:depth-idx)                        Param #
├─ModuleList: 1-1                             --
|    └─DoubleConv: 2-1                        --
|    |    └─Sequential: 3-1                   9,696
|    └─DoubleConv: 2-2                        --
|    |    └─Sequential: 3-2                   55,680
|    └─DoubleConv: 2-3                        --
|    |    └─Sequential: 3-3                   221,952
|    └─DoubleConv: 2-4                        --
|    |    └─Sequential: 3-4                   886,272
|    └─DoubleConv: 2-5                        --
|    |    └─Sequential: 3-5                   3,542,016
├─ModuleList: 1-2                             --
|    └─UpSampling: 2-6                        --
|    |    └─Sequential: 3-6                   3,277,568
|    └─DoubleDeConv: 2-7                      --
|    |    └─Sequential: 3-7                   1,771,008
|    └─UpSampling: 2-8                        --
|    |    └─Sequential: 3-8                   819,584
|    └─DoubleDeConv: 

# Training

### Prerequisites

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [10]:
print(torch.version.cuda)

12.4


In [11]:
training_set = DSDDataset(spectrograms_path=SPECTROGRAMS_PATH, vocal_only=VOCAL_ONLY, train=True)
val_set = DSDDataset(spectrograms_path=SPECTROGRAMS_PATH, vocal_only=VOCAL_ONLY, train=False)

In [12]:
train_loader = DataLoader(dataset=training_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=BATCH_SIZE, shuffle=False)

In [13]:
for feature, target in train_loader:
    print(f"Feature shape: {feature.shape}\n")
    print(f"Targets: {target.keys()}\n")
    print(f"Vocal target shape: {target['vocals'].shape}\n")
    print(f"Accompaniment target shape: {target['accompaniment'].shape}")

    break

Feature shape: torch.Size([8, 1, 1025, 173])

Targets: dict_keys(['vocals', 'accompaniment'])

Vocal target shape: torch.Size([8, 1, 1025, 173])

Accompaniment target shape: torch.Size([8, 1, 1025, 173])


In [14]:
def train_one(model, dataloader, loss_fn, optimizer, device):
    
    model = model.to(device)
    running_loss = 0
    last_loss = 0

    for i, data in enumerate(dataloader):
         
        feature, target = data

        feature = feature.to(device)
        target['accompaniment'] = target['accompaniment'].to(device)
        target['vocals'] = target['vocals'].to(device)
        
        optimizer.zero_grad()
        outputs = model(feature)

        vocal_channel_output = outputs[:, 0, :, :].unsqueeze(1)
        accompaniment_channel_output = outputs[:, 1, :, :].unsqueeze(1)

        loss = loss_fn(vocal_channel_output ,target['vocals'],  accompaniment_channel_output ,target['accompaniment'])
        loss.backward()

        optimizer.step()

        running_loss+=loss.item()

        if (i+1)%5==0:
         last_loss = running_loss/5
         print(f'Batch {i+1},  loss: {last_loss}')
         running_loss=0
    return last_loss


In [15]:
def train(model, train_loader, val_loader, loss_fn, optimizer, device, epochs):

    for epoch in range(epochs):

        print(f'EPOCH {epoch+1}:')

        model.train(True)
        avg_loss = train_one(model, train_loader, loss_fn, optimizer, device)

        running_val_loss = 0.0

        model.eval()

        with torch.no_grad():

            for i, vdata in enumerate(val_loader):
                vfeature, vtarget = vdata
                vfeature = vfeature.to(device)
                vtarget['accompaniment'] = vtarget['accompaniment'].to(device)
                vtarget['vocals'] = vtarget['vocals'].to(device)
                voutput = model(vfeature)

                vocal_channel_output = voutput[:, 0, :, :].unsqueeze(1)
                accompaniment_channel_output = voutput[:, 1, :, :].unsqueeze(1)
                vloss = loss_fn(vocal_channel_output, vtarget['vocals'], accompaniment_channel_output, vtarget['accompaniment'])
                running_val_loss += vloss.item()
            avg_vloss = running_val_loss / (i + 1)
            print(f'LOSS train {avg_loss}. Validation loss: {avg_vloss} \n\n\n')


In [16]:
optimizer = torch.optim.Adam(lr=LEARNING_RATE, params=model.parameters(), weight_decay=WEIGHT_DECAY)
loss_fn = VocalLoss(alpha=VOCAL_ALPHA)

### Training the model

In [None]:
train(model, train_loader, val_loader, loss_fn, optimizer, device, EPOCHS)

In [18]:
torch.save(model.state_dict(), MODEL_PATH+'voicemodelp2.pth')