# Trained on a single audio segment

# 1. Training on a single audio

In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np

In [49]:
class AutoencConv1D(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder filterbank. Using stride for downsampling, padding (kernel_size/2-1) to center filter at 0.
        self.enc_conv1 = nn.Conv1d(in_channels=1, out_channels=32, kernel_size=2048, stride=1024, padding=1023, bias=True)
        
        # Decoder filterbank
        self.dec_convt1 = nn.ConvTranspose1d(in_channels=32, out_channels=1, kernel_size=2048, stride=1024, padding=1023, bias=True)

    def encoder(self, x):
        x = self.enc_conv1(x)
        x = torch.tanh(x)
        return x
      
    def decoder(self, x):
        x = self.dec_convt1(x)
        return x
      
    def forward(self, x):
        z = self.encoder(x)
        x_reconstructed = self.decoder(z)
        return x_reconstructed

In [51]:
audio_path = "audio/dubstep.flac"
waveform, samplerate = torchaudio.load(audio_path)

# Convert to mono
waveform = torch.mean(waveform, dim=0, keepdim=True)
print(waveform.shape)

# Normalize so that max absolute value = 1
max_val = torch.abs(waveform).max()
if max_val > 0:
    waveform = waveform / max_val

torch.Size([1, 302085])


In [52]:
from IPython.display import Audio
Audio(waveform.numpy(), rate=samplerate)

In [53]:
model = AutoencConv1D()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"AutoencConv1D parameters: {num_params}")

AutoencConv1D parameters: 131105


In [54]:
# Do a forward pass to measure output length
waveform_reconstructed = model(waveform)
len_output = waveform_reconstructed.shape[1]
# Trim target reference to this duration
waveform_target = waveform[:,:len_output]
waveform_target.shape

torch.Size([1, 302082])

In [55]:
# Training loop
n_epochs = 10000
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


for epoch in range(n_epochs):
    waveform_predicted = model(waveform)
    loss = loss_function(waveform_predicted, waveform_target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1)%10==0:
        loss_train = loss.item()
        print(
            f'Epoch [{epoch + 1}/{n_epochs}], ' +
            f'Train loss: {loss_train:.4f}'
        )

Epoch [10/10000], Train loss: 0.0708
Epoch [20/10000], Train loss: 0.0673
Epoch [30/10000], Train loss: 0.0634
Epoch [40/10000], Train loss: 0.0600
Epoch [50/10000], Train loss: 0.0572
Epoch [60/10000], Train loss: 0.0550
Epoch [70/10000], Train loss: 0.0532
Epoch [80/10000], Train loss: 0.0517
Epoch [90/10000], Train loss: 0.0504
Epoch [100/10000], Train loss: 0.0492
Epoch [110/10000], Train loss: 0.0481
Epoch [120/10000], Train loss: 0.0470
Epoch [130/10000], Train loss: 0.0459
Epoch [140/10000], Train loss: 0.0449
Epoch [150/10000], Train loss: 0.0439
Epoch [160/10000], Train loss: 0.0429
Epoch [170/10000], Train loss: 0.0419
Epoch [180/10000], Train loss: 0.0409
Epoch [190/10000], Train loss: 0.0400
Epoch [200/10000], Train loss: 0.0391
Epoch [210/10000], Train loss: 0.0382
Epoch [220/10000], Train loss: 0.0373
Epoch [230/10000], Train loss: 0.0365
Epoch [240/10000], Train loss: 0.0357
Epoch [250/10000], Train loss: 0.0349
Epoch [260/10000], Train loss: 0.0342
Epoch [270/10000], Tr

Epoch [2140/10000], Train loss: 0.0115
Epoch [2150/10000], Train loss: 0.0115
Epoch [2160/10000], Train loss: 0.0114
Epoch [2170/10000], Train loss: 0.0114
Epoch [2180/10000], Train loss: 0.0114
Epoch [2190/10000], Train loss: 0.0114
Epoch [2200/10000], Train loss: 0.0114
Epoch [2210/10000], Train loss: 0.0114
Epoch [2220/10000], Train loss: 0.0114
Epoch [2230/10000], Train loss: 0.0114
Epoch [2240/10000], Train loss: 0.0114
Epoch [2250/10000], Train loss: 0.0114
Epoch [2260/10000], Train loss: 0.0114
Epoch [2270/10000], Train loss: 0.0113
Epoch [2280/10000], Train loss: 0.0113
Epoch [2290/10000], Train loss: 0.0113
Epoch [2300/10000], Train loss: 0.0113
Epoch [2310/10000], Train loss: 0.0113
Epoch [2320/10000], Train loss: 0.0113
Epoch [2330/10000], Train loss: 0.0113
Epoch [2340/10000], Train loss: 0.0113
Epoch [2350/10000], Train loss: 0.0113
Epoch [2360/10000], Train loss: 0.0113
Epoch [2370/10000], Train loss: 0.0113
Epoch [2380/10000], Train loss: 0.0113
Epoch [2390/10000], Train

