In [None]:
from torchaudio import models
import torchaudio
import torchaudio.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import IPython.display as ipd
import librosa
import librosa.display
import matplotlib.pyplot as plt

import dataloader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
%%time
ms = dataloader.MusicSet(dataloader.TRACKS, device=device)

In [None]:
rnn = nn.RNN(1025, 10, 1, batch_first=True).to(device)
rnn2 = nn.RNN(1035, 10, 1, batch_first=True).to(device)
idx = 100
out, h = rnn(ms[idx][0].t()[None,])
print(out.shape)
out_r = ms[idx][0].t()[0][None].repeat(1291, 1).unsqueeze(0)

out = torch.zeros(10).to(device)

finale = torch.zeros(1025, 10)

lol = ms[idx][0].t()[0]

middle = torch.cat([lol, out + torch.randn(10).to(device)])

for i in range(1025):
    out, h = rnn2(middle[None, None, ], h)
    out = out[0][0]
    finale[i] = out
    middle = torch.cat([lol, out + torch.randn(10).to(device)])

out_s = torch.sum(finale, 1)
print(out_s.shape, out_s)

In [None]:
class RNN_VAE(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, encoder_dim=64):
        super(RNN_VAE, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.rnn1 = nn.RNN(input_size, hidden_size, num_layers, batch_first=True, nonlinearity="tanh")
        self.rnn2 = nn.RNN(hidden_size * 2, input_size, num_layers, batch_first=True, nonlinearity="tanh")
        self.fc0 = nn.Linear(hidden_size, encoder_dim)
        self.fc1 = nn.Linear(encoder_dim, hidden_size)

    def forward(self, x):
        m = x.mean()
        x = x
        s0 = torch.zeros(self.num_layers, 1, self.hidden_size).to(device)
        s1 = torch.zeros(self.num_layers, 1, self.input_size).to(device)
        middle, s0 = self.rnn1(x, s0)
        middle_0 = middle.squeeze()[x.shape[1] - 1]
        middle_1 = self.fc0(middle_0)
        middle_2 = torch.sigmoid(middle_1)
        middle_3 = self.fc1(middle_1)
        #print(middle.shape)
        #middle_exp = middle_3.unsqueeze(0).repeat(x.shape[1], 1).unsqueeze(0) * torch.randn(middle_3.shape[0]).to(device).unsqueeze(0)
        #output, s1 = self.rnn2(middle_exp, s1)
        output = torch.zeros(x.shape[1], self.input_size, requires_grad=True).to(device)
        middle_exp = torch.cat([middle_3, torch.zeros(self.hidden_size).to(device)])
        for i in range(x.shape[1]):
            middle_exp = torch.cat([middle_3, middle[0][i]])
            y, s1 = self.rnn2(middle_exp[None, None, ], s1)
            output[i] = y
        return output

In [None]:
model = RNN_VAE(1025, 64, 1, 64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
#criterion = nn.MSELoss()
criterion = nn.CrossEntropyLoss()

In [None]:
def train(num_epochs):
    for epoch in range(num_epochs):
        for i, x in enumerate(ms):
            x_p = F.normalize(x[0].t()[None,])
            out = model(x_p)
            optimizer.zero_grad()
            loss = criterion(out, x_p[0])
            loss.backward()
            _ = nn.utils.clip_grad_norm_(model.parameters(), 25.)
            optimizer.step()
            i += 1
            if i % 1 == 0:
                print(f"{epoch},{i}: {loss}")

In [None]:
train(250)

In [None]:
test_song_orig = ms[420][0]
print(test_song_orig.shape)
res_orig = ms.from_spectro(test_song_orig)
print(res_orig.shape)


print(test_song_orig)
print(test_song_orig.max(), test_song_orig.min(), test_song_orig.mean())

p = torch.sum(test_song_orig.t().abs(), 1)
plt.plot(p.cpu().detach())
ipd.Audio(res_orig.cpu(), rate=22050)

In [None]:
test_song = ms[420][0].t()[None,]
spectro = model.forward(test_song).t().abs()
print(spectro.shape)

res = ms.from_spectro(spectro)
print(res.shape)
print(spectro.abs())


print(spectro.max(), spectro.min(), spectro.mean())

lol = spectro[1:] - spectro[0:-1]
print(lol.max(), lol.min())

p = torch.sum(spectro.t().abs(), 1)
plt.plot(p.cpu().detach())
ipd.Audio(res.cpu().detach(), rate=22050)

In [None]:
librosa.display.specshow(test_song_orig.cpu().detach().numpy(), sr=22050, hop_length=512, x_axis='time', y_axis='log')

In [None]:
librosa.display.specshow(spectro.cpu().detach().numpy(), sr=22050, hop_length=512, x_axis='time', y_axis='log')