In [None]:
%pip install librosa python-dotenv pydot
%pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

In [None]:
from torchaudio import models
import torchaudio
import torchaudio.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import IPython.display as ipd
import librosa
import librosa.display
import matplotlib.pyplot as plt
import random
import numpy as np

import dataloader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
%%time
ms = dataloader.MusicSet(dataloader.TRACKS, device=device)

In [None]:
th = nn.Threshold(1, 1e-12)
lol = torch.Tensor([0., 0.5, 1., 1.5, 2.])
print(th(lol))

In [None]:
t = ms[0][0]
print(t.shape)

In [None]:
cv0 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(9, 9), stride = (2, 3)).to(device)
t_0 = cv0(t[None][None])
print(t_0.shape)
cv1 = nn.Conv2d(in_channels=4, out_channels=64, kernel_size=(9, 9), stride = (2, 3)).to(device)
t_1 = cv1(t_0)
print(t_1.shape)
cv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(9, 9), stride = (2, 3)).to(device)
t_2 = cv2(t_1)
print(t_2.shape)

rnn1 = nn.LSTM(122 * 44, 128, 2).to(device)
t_3, _ = rnn1(t_2.view(1, -1, 122 * 44))
print(t_3.shape)

rnn2 = nn.LSTM(128, 122 * 44, 2).to(device)
t_4p, _ = rnn2(t_3)
t_4 = t_4p.view(1, 128, -1, 44)
print(t_4.shape)

cv3 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=(9,9), stride = (2, 3)).to(device)
t_5 = cv3(t_4)
print(t_5.shape)

cv4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(9,9), stride = (2, 3)).to(device)
t_6 = cv4(t_5)
print(t_6.shape)

cv5 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=(9,9), stride = (2, 3)).to(device)
t_7 = cv5(t_6)
print(t_7.shape)

In [None]:
a, b = torch.zeros([1, 1291, 1025]), torch.zeros([1, 1291, 1025])
c = torch.sum((a - b) ** 2, 1).squeeze()
print(c.shape)

In [None]:
def pre_process(sample):
    x_p = sample[0].t()[None, ]
    F.threshold(x_p, 1e-4, 0, inplace=True)
    #torch.log10(x_p, out=x_p)
    a_max = x_p.abs().max()
    x_p /= a_max
    return x_p, a_max

def post_process(sample, a_max):
    return sample*a_max
    #return 10**(sample*a_max)

In [None]:
class RNN_VAE(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, encoder_dim=64):
        super(RNN_VAE, self).__init__()
        self.cv0 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(9,9), stride = (2, 3)).to(device)
        self.cv1 = nn.Conv2d(in_channels=4, out_channels=64, kernel_size=(9,9), stride = (2, 3)).to(device)
        self.cv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(9,9), stride = (2, 3)).to(device)
        self.cv3 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(9,9), stride = (2, 3)).to(device)
        self.cv4 = nn.ConvTranspose2d(in_channels=64, out_channels=4, kernel_size=(9,9), stride = (2, 3)).to(device)
        self.cv5 = nn.ConvTranspose2d(in_channels=4, out_channels=1, kernel_size=(9,9), stride = (2, 3)).to(device)
        self.rnn1 = nn.LSTM(122 * 44, 128, 2).to(device)
        self.rnn2 = rnn2 = nn.LSTM(128, 122 * 44, 2).to(device)

    def encode(self,x):
        # hidden state
        x = self.cv0(x)
        x = self.cv1(x)
        x = self.cv2(x)
        x, _ = self.rnn1(x.view(1, -1, 122 * 44))
        return torch.relu(x)

    def decode(self, x):
        x, _ = self.rnn2(x)
        x = x.view(1, 128, -1, 44)
        x = self.cv3(x)
        x = self.cv4(x)
        x = self.cv5(x)
        return x

    def forward(self, x):
        middle = self.encode(x[0].t()[None][None])
        out = self.decode(middle)
        return out

In [None]:
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 0
PENALITY = 0.5

model = RNN_VAE(1025, 64, 3, 64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50], gamma=0.1)
criterion = nn.MSELoss()

In [None]:
TEACHING_RATE = 1
def train(num_epochs, teaching_rate=TEACHING_RATE):
    for epoch in range(num_epochs):
        for i, x in enumerate(ms):
            x_p, _ = pre_process(x)
            #x_p = sample[0].t()[None, ]
            out = model.forward(x_p)[0][0].t()[None]
            x_p = x_p[:, :out.shape[1], :out.shape[2]]
            optimizer.zero_grad()
            #loss = torch.sum(torch.sum((out - x_p) ** 2, 1).squeeze() * torch.linspace(1, 20, out.shape[2]).to(device) ** (1/2))
            #loss = torch.sum((out[None, :] - x_p))
            loss = criterion(out, x_p)
            loss.backward()
            # _ = nn.utils.clip_grad_norm_(model.parameters(), 25.)
            optimizer.step()
            if i % 1 == 0:
                print(" "*28, end='\r')
                print(f"{epoch},{i}: {loss}", end='\r')
            i += 1
            break
        scheduler.step()

In [None]:
train(20000)

In [None]:
SAMPLE = 0

test_song_orig = ms[SAMPLE][0]
print(test_song_orig.shape)
res_orig = ms.from_spectro(test_song_orig)
print(res_orig.shape)


print(test_song_orig.cpu().detach().numpy())
print(f"Max: {test_song_orig.max().cpu().detach().numpy()},\
      min: {test_song_orig.min().cpu().detach().numpy()},\
      mean: {test_song_orig.mean().cpu().detach().numpy()}")

p = torch.sum(test_song_orig.t().abs(), 1)
plt.plot(p.cpu().detach())
ipd.Audio(res_orig.cpu(), rate=22050)

In [None]:
test_song = ms[SAMPLE]
test_song, a_max = pre_process(test_song)
spectro = model.forward(test_song)
# print(spectro_.shape)
spectro = post_process(spectro, a_max).abs()[0][0]
#spectro = F.threshold(spectro, 1e-3, 0)

res = ms.from_spectro(spectro)
# print(res.shape)

#print(spectro.cpu().detach().numpy())
print(f"Max: {spectro.max().cpu().detach().numpy()},\
      min: {spectro.min().cpu().detach().numpy()},\
      mean: {spectro.mean().cpu().detach().numpy()}")

p = torch.sum(spectro.t().abs(), 1)
plt.plot(p.cpu().detach())
ipd.Audio(res.cpu().detach(), rate=22050)

In [None]:
librosa.display.specshow(test_song_orig.cpu().detach().numpy(), sr=22050, hop_length=512, x_axis='time', y_axis='log')
plt.colorbar()

In [None]:
librosa.display.specshow(spectro.cpu().detach().numpy(), sr=22050, hop_length=512, x_axis='time', y_axis='log')
plt.colorbar()

plt.savefig("lol")

In [None]:
import soundfile
soundfile.write("result_2_3layer_bidim_gru_0-2.wav",10*res.cpu().detach().numpy(), 22050)

In [None]:
torch.save(model.state_dict(), "CNN_LSTM_model_0")

In [None]:
model.load_state_dict(torch.load("CNN_LSTM_model_0"))

In [None]:
pytorch_total_params = list((p.numel()) for p in model.parameters())
print(pytorch_total_params)

In [None]:
[(x, model.state_dict()[x].shape, prod(torch.Tensor(model.state_dict()[x].shape))) for x in model.state_dict()]