Epoch [4260/10000], Train loss: 0.0104
Epoch [4270/10000], Train loss: 0.0104
Epoch [4280/10000], Train loss: 0.0104
Epoch [4290/10000], Train loss: 0.0104
Epoch [4300/10000], Train loss: 0.0104
Epoch [4310/10000], Train loss: 0.0104
Epoch [4320/10000], Train loss: 0.0104
Epoch [4330/10000], Train loss: 0.0104
Epoch [4340/10000], Train loss: 0.0104
Epoch [4350/10000], Train loss: 0.0104
Epoch [4360/10000], Train loss: 0.0104
Epoch [4370/10000], Train loss: 0.0104
Epoch [4380/10000], Train loss: 0.0104
Epoch [4390/10000], Train loss: 0.0104
Epoch [4400/10000], Train loss: 0.0104
Epoch [4410/10000], Train loss: 0.0104
Epoch [4420/10000], Train loss: 0.0104
Epoch [4430/10000], Train loss: 0.0104
Epoch [4440/10000], Train loss: 0.0104
Epoch [4450/10000], Train loss: 0.0104
Epoch [4460/10000], Train loss: 0.0104
Epoch [4470/10000], Train loss: 0.0104
Epoch [4480/10000], Train loss: 0.0104
Epoch [4490/10000], Train loss: 0.0104
Epoch [4500/10000], Train loss: 0.0103
Epoch [4510/10000], Train

Epoch [6370/10000], Train loss: 0.0100
Epoch [6380/10000], Train loss: 0.0100
Epoch [6390/10000], Train loss: 0.0100
Epoch [6400/10000], Train loss: 0.0100
Epoch [6410/10000], Train loss: 0.0100
Epoch [6420/10000], Train loss: 0.0100
Epoch [6430/10000], Train loss: 0.0100
Epoch [6440/10000], Train loss: 0.0100
Epoch [6450/10000], Train loss: 0.0100
Epoch [6460/10000], Train loss: 0.0100
Epoch [6470/10000], Train loss: 0.0100
Epoch [6480/10000], Train loss: 0.0100
Epoch [6490/10000], Train loss: 0.0100
Epoch [6500/10000], Train loss: 0.0100
Epoch [6510/10000], Train loss: 0.0100
Epoch [6520/10000], Train loss: 0.0100
Epoch [6530/10000], Train loss: 0.0100
Epoch [6540/10000], Train loss: 0.0100
Epoch [6550/10000], Train loss: 0.0100
Epoch [6560/10000], Train loss: 0.0100
Epoch [6570/10000], Train loss: 0.0100
Epoch [6580/10000], Train loss: 0.0100
Epoch [6590/10000], Train loss: 0.0100
Epoch [6600/10000], Train loss: 0.0100
Epoch [6610/10000], Train loss: 0.0100
Epoch [6620/10000], Train

Epoch [8490/10000], Train loss: 0.0098
Epoch [8500/10000], Train loss: 0.0098
Epoch [8510/10000], Train loss: 0.0098
Epoch [8520/10000], Train loss: 0.0098
Epoch [8530/10000], Train loss: 0.0098
Epoch [8540/10000], Train loss: 0.0098
Epoch [8550/10000], Train loss: 0.0098
Epoch [8560/10000], Train loss: 0.0098
Epoch [8570/10000], Train loss: 0.0098
Epoch [8580/10000], Train loss: 0.0098
Epoch [8590/10000], Train loss: 0.0098
Epoch [8600/10000], Train loss: 0.0098
Epoch [8610/10000], Train loss: 0.0098
Epoch [8620/10000], Train loss: 0.0098
Epoch [8630/10000], Train loss: 0.0098
Epoch [8640/10000], Train loss: 0.0098
Epoch [8650/10000], Train loss: 0.0098
Epoch [8660/10000], Train loss: 0.0098
Epoch [8670/10000], Train loss: 0.0098
Epoch [8680/10000], Train loss: 0.0098
Epoch [8690/10000], Train loss: 0.0098
Epoch [8700/10000], Train loss: 0.0098
Epoch [8710/10000], Train loss: 0.0098
Epoch [8720/10000], Train loss: 0.0098
Epoch [8730/10000], Train loss: 0.0098
Epoch [8740/10000], Train

