In [None]:
%pip install librosa python-dotenv pydot
%pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

In [None]:
from torchaudio import models
import torchaudio
import torchaudio.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import IPython.display as ipd
import librosa
import librosa.display
import matplotlib.pyplot as plt
import random

import dataloader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
%%time
ms = dataloader.MusicSet(dataloader.TRACKS, device=device)

In [None]:
def pre_process(sample):
    x_p = sample[0].t()[None, ]
    a_max = x_p.abs().max()
    x_p /= a_max
    F.threshold(x_p, 1e-2, 1e-2, inplace=True)
    x_p *= a_max
    torch.log10(x_p, out=x_p)
    return x_p, None

def pre_process(sample):
    x_p = sample[0].t()[None, ]
    x_p /= x_p.abs().max()
    F.threshold(x_p, 1e-4, 0, inplace=True)
    return x_p, None

def post_process(sample):
    return 10**(sample)

def post_process(sample, _):
    return sample

In [None]:
class RNN_VAE(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, encoder_dim=64):
        super(RNN_VAE, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.rnn1 = nn.LSTM(input_size,
                           hidden_size,
                           num_layers,
                           batch_first=True,
                           bidirectional=True)
        self.rnn2 = nn.LSTM(hidden_size,
                           input_size,
                           num_layers,
                           batch_first=True)
        #self.fc0 = nn.Linear(hidden_size, encoder_dim)
        #self.fc1 = nn.Linear(encoder_dim, hidden_size)

    def encode(self,x):
        # hidden state
        middle, _ = self.rnn1(x)
        middle = middle.view(1, x.shape[1], 2, self.hidden_size)[:, :, 0, :]
        return middle

    def decode(self, inp_dec, forced_teaching=False):
        if not forced_teaching:
            L = inp_dec.shape[1]
            inp_dec = inp_dec[:, -1].repeat(1, L, 1)
        #middle = self.fc0(inp_dec)
        middle = torch.tanh(inp_dec)
        #middle = self.fc1(middle)
        #middle = torch.relu(middle)
        y, _ = self.rnn2(middle)
        return torch.tanh(y)

    def forward(self, x):
        middle = self.encode(x)
        out = self.decode(middle)
        return out

In [None]:
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0

model = RNN_VAE(1025, 64, 3, 64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50], gamma=0.1)
criterion = nn.MSELoss()

In [None]:
TEACHING_RATE = 1
def train(num_epochs, teaching_rate=TEACHING_RATE):
    freq_mask = transforms.FrequencyMasking(80)
    for epoch in range(num_epochs):
        for i, x in enumerate(ms):
            x_p, _ = pre_process(x)
            #x_p = sample[0].t()[None, ]
            x_p = freq_mask(x_p)
            mid = model.encode(x_p)
            forced_teaching = random.random() < teaching_rate
            out = model.decode(mid, forced_teaching)
            optimizer.zero_grad()
            loss = criterion(out, x_p)
            loss.backward()
            _ = nn.utils.clip_grad_norm_(model.parameters(), 25.)
            optimizer.step()
            if i % 1 == 0:
                print(" "*30, end='\r')
                print(f"{epoch},{i}: {loss:.4e}", end='\r')
            i += 1
        scheduler.step()

In [None]:
#model = torch.load("trained_model")

In [None]:
#train(5)

In [None]:
SAMPLE = 7532

test_song_orig = ms[SAMPLE][0]
test_song_orig /= test_song_orig.max()
print(test_song_orig.shape)
res_orig = ms.from_spectro(test_song_orig)
print(res_orig.shape)


print(test_song_orig.cpu().detach().numpy())
print(f"Max: {test_song_orig.max().cpu().detach().numpy()},\
      min: {test_song_orig.min().cpu().detach().numpy()},\
      mean: {test_song_orig.mean().cpu().detach().numpy()}")

p = torch.sum(test_song_orig.t().abs(), 1)
plt.plot(p.cpu().detach())
ipd.Audio(res_orig.cpu(), rate=22050)

In [None]:
test_song = ms[SAMPLE]
test_song, a_max = pre_process(test_song)
code = model.encode(test_song)
#plt.imshow(code.squeeze().cpu().detach().numpy())
spectro_forced = model.decode(code, True).squeeze().t()
spectro_ = model.decode(code, False).squeeze().t()
# print(spectro_.shape)
spectro = post_process(spectro_forced, None)
spectro = spectro.relu()
#spectro = F.threshold(spectro, 1e-4, 1e-15)
res = ms.from_spectro(spectro)
# print(res.shape)

#print(spectro.cpu().detach().numpy())
print(f"Max: {spectro.max().cpu().detach().numpy()},\
      min: {spectro.min().cpu().detach().numpy()},\
      mean: {spectro.mean().cpu().detach().numpy()}")

p = torch.sum(spectro.t().abs(), 1)
plt.plot(p.cpu().detach())
ipd.Audio(res.cpu().detach(), rate=22050)

In [None]:
librosa.display.specshow(test_song_orig.log10().cpu().detach().numpy(), sr=22050, hop_length=512, x_axis='time', y_axis='log')
plt.colorbar()

In [None]:
librosa.display.specshow(spectro.log10().cpu().detach().numpy(), sr=22050, hop_length=512, x_axis='time', y_axis='log')
plt.colorbar()

In [None]:
#torch.save(model, "./trained_model")

In [None]:
import soundfile
#soundfile.write("result_2_3layer_bidim_gru_0-2.wav",10*res.cpu().detach().numpy(), 22050)

In [None]:
## Testing
import os
test_set_files = os.listdir("../test_set")
test_set_files = ["../test_set/"+x for x in test_set_files]
losses = [None]*len(test_set_files)

def load_file(filename):
    audio, sr = torchaudio.load(test_sample_file)
    audio = audio.to(device)
    audio = audio.mean(dim=0) # to mono
    audio = torchaudio.functional.resample(audio, sr, 22050)
    spectro = ms.to_spectro(audio)
    return spectro, audio

In [None]:
for i, test_sample_file in enumerate(test_set_files):
    spectro, _ = load_file(test_sample_file)
    inp, _ = pre_process(spectro[None,])
    code = model.encode(inp)
    out = model.decode(code, True)
    out = post_process(out, None)
    out = out[0].t()
    losses[i] = criterion(out, spectro)

In [None]:
plt.plot(sorted(losses[i]))

In [None]:
SAMPLE_TEST = 22
test_sample_file = test_set_files[SAMPLE_TEST]

spectro, audio = load_file(test_sample_file)
inp, _ = pre_process(spectro[None,])
code = model.encode(inp)
out = model.decode(code, True)
out = post_process(out, None)
out = out[0].t().relu()
librosa.display.specshow(spectro.cpu().detach().numpy(), sr=22050, hop_length=512, x_axis='time', y_axis='log')
print(spectro.min(), spectro.max(), spectro.max())
plt.colorbar()
plt.show()

In [None]:
librosa.display.specshow(out.cpu().detach().numpy(), sr=22050, hop_length=512, x_axis='time', y_axis='log')
print(out.min(), out.max(), out.max())
plt.colorbar()
plt.show()

In [None]:
ipd.Audio(audio.cpu().detach(), rate=22050)

In [None]:
out_audio = ms.from_spectro(out)
ipd.Audio(out_audio.cpu().detach(), rate=22050)