In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import data_generator as dg
from datetime import datetime

#print time
print(datetime.now().strftime("%H:%M:%S"))

C = 3
L = 1025
amount = 50
files = 10
inst = 'Piano'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.seed = 2

class DNN(nn.Module):

    def __init__(self, C, L):
        super(DNN, self).__init__()

        self.layer1 = nn.Linear((2*C+1)*L, L, dtype=torch.float64)
        self.layer2 = nn.Linear(L, L, dtype=torch.float64)
        self.layer3 = nn.Linear(L, L, dtype=torch.float64)
        self.layer4 = nn.Linear(L, L, dtype=torch.float64)
        self.layer5 = nn.Linear(L, L, dtype=torch.float64)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        x = F.relu(self.layer4(x))
        x = self.layer5(x)
        return x

dnn_real = DNN(C, L).to(device)
dnn_imag = DNN(C, L).to(device)

criterion = nn.MSELoss(reduction='sum')


def complex_relu(complex_input):
    return torch.view_as_complex(torch.clamp(torch.view_as_real(complex_input),min = 0))

class DNN_ra(nn.Module):

    def __init__(self, C, L):
        super(DNN_ra, self).__init__()

        self.layer1 = nn.Linear((2*C+1)*L, L, dtype=torch.complex128)
        self.layer2 = nn.Linear(L, L, dtype=torch.complex128)
        self.layer3 = nn.Linear(L, L, dtype=torch.complex128)
        self.layer4 = nn.Linear(L, L, dtype=torch.complex128)
        self.layer5 = nn.Linear(L, L, dtype=torch.complex128)

    def forward(self, x):
        x = complex_relu(self.layer1(x))
        x = complex_relu(self.layer2(x))
        x = complex_relu(self.layer3(x))
        x = complex_relu(self.layer4(x))
        x = self.layer5(x)
        return x

dnn = DNN_ra(C, L)
dnn.load_state_dict(torch.load('../model_23_06_12_2340_EPOCH_6'))
dnn.eval()
dnn.to(device=device)

10:11:37


DNN_ra(
  (layer1): Linear(in_features=7175, out_features=1025, bias=True)
  (layer2): Linear(in_features=1025, out_features=1025, bias=True)
  (layer3): Linear(in_features=1025, out_features=1025, bias=True)
  (layer4): Linear(in_features=1025, out_features=1025, bias=True)
  (layer5): Linear(in_features=1025, out_features=1025, bias=True)
)

# load model imag
dnn_real.load_state_dict(torch.load('../DNN_leastSquares_real.pt'))
dnn_imag.load_state_dict(torch.load('../DNN_leastSquares_imag.pt'))
dnn_real.eval()
dnn_imag.eval()


data, label = [], []
for f in dg.data_frame(files, amount, C = C, L = L, mix_amount = 4, device='cpu', directory='../Data/slakh2100_flac_redux/test'):
    positive, negative = dg.search_dicts(f, inst)
    if inst in positive:
        # Yield positive with inst as label and negative with a zero_like as label
        for instrument in positive: 
            data.append(torch.view_as_complex(f[instrument]).real)
            label.append(torch.view_as_complex(f[inst][:, :,C]).real)
        iter = 0
        for instrument in negative:
            data.append(torch.view_as_complex(f[instrument]).real)
            label.append(torch.view_as_complex(torch.zeros_like(f[inst][:, :,C])).real)
            iter += 1
            if iter == len(positive):
                break
N = (len(data)*amount)
print(N)
data = torch.stack(data).to(device)
label = torch.stack(label).to(device)
input = data.reshape(-1, L*(2*C+1))
target = label.reshape(-1, L)# 

criterion(dnn_real(input), target) # imag 0.94, real 3.426

In [2]:
def generate_overlapping_array(original_array):
    new_array = []
    array_length = len(original_array)

    for i in range(array_length):
        end_index = i + 7
        if end_index > array_length:
            break
        new_array.append(original_array[i:end_index])

    return new_array

In [3]:
# Generate entire song for testing
import soundfile as sf
import numpy as np
from scipy.signal import stft, istft

def generate_song(C, L, device='cpu'):
    song, sr = sf.read('test.mp3')
    song = np.mean(song,axis=1)
    _, _, Zxx = stft(song, fs=sr, nperseg=L*2-2)
    print(Zxx.shape)
    Zxx = np.array(generate_overlapping_array(Zxx.T))
    Zxx = Zxx.reshape(-1, L*(2*C+1))
    Zxx = torch.from_numpy(Zxx).to(device)
    Zxx_real = dnn(Zxx).detach().cpu().numpy()
    #Zxx_imag = dnn_imag(Zxx.imag).detach().cpu().numpy()
    Zxx = Zxx_real #+ 1j*Zxx_imag
    _, song = istft(Zxx.T, fs=sr, nperseg=L*2-2)
    sf.write('SSE.wav', song, sr)

In [4]:
generate_song(C, L, device=device)

(1025, 9133)