In [69]:
waveform_reconstructed = model(waveform).detach()
Audio(waveform_reconstructed.numpy(), rate=samplerate)

In [74]:
print("Input:", waveform.shape)
bottleneck = model.encoder(waveform)
print("Bottleneck:", bottleneck.shape)
output = model.decoder(bottleneck)
print("Output:", output.shape)

Input: torch.Size([1, 302085])
Bottleneck: torch.Size([32, 296])
Output: torch.Size([1, 302082])


In [68]:
# Try on another audio
audio2_path = "audio/techno_loop.wav"
waveform2, samplerate = torchaudio.load(audio2_path)

# Convert to mono
waveform2 = torch.mean(waveform2, dim=0, keepdim=True)
print(waveform2.shape)

# Normalize so that max absolute value = 1
max_val = torch.abs(waveform2).max()
if max_val > 0:
    waveform2 = waveform2 / max_val

torch.Size([1, 1354752])


In [78]:
Audio(waveform2.numpy(), rate=samplerate)

In [77]:
waveform2_reconstructed = model(waveform2).detach()
Audio(waveform2_reconstructed.numpy(), rate=samplerate)

# 2. Trained on a dataset

In [136]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import mirdata

from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, Dataset
from torchaudio.transforms import MelSpectrogram, Resample
from tqdm import tqdm

In [137]:
dataset = mirdata.initialize("tinysol")
dataset.download()
split = dataset.get_random_track_splits([0.8, 0.2], split_names=("train", "val"))
train_ids, val_ids = split["train"], split["val"]



In [189]:
class TinySOLDataset(Dataset):
    def __init__(self, mirdata_dataset, ids):
        self.orig_sample_rate = 44100
        self.sample_rate = 16000

        self.audio_duration = 1
        
        self.mirdata_dataset = mirdata_dataset
        self.tids = ids

        # Load audio and labels
        self.audio = {}
        self.natural_labels = {}

        self.resample = Resample(orig_freq=self.orig_sample_rate, new_freq=self.sample_rate)

        n_samples = self.sample_rate * self.audio_duration
        
        for tid in tqdm(ids, desc="Loading audio"):
            track = self.mirdata_dataset.track(tid)
            audio, sr = track.audio

            assert sr == self.orig_sample_rate
            audio = self.resample(torch.Tensor(audio))

            if len(audio) >= n_samples:
                audio = audio[:n_samples]
            else:
                pad_size = n_samples - len(audio)
                audio = torch.cat([audio, torch.zeros(pad_size)])
                
            # Normalize so that max absolute value = 1
            max_val = torch.abs(audio).max()
            if max_val > 0:
                audio = audio / max_val
            
            self.audio[tid] = audio
            self.natural_labels[tid] = track.instrument_full

        # One hot encode labels
        natural_labels = np.array(list(self.natural_labels.values())).reshape(-1, 1)
        ohe = OneHotEncoder()
        one_hot_labels = ohe.fit_transform(natural_labels).toarray()
        self.labels = {k: v for k, v in zip(self.tids, one_hot_labels)}


    def __len__(self):
        return len(self.tids)

    def __getitem__(self, idx, audio_cap=4):
        # TODO audio_cap not used.
        tid = self.tids[idx]
        audio = self.audio[tid]
        labels = self.labels[tid]
        natural_labels = self.natural_labels[tid]
        
        return {"audio": audio, "labels": labels, "natural_labels": natural_labels}

train_dataset = TinySOLDataset(dataset, train_ids)
val_dataset = TinySOLDataset(dataset, val_ids)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

Loading audio: 100%|██████████| 2331/2331 [00:14<00:00, 163.75it/s]
Loading audio: 100%|██████████| 582/582 [00:03<00:00, 166.52it/s]


In [143]:
# 3 conv layers
class AudioAutoencoder(nn.Module):
    def __init__(self, input_length=16000, bottleneck_size=64):
        super().__init__()
        
        # Encoder layers
        self.enc_conv1 = nn.Conv1d(1, 8, kernel_size=9, stride=2, padding=4)  # [batch_size,8,8000]
        self.enc_relu1 = nn.ReLU()
        self.enc_conv2 = nn.Conv1d(8, 16, kernel_size=9, stride=2, padding=4)  # [batch_size,16,4000]
        self.enc_relu2 = nn.ReLU()       
        self.enc_conv3 = nn.Conv1d(16, 32, kernel_size=9, stride=2, padding=4)  # [batch_size,32,2000]
        self.enc_relu3 = nn.ReLU()              
        enc_out_len = input_length // (2*2*2) * 32  # Flattened length after strides
        self.enc_fc = nn.Linear(enc_out_len, bottleneck_size)
               
        # Decoder layers
        self.dec_fc = nn.Linear(bottleneck_size, enc_out_len)        
        self.dec_deconv1 = nn.ConvTranspose1d(32, 16, kernel_size=9, stride=2, padding=4, output_padding=1)
        self.dec_relu1 = nn.ReLU()
        self.dec_deconv2 = nn.ConvTranspose1d(16, 8, kernel_size=9, stride=2, padding=4, output_padding=1)
        self.dec_relu2 = nn.ReLU()
        self.dec_deconv3 = nn.ConvTranspose1d(8, 1, kernel_size=9, stride=2, padding=4, output_padding=1)
        self.dec_tanh = nn.Tanh()

    def encoder(self, x):
        x = self.enc_conv1(x)
        x = self.enc_relu1(x)
        x = self.enc_conv2(x)
        x = self.enc_relu2(x)
        x = self.enc_conv3(x)
        x = self.enc_relu3(x)        
        x = x.flatten(start_dim=1)  # Flatten for bottleneck
        x = self.enc_fc(x)
        return x
        
    def decoder(self, x):
        x = self.dec_fc(x)
        x = x.view(x.size(0), 32, -1)  # reshape for ConvTranspose1d
        x = self.dec_deconv1(x)
        x = self.dec_relu1(x)
        x = self.dec_deconv2(x)
        x = self.dec_relu2(x)
        x = self.dec_deconv3(x)       
        x = self.dec_tanh(x)
        return x
        
    def forward(self, x):
        z = self.encoder(x)
        x = self.decoder(z)
        return x

In [226]:
target_length = 16000
bottleneck_size = 512 #2 #16 #64
model = AudioAutoencoder(input_length=target_length, bottleneck_size=bottleneck_size)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 20
for epoch in range(num_epochs):

    # Training loop
    model.train()

    loss_train = []
    for i, batch in enumerate(train_loader):
        x = batch["audio"]
        x = x.unsqueeze(1)
        optimizer.zero_grad()    
        
        x_reconstr = model(x)        
        loss = criterion(x, x_reconstr)
        loss.backward()
        optimizer.step()
        loss_train.append(loss.item())
        
    loss_train = np.mean(loss_train)

    # Validation loop
    model.eval()  
    
    loss_val = []
    for i, batch in enumerate(val_loader):
        x = batch["audio"]
        x = x.unsqueeze(1)
        
        x_reconstr = model(x) 
        loss = criterion(x, x_reconstr)
        loss_val.append(loss.item())
    
    loss_val = np.mean(loss_val)    

    if (epoch+1) % 1 == 0:
        print(f"Epoch {epoch+1}: Train loss={loss_train:.6f}, Val loss={loss_val:.6f}")

Epoch 1: Train loss=0.142921, Val loss=0.138120
Epoch 2: Train loss=0.138805, Val loss=0.136123
Epoch 3: Train loss=0.128771, Val loss=0.117873
Epoch 4: Train loss=0.106712, Val loss=0.099142
Epoch 5: Train loss=0.090085, Val loss=0.089471
Epoch 6: Train loss=0.079120, Val loss=0.083646
Epoch 7: Train loss=0.071471, Val loss=0.080145
Epoch 8: Train loss=0.066137, Val loss=0.078570
Epoch 9: Train loss=0.062229, Val loss=0.077358
Epoch 10: Train loss=0.059133, Val loss=0.075840
Epoch 11: Train loss=0.056484, Val loss=0.075635
Epoch 12: Train loss=0.053841, Val loss=0.074873
Epoch 13: Train loss=0.051283, Val loss=0.074402
Epoch 14: Train loss=0.048803, Val loss=0.074766
Epoch 15: Train loss=0.046013, Val loss=0.073973
Epoch 16: Train loss=0.043450, Val loss=0.074646
Epoch 17: Train loss=0.040857, Val loss=0.074411
Epoch 18: Train loss=0.038034, Val loss=0.074450
Epoch 19: Train loss=0.035502, Val loss=0.074382
Epoch 20: Train loss=0.032938, Val loss=0.073676


In [231]:
batch = next(iter(val_loader))
audio_example = batch["audio"]
audio_example = audio_example.unsqueeze(1)
audio_example.shape

torch.Size([32, 1, 16000])

In [232]:
audio_example_reconstr = model(audio_example).detach()

In [239]:
Audio(audio_example[4].squeeze().numpy(), rate=16000)

In [238]:
Audio(audio_example_reconstr[4].squeeze().numpy(), rate=16000)