In [None]:
import librosa 
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import soundfile as sf
import os


In [None]:
def load_data(xpath,ypath):
    def normalize(qt,target):
        mfcc_min = torch.min(qt)
        mfcc_max = torch.max(qt)
        qt = (qt - mfcc_min) / (mfcc_max - mfcc_min)
        #qt = (qt -0.5)/0.5
        
        mfcc_min2 = torch.min(target)
        mfcc_max2 = torch.max(target)
        target = (target - mfcc_min2) / (mfcc_max2 - mfcc_min2)
        #target = (target -0.5)/0.5
        return qt,target,mfcc_min2,mfcc_max2
    

    song,sr = librosa.load(xpath)
    mel = librosa.feature.melspectrogram(y=song,sr=sr,n_fft=2048,hop_length=512)
    mel= [librosa.power_to_db(mel)]
    """ librosa.display.specshow(data=mel,x_axis="time",y_axis="mel")
    plt.colorbar(format='%+2.0f dB') """
    qt= torch.Tensor(mel) 
    
    drum,sr2=librosa.load(ypath)
    label = librosa.feature.melspectrogram(y=drum,sr=sr2,n_fft=2048,hop_length=512)
    label= [librosa.power_to_db(label)]
    """ librosa.display.specshow(data=label,x_axis="time",y_axis="mel")
    plt.colorbar(format='%+2.0f dB') """
    target = torch.Tensor(label) 
    

     
    
        
    return normalize(qt,target)

a,target,_,_= load_data("test/012146 [music].wav","test/012146 [drums].wav")    

class CustomDataset(Dataset):
    def __init__(self, dir1,dir2, transform=None, target_transform=None):
        self.music =dir1
        self.drum = dir2
        self.musicdir = os.listdir(dir1)
        self.drumdir = os.listdir(dir2)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.musicdir)

    def __getitem__(self, idx):
        music_path = os.path.join(self.music, self.musicdir[idx])
        drum_path = os.path.join(self.drum, self.drumdir[idx])
        mus,dr,m1,m2 = load_data(music_path,drum_path)
        
        return mus,dr,m1,m2
    
training_data = CustomDataset("D:\Seperated\BassTrainData","D:\Seperated\DrumTrainData")



In [None]:
class Network(nn.Module):
    def __init__(self,latspa):
        super(Network,self).__init__()
        self.latspa =latspa
        self.cnn= nn.Sequential(
            nn.Conv2d(1,4,3,1),
            nn.MaxPool2d(2),
            nn.Conv2d(4,16,3,1),
            nn.MaxPool2d(2),
            nn.Conv2d(16,1,3,1),
            nn.MaxPool2d(2),
            
        )
       
            
        self.rnn=nn.LSTM(53,128)
        self.rnn2=nn.LSTM(128,128)
        
        
        
        
        self.encoder=nn.Sequential(
            
            #TRIED LEAKY RELU 0.1 BUT GAVE NEGATIVE VALUES ON KL DIVERGENCE, WHICH IS MU AND SIGMA WHIC IS UNACCEPTABLE
            nn.Linear(128,512),
            nn.LeakyReLU(),
            nn.Linear(512,64),
            nn.LeakyReLU(),
            nn.Linear(64,32),
            nn.LeakyReLU(),
            nn.Linear(32,latspa)
    
        )
        self.mu = nn.Sequential(
            nn.Linear(latspa,latspa),
            nn.ReLU()
        )
        self.dev = nn.Sequential(
            nn.Linear(latspa,latspa),
            nn.ReLU()
        )
        
        self.convtrans=nn.Sequential(
            nn.ConvTranspose2d(1,4,4,2,1,output_padding=(0,1)),
            nn.ConvTranspose2d(4,4,4,2,1,output_padding=(0,1)),
            nn.ConvTranspose2d(4,1,4,2,1,output_padding=(0,1)),
            nn.Sigmoid()
        )
        
        self.decoder=nn.Sequential(
            
            
            nn.Linear(latspa,32),
            nn.Linear(32,64),
            nn.Linear(64,128),
            nn.Linear(128,848)
            
           
            
        )
    
    def encode(self,nm):
        out = self.cnn(nm)
        out= out.transpose(1,0)[0]
        out = out.transpose(1,0)
        out,_ = self.rnn(out)
        out,_ = self.rnn2(out)
        out=out[0]
        a = self.encoder(out)
        mu = self.mu(a)
        dev = self.dev(a)


        return mu,dev
    
    def reparametrize(self, mu, log_var):

        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self,mfcc):
        
        mu,dev=self.encode(mfcc)
        z = self.reparametrize(mu,dev)
        temp = self.decoder(z)
        temp=temp.view(16,1,16,53)
        temp = self.convtrans(temp)
        temp= temp.transpose(1,0)[0]
        return  temp,mu,dev
    
    def sample(self):
        with torch.no_grad():
            z = torch.randn(1,self.latspa)
            return self.decoder(z)


In [None]:
class lf(nn.Module):
    def __init__(self,beta):
        super(lf, self).__init__()
        self.beta=beta
    def forward(self,a,mu,det,target):
        target= target.transpose(1,0)[0]
        L1= nn.functional.mse_loss(a,target)
        L2=torch.mean((0.5*(torch.pow(det,2) + torch.pow(mu,2) - torch.log(det+1e-7) - 1)))
    
        return (L1 +self.beta*L2,L2)

   
    
    


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 3e-4
latspa=8
model=Network(latspa)
optimizer=torch.optim.Adam(params=model.parameters(),lr=lr)
epoch=60
beta=0.1
loss_fn =lf(beta)
batch_size =16
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
        

            

            
def train_loop(x,y,ep,i):
   
    x=x.to(device)
    y=y.to(device)
    

    a,mu,det= model(x)
            
            
           
    optimizer.zero_grad()
    loss=loss_fn(a,mu,det,y)
            
    loss[0].backward()
    optimizer.step()
    
    print(f"Epoch: {ep+1},batch: {i+1}, loss: {loss[0].item()}, KL: {loss[1].item()}")
                 
for ep in range(epoch):
    for i, data in enumerate(train_dataloader):
        if data[0].shape[0]==batch_size:
            train_loop(data[0],data[1],ep,i)
         
        
    

In [None]:
a= next(iter(train_dataloader))

b=model(a[0])
audio = b[0][0]

mfccmin = a[2][0]
mfccmax = a[3][0]
#audio = audio *0.5 + 0.5
audio =audio * (mfccmax-mfccmin) +mfccmin

aud=audio.detach().cpu().numpy()
aud=librosa.db_to_power(aud)
aud = librosa.feature.inverse.mel_to_audio(M=aud)
sf.write('stereo_file_MoreData.wav', data= aud, samplerate=22050,subtype='PCM_24') 

In [None]:

song,sr = librosa.load("test/012146 [drums].wav")
mel = librosa.feature.melspectrogram(y=song,sr=sr,n_fft=2048,hop_length=512)
mel= librosa.power_to_db(mel)
mel=librosa.db_to_power(mel)
mel = librosa.feature.inverse.mel_to_audio(M=mel)
sf.write('test.wav', data= mel, samplerate=22050,subtype='PCM_24') 