In [136]:
from tqdm import tqdm
import numpy as np
import librosa
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable

In [6]:
x_raw = "MiniMega_low.mp3"
y_raw = "MiniMega_norm.mp3"
x, sr_x = librosa.load(x_raw)
y, sr_y = librosa.load(y_raw)
x_D = np.abs(librosa.stft(x))
y_D = np.abs(librosa.stft(y))



In [7]:
X = torch.from_numpy(x_D)
Y = torch.from_numpy(y_D)

In [14]:
width = X.shape[1]

In [8]:
X.shape

torch.Size([1025, 25747])

In [41]:
batch_norm = nn.BatchNorm1d(width)
demo_x = batch_norm(X)
test_slice = demo_x[0].unsqueeze(0)

In [105]:
class Generator(nn.Module):
    def __init__(self, X):
        super(Generator, self).__init__()

        self.init_size = X.shape[1]
        self.conv1 = nn.Conv1d(self.init_size, 15000, 1, stride=3)
        self.relu1 = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = nn.Conv1d(15000, 7500, 1, stride=3)
        self.relu2 = nn.LeakyReLU(0.2, inplace=True)
        self.fc1 = nn.Linear(7500, 15000)
        self.relu3 = nn.LeakyReLU(0.2, inplace=True)
        self.fc2 = nn.Linear(15000, self.init_size)
        self.relu4 = nn.LeakyReLU(0.2, inplace=True)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.conv1(x.unsqueeze(2)).squeeze(2)
        x = self.relu1(x)
        x = self.conv2(x.unsqueeze(2)).squeeze(2)
        x = self.relu2(x)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        x = self.relu4(x)
        x = self.tanh(x)
        return x

In [106]:
gen = Generator(X)

In [107]:
x_ = gen(X[0].unsqueeze(0))

In [108]:
x_.shape

torch.Size([1, 25747])

In [115]:
class Discriminator(nn.Module):
    def __init__(self, X):
        super(Discriminator, self).__init__()
        
        self.init_size = X.shape[1]
        self.fc1 = nn.Linear(self.init_size, 10000)
        self.fc2 = nn.Linear(10000, 2)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

In [116]:
dis = Discriminator(X)

In [117]:
out = dis(x_)

In [121]:
adversarial_loss = torch.nn.BCELoss()
generator = Generator(X)
discriminator = Discriminator(X)
learning_rate = 1e-6
optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

In [149]:
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
for i in tqdm(range(X.shape[0])):
    # Generator Training
    ground_truth = Y[i]
    fake = torch.tensor([0.0, 1.0])
    valid = torch.tensor([1.0, 0.0])
    input_slice = X[i]
    optimizer_G.zero_grad()
    gen_slice = generator(X[i].unsqueeze(0))
    g_loss = adversarial_loss(discriminator(gen_slice), fake)
    g_loss.backward()
    optimizer_G.step()
    
    # Discriminator Training
    optimizer_D.zero_grad()
    real_loss = adversarial_loss(discriminator(ground_truth), valid)
    fake_loss = adversarial_loss(discriminator(gen_slice.detach()), fake)
    d_loss = (real_loss + fake_loss)/2
    d_loss.backward()
    optimizer_D.step()

100%|██████████| 1025/1025 [4:26:43<00:00, 15.61s/it] 


In [151]:
gen_data = []
for i in tqdm(range(X.shape[0])):
    # Generator Training
    input_slice = X[i]
    gen_slice = generator(input_slice.unsqueeze(0))
    gen_data.append(gen_slice.detach().numpy()[0])


D = np.array(gen_data)
reconstructed_audio = librosa.griffinlim(D)
import soundfile
soundfile.write('test.wav', reconstructed_audio, 22050)

100%|██████████| 1025/1025 [32:22<00:00,  1.90s/it]
