Libraries required for developing the project

In [None]:
#@title Libraries required
%%capture
!pip install laion-clap
!pip install transformers==4.30.2
!pip install soundfile
!pip install librosa
!pip install torchlibrosa
!pip install ftfy
!pip install braceexpand
!pip install webdataset
!pip install wget
!pip install wandb
!pip install llvmlite
!pip install scipy
!pip install scikit-learn
!pip install pandas
!pip install h5py
!pip install tqdm
!pip install regex
!pip install torch
!pip install pytube
!pip install pydub

In [1]:
#@title Import of CLAP for embedding
%%capture
import laion_clap
import torch
from huggingface_hub import hf_hub_download
import librosa
import os
import numpy as np

model = laion_clap.CLAP_Module(enable_fusion=False, amodel= 'HTSAT-base').to(device='cuda')
dataset_path = hf_hub_download(repo_id="lukewys/laion_clap", filename="music_speech_audioset_epoch_15_esc_89.98.pt")
model.load_ckpt(dataset_path)
# quantization
def int16_to_float32(x):
    return (x / 32767.0).astype(np.float32)


def float32_to_int16(x):
    x = np.clip(x, a_min=-1., a_max=1.)
    return (x * 32767.).astype(np.int16)


In [None]:
#@title Instructions to load the code and the datase
pat = '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'
!git clone https://{pat}@github.com/LorenzoFrangella/Neural-Networks-Mastrandrea-Frangella/
!mv Neural-Networks-Mastrandrea-Frangella/download .
!mv Neural-Networks-Mastrandrea-Frangella/tmp .
!mv Neural-Networks-Mastrandrea-Frangella/new_balanced.csv .
!mv Neural-Networks-Mastrandrea-Frangella/validation_set.csv .
!mv Neural-Networks-Mastrandrea-Frangella/validation .
!mv Neural-Networks-Mastrandrea-Frangella/class_labels_indices.csv .
!mv Neural-Networks-Mastrandrea-Frangella/eval_segments.csv .
!rm ./download/.DS_Store
!rm ./validation/.DS_Store

In [None]:
#@title Cleanup (!!!!!)
!rm -rf ./download
!rm -rf ./tmp
!rm -rf ./validation
!rm -rf Neural-Networks-Mastrandrea-Frangella/




In [2]:
#@title Utility functions to manage the data
from  pytube import YouTube
import os
from pydub import AudioSegment
import csv
import random
import math
import torch
import torchaudio

def cut_audio(input_file, start_time, end_time):
    audio = AudioSegment.from_file(input_file)
    audio = audio.set_frame_rate(32000)
    cut_audio = audio[start_time:end_time]
    cut_audio.export(input_file, format="mp3")

def get_mixture_audio(audio1,audio2):

    E1 = torch.square(torch.norm(audio1,p=2))
    E2 = torch.square(torch.norm(audio2,p=2))

    alpha = torch.sqrt(E1/E2)

    x = audio1 + alpha * audio2
    return x

def get_audio_clip(video_id, start, end, download=True,path='./download'):

    if download:

        if f"{video_id}.mp3" not in os.listdir(path):
            video_url = f"https://www.youtube.com/watch?v={video_id}"
            selected_video = YouTube(video_url)
            audio = selected_video.streams.filter(only_audio = True).first()
            path_dest = audio.download(path, filename=f"{video_id}.mp3")
            print(f"download:{video_id} in path: {path_dest}")
            cut_audio(path_dest, start*1000, end*1000)

        path_dest = f"{path}/{video_id}.mp3"

    else:

        if f"{video_id}.mp3" not in os.listdir(path):
            return ""

        else:
            path_dest = f"{path}/{video_id}.mp3"

    return path_dest

def download_all_dataset():
    with open("./Neural-Networks-Mastrandrea-Frangella/new_balanced.csv", mode ='r')as file:
        csvFile = csv.reader(file)
        for lines in csvFile:
            video_id = lines[0]
            start = lines[1]
            end = lines[2]
            try:
                get_audio_clip(video_id,float(start),float(end))
            except:
                continue

def get_random_files(directory, count=20):
    files = os.listdir(directory)
    random_files = random.sample(files, count)
    return random_files

def get_batch(batch_size,dataset_path,labels_file,modality):
    batch_size=batch_size*2
    random_samples = get_random_files(dataset_path,batch_size)
    half = int(batch_size/2)

    first_subset = random_samples[0:half]
    second_subset = random_samples[half:batch_size]

    labels_dict = {}

    with open(labels_file, mode ='r')as file:
        csvFile = csv.reader(file)
        for lines in csvFile:
            label = lines[4][1:-1]
            label = label.replace("[","")
            label = label.replace("]","")
            label = label.replace(",","")
            label = label.replace("'","")
            labels_dict[lines[0]]=[label]

    batch = []
    labels = []
    target = []



    for i in range(half):
        audio1,sample_rate1 = torchaudio.load(f"./download/{first_subset[i]}")
        audio2,sample_rate2 = torchaudio.load(f"./download/{second_subset[i]}")

        mixed = get_mixture_audio(audio1,audio2)

        # computing starting and ending frame for audio1
        start = random.randint(0,160000)
        end = start + 160000

        mixed = mixed[:,start:end]
        audio1 = audio1[:,start:end]

        target.append(audio1)

        batch.append(mixed)
        labels.append(labels_dict[first_subset[i][:-4]][0])

    batch = torch.stack(batch,dim=0)
    target = torch.stack(target,dim=0)

    if modality=="text":
        return (batch,target,labels,False)

    else:
        if random.uniform(0,1) > 0.5:
            labels = ["./download/"+elem for elem in first_subset]
            return (batch,target,labels,True)

    return (batch,target,labels,False)


def get_audio_metadata(video_id,path='./new_balanced.csv'):
  with open(path, mode='r') as file:
    # Create a CSV reader object
    csv_reader = csv.reader(file)
    for row in csv_reader:
      if row[0] == video_id:
        return row

def get_dict_of_classes():
  with open('class_labels_indices.csv', mode='r') as file:
    # Create a CSV reader object
        csv_reader = csv.reader(file)
        # Read and print each row of the CSV file
        classes_dict = {}
        labels = []
        for row in csv_reader:

            label = row[1]
            labels.append(label)

        classes = labels.copy()
        for elem in labels:
            #remove elem from classes
            classes.remove(elem)
            random_classes = random.sample(classes,5)
            #add elem from labels
            classes.append(elem)
            #shuffle of labels
            random.shuffle(classes)
            classes_dict[elem]=random_classes
  file_path = 'eval_segments.csv'
  with open(file_path, 'r') as file:
    reader = csv.reader(file)
    next(reader)
    next(reader)
    next(reader)
    video_dict = {}
    for row in reader:
        new_column = ""
        for i in range(3,len(row)):
            new_column += row[i]
            if(i != len(row)-1):
                new_column += ","
        row[3] = new_column.replace('"', '')
        row[3] = row[3].replace(' ', '')
        tokens = row[3].split(",")
        for elem in tokens:
            if elem not in video_dict:
                video_dict[elem] = []
                video_dict[elem].append(row[0])
            else:
                video_dict[elem].append(row[0])
  return video_dict,classes_dict

In [3]:
#@title Clean the dataset
%%capture
import os
audios = os.listdir("./download")
import shutil


for audio_path in audios:
  audio_clip,w = torchaudio.load(f"./download/{audio_path}")
  if torch.norm(audio_clip,p=2) == 0:
    print(f"./download/{audio_path}")
    os.remove(f"./download/{audio_path}")
  if audio_clip.shape[1] != 320000:
    os.remove(f"./download/{audio_path}")
    print(f"./download/{audio_path}")


In [22]:
#@title Classes to Manage Audios
class Audioclip():
  def __init__(self,video_id,path="./download",file="./new_balanced.csv"):
    self.video_id = video_id
    info = get_audio_metadata(video_id,file)
    self.start, self.end, self.classes, self.labels = info[1],info[2],info[3],info[4]
    self.start = float(self.start)
    self.end = float(self.end)
    self.classes = self.classes.split(",")
    self.labels = self.labels.replace("[","").replace("]","").replace(",","").replace('"',"")

    if (video_id) not in os.listdir("./download"):
      path_dest = get_audio_clip(video_id,self.start,self.end,True,path)
    self.audioclip,self.sample_rate = torchaudio.load(path_dest)

class Mixedaudio():
  def __init__(self,audio1, audio2 ,path1 = "./download" ,path2="./download",file="./new_balanced.csv"):
    self.audio1 = Audioclip(audio1,path1,file)
    self.audio2 = Audioclip(audio2,path2,file)

    self.mixed_track = get_mixture_audio(self.audio1.audioclip,self.audio2.audioclip)


  def get_random_sample(self):
    start = random.randint(0,160000)
    end = start + 160000
    return (self.mixed_track[:,start:end],self.audio1.audioclip[:,start:end])

In [23]:
#@title Function to get the validation set
def get_validation_set(classes):
  video_dict,classes_dict = get_dict_of_classes()
  video_keys = list(video_dict.keys())
  print(video_keys)
  print(video_dict)
  j = 0
  m = []
  target = []
  labels = []
  for key in video_keys:
    if j >= classes:
      break
    print(key)
    random_classes = classes_dict[key]
    for c in random_classes:

      while True:
        s1 = random.sample(video_dict[key],1)[0]
        s2 = random.sample(video_dict[c],1)[0]
        try:
          mix = Mixedaudio(s1,s2,"./validation","./validation","./eval_segments.csv")
          mixed,target_cutted = mix.get_random_sample()
          target.append(target_cutted)
          m.append(mixed)
          labels.append(mix.audio1.labels)
          break
        except:
          continue



    j+=1
  return m,target,labels
#random.seed(12345)
#mixed, target, labels = get_validation_set(10)


In [None]:
!zip -r ./validation.zip /content/validation

In [6]:
#@title Utility Functions for the Neural Network
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchlibrosa.stft import STFT, ISTFT, magphase
import numpy as np

stft = STFT(n_fft=1024,
            hop_length=320,
            win_length=1024,
            window='hann',
            center=True,
            pad_mode='reflect',
            freeze_parameters=True).to(device='cuda')

istft = ISTFT(
            n_fft=1024,
            hop_length=320,
            win_length=1024,
            window='hann',
            center=True,
            pad_mode='reflect',
            freeze_parameters=True
        ).to(device='cuda')

def init_bn(bn):
    bn.bias.data.fill_(0.0)
    bn.weight.data.fill_(1.0)

def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.0)

def from_audio_to_spectogram(audios):
    magnitudes = []
    cosines = []
    sines = []
    for i in range(audios.shape[1]):

        (real,imag) = stft(audios[:,i,:])
        mag = torch.clamp(real ** 2 + imag ** 2, 1e-8, np.inf) ** 0.5
        cos = real / mag
        sin = imag / mag
        magnitudes.append(mag)
        cosines.append(cos)
        sines.append(sin)
    mags = torch.cat(magnitudes, dim=1)
    coss = torch.cat(cosines, dim=1)
    sins = torch.cat(sines, dim=1)

    return mags,coss,sins

In [7]:
#@title Film module
class FilmModule(nn.Module):
    def __init__(self,input_size,output_size):
        super(FilmModule, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.l1 = nn.Linear(input_size, output_size * 2)
        self.l2 = nn.Linear(output_size * 2, output_size)
        self.init_weights()

    def init_weights(self):
        init_layer(self.l1)
        init_layer(self.l2)

    def forward(self,data,embedding_vector):
        x = F.relu(self.l1(embedding_vector))
        x = F.relu(self.l2(x))
        dim = x.shape
        x = data + x[...,None,None]

        return x



In [8]:
#@title Encoder Block
class EncoderBlock(nn.Module):
    def __init__(self,input_channels, output_channels, embedding_size, momentum,downsample):
        super(EncoderBlock, self).__init__()
        self.downsample = downsample
        self.Film1 = FilmModule(embedding_size,input_channels)
        self.Film2 = FilmModule(embedding_size,output_channels)


        self.bn1 = nn.BatchNorm2d(input_channels,momentum=momentum)

        self.conv1 = nn.Conv2d(
            in_channels=input_channels,
            out_channels=output_channels,
            kernel_size=(3,3),
            stride=(1,1),
            dilation=(1,1),
            padding=(1,1),
            bias=False
            )

        self.bn2 = nn.BatchNorm2d(output_channels,momentum=momentum)

        self.conv2 = nn.Conv2d(
            in_channels=output_channels,
            out_channels=output_channels,
            kernel_size=(3,3),
            stride=(1,1),
            dilation=(1,1),
            padding=(1,1),
            bias=False
        )

        if input_channels != output_channels:
            self.residual_convolution = nn.Conv2d(
                in_channels=input_channels,
                out_channels=output_channels,
                kernel_size=(1,1),
                stride=(1,1),
                padding=(0,0),
            )
            self.has_residual_connection = True
        else:
            self.has_residual_connection = False

        self.init_weights()


    def init_weights(self):
        init_bn(self.bn1)
        init_bn(self.bn2)
        init_layer(self.conv1)
        init_layer(self.conv2)

        if self.has_residual_connection:
            init_layer(self.residual_convolution)



    def forward(self,input_tensor,embedding_vector):

        x = self.bn1(input_tensor)
        x = self.Film1(x,embedding_vector)
        x = F.leaky_relu(x,negative_slope=0.01)
        x = self.conv1(x)
        x = self.bn2(x)
        x = self.Film2(x,embedding_vector)
        x = F.leaky_relu(x,negative_slope=0.01)
        x = self.conv2(x)

        if self.has_residual_connection:
            y = self.residual_convolution(input_tensor)
            x = x + y

        x_pool = F.avg_pool2d(x,self.downsample)

        return x, x_pool

In [9]:
#@title Decoder Block
class DecoderBlock(nn.Module):

    def __init__(self,input_size, output_size,embedding_size,momentum,upsample):
        super(DecoderBlock, self).__init__()
        self.upsample = upsample

        self.conv1 = torch.nn.ConvTranspose2d(
            in_channels=input_size,
            out_channels=output_size,
            kernel_size=self.upsample,
            stride=self.upsample,
            padding=(0,0),
            bias=False,
            dilation=(1,1)

        )

        self.bn1 = nn.BatchNorm2d(input_size,momentum=momentum)

        #self.conv_block2 = ConvBlockRes(
        #    out_channels * 2, out_channels, kernel_size, momentum, has_film,

        self.Film1 = FilmModule(embedding_size,input_size)
        self.Film2 = FilmModule(embedding_size,output_size*2)
        self.Film3 = FilmModule(embedding_size,output_size)

        self.bn2 = nn.BatchNorm2d(output_size*2,momentum=momentum)
        self.bn3 = nn.BatchNorm2d(output_size,momentum=momentum)

        self.conv2 = nn.Conv2d(
            in_channels=output_size*2,
            out_channels=output_size,
            kernel_size=(3,3),
            stride=(1,1),
            dilation=(1,1),
            padding=(1,1),
            bias=False
        )

        self.conv3 = nn.Conv2d(
            in_channels=output_size,
            out_channels=output_size,
            kernel_size=(3,3),
            stride=(1,1),
            dilation=(1,1),
            padding=(1,1),
            bias=False
        )

        self.residual_convolution = nn.Conv2d(
                in_channels=output_size*2,
                out_channels=output_size,
                kernel_size=(1,1),
                stride=(1,1),
                padding=(0,0),
            )
        self.has_residual_connection = True

        self.bn4 = nn.BatchNorm2d(input_size,momentum=momentum)




        self.init_weights()

    def init_weights(self):
        init_bn(self.bn1)
        init_bn(self.bn2)
        init_bn(self.bn3)

        init_layer(self.conv1)
        init_layer(self.conv2)
        init_layer(self.conv3)

        if self.has_residual_connection:
            init_layer(self.residual_convolution)

    def forward(self,input_tensor,concat_tensor,embedding_vector):
        x = self.bn1(input_tensor)
        x = self.Film1(x,embedding_vector)
        x = F.leaky_relu(x)

        x = self.conv1(x)

        x = torch.cat((x,concat_tensor), dim=1)
        x_res = x
        x = self.bn2(x)
        x = self.Film2(x,embedding_vector)
        x = F.leaky_relu(x,negative_slope=0.01)
        x = self.conv2(x)
        x = self.bn3(x)
        x = self.Film3(x,embedding_vector)
        x = F.leaky_relu(x,negative_slope=0.01)
        x = self.conv3(x)
        if self.has_residual_connection:
            y = self.residual_convolution(x_res)



            x = x + y

        return x

In [10]:
#@title ResUNet
class ResUnet(nn.Module):

    def __init__(self, input_size, output_size):
        super(ResUnet, self).__init__()

        self.input_size = input_size;
        self.output_size = output_size;

        self.momentum = 0.01


        # instanziare la preconv che è una conv2d

        # definire la classe degli encoder block
        # definire la classe dei decoder block

        self.batch_norm0 = nn.BatchNorm2d(513,momentum=self.momentum)
        self.preconvolution = nn.Conv2d(
            in_channels=input_size,
            out_channels=32,
            kernel_size=(1,1),
            stride=(1,1),
            padding=(0,0),
            bias=True
        )

        self.EncoderBlock1 = EncoderBlock(
            input_channels=32,
            output_channels=32,
            downsample=(2,2),
            embedding_size=512,
            momentum=0.01
        )

        self.EncoderBlock2 = EncoderBlock(
            input_channels=32,
            output_channels=64,
            downsample=(2,2),
            embedding_size=512,
            momentum=0.01
        )

        self.EncoderBlock3 = EncoderBlock(
            input_channels=64,
            output_channels=128,
            downsample=(2,2),
            embedding_size=512,
            momentum=0.01
        )

        self.EncoderBlock4 = EncoderBlock(
            input_channels=128,
            output_channels=256,
            downsample=(2,2),
            embedding_size=512,
            momentum=0.01
        )

        self.EncoderBlock5 = EncoderBlock(
            input_channels=256,
            output_channels=384,
            downsample=(2,2),
            embedding_size=512,
            momentum=0.01
        )

        self.EncoderBlock6 = EncoderBlock(
            input_channels=384,
            output_channels=384,
            downsample=(1,2),
            embedding_size=512,
            momentum=0.01
        )

        self.EncoderBlock7 = EncoderBlock(
            input_channels=384,
            output_channels=384,
            downsample=(1,1),
            momentum=0.01,
            embedding_size=512
        )

        self.DecoderBlock1 = DecoderBlock(
            input_size=384,
            output_size= 384,
            embedding_size= 512,
            momentum=0.01,
            upsample=(1,2)
            )

        self.DecoderBlock2 = DecoderBlock(
            input_size=384,
            output_size= 384,
            embedding_size= 512,
            momentum=0.01,
            upsample=(2,2)
            )

        self.DecoderBlock3 = DecoderBlock(
            input_size=384,
            output_size= 256,
            embedding_size= 512,
            momentum=0.01,
            upsample=(2,2)
            )

        self.DecoderBlock4 = DecoderBlock(
            input_size=256,
            output_size= 128,
            embedding_size= 512,
            momentum=0.01,
            upsample=(2,2)
            )

        self.DecoderBlock5 = DecoderBlock(
            input_size=128,
            output_size= 64,
            embedding_size= 512,
            momentum=0.01,
            upsample=(2,2)
            )

        self.DecoderBlock6 = DecoderBlock(
            input_size=64,
            output_size= 32,
            embedding_size= 512,
            momentum=0.01,
            upsample=(2,2)
            )


        self.after_conv = nn.Conv2d(
            in_channels=32,
            out_channels=2*3,
            kernel_size=(1, 1),
            stride=(1, 1),
            padding=(0, 0),
            bias=True,
        )

        self.init_weights()
    def init_weights(self):
        init_bn(self.batch_norm0)
        init_layer(self.preconvolution)
        init_layer(self.after_conv)

    def forward(self,input,embedding_vector):
        #x -->mag
        #coss--->coss
        #sins--->sins
        audio_length = input.shape[-1]
        audios = input.clone()
        mags,coss,sins = from_audio_to_spectogram(audios)
        x = mags
        x = x.transpose(1,3)
        x = self.batch_norm0(x)
        x = x.transpose(1,3)
        origin_len = x.shape[2]
        pad_len = (int(np.ceil(x.shape[2] / 2**5)) * (2**5)- origin_len)
        x = F.pad(x, pad=(0, 0, 0, pad_len))
        x = x[:,:,:,0:-1]
        x = self.preconvolution(x)

        x1, x1_pool = self.EncoderBlock1(x,embedding_vector)
        x2, x2_pool = self.EncoderBlock2(x1_pool,embedding_vector)
        x3, x3_pool = self.EncoderBlock3(x2_pool,embedding_vector)
        x4, x4_pool = self.EncoderBlock4(x3_pool,embedding_vector)
        x5, x5_pool = self.EncoderBlock5(x4_pool,embedding_vector)
        x6, x6_pool = self.EncoderBlock6(x5_pool,embedding_vector)
        x_c,x_c_pool = self.EncoderBlock7(x6_pool,embedding_vector)
        x7 = self.DecoderBlock1(x_c,x6,embedding_vector)
        x8 = self.DecoderBlock2(x7,x5,embedding_vector)
        x9 = self.DecoderBlock3(x8,x4,embedding_vector)
        x10 = self.DecoderBlock4(x9,x3,embedding_vector)
        x11 = self.DecoderBlock5(x10,x2,embedding_vector)
        x12 = self.DecoderBlock6(x11,x1,embedding_vector)
        x = self.after_conv(x12)
        x = F.pad(x, pad=(0, 1))
        x = x[:, :, 0:origin_len, :]
        batch_size,_,time_steps,frequency_bins = x.shape
        x = x.reshape(
            batch_size,   #batch size
            1,   #target audio
            2,   #num channels
            3,   #magnitude, cos, sin
            time_steps, #time_steps
            frequency_bins, #frequency bins
        )
        mask_mag = torch.sigmoid(x[:, :, :, 0, :, :]) # get the magnitude mask
        _mask_real = torch.tanh(x[:, :, :, 1, :, :])  # get the real component mask
        _mask_imag = torch.tanh(x[:, :, :, 2, :, :])  # get the imaginary component mask
        #mask_mag, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)# get the fourier transform of the mask
        # apply the filtering functions to the cosine
        # mask_mag = (_mask_real ** 2 + _mask_imag ** 2) ** 0.5
        mask_cos = _mask_real / torch.clamp(_mask_imag, 1e-10, np.inf)
        mask_sin = _mask_imag / torch.clamp(_mask_imag, 1e-10, np.inf)
        out_cos = (
            coss[:, None, :, :, :] * mask_cos - sins[:, None, :, :, :] * mask_sin
        )

        # apply the filtering function to the sine

        out_sin = (
            sins[:, None, :, :, :] * mask_cos + coss[:, None, :, :, :] * mask_sin
        )
        #apply the filtering function to the magnitude

        out_mag = F.relu(mags[:, None, :, :, :] * mask_mag)
        out_real = torch.mul(out_mag , out_cos)
        out_imag = torch.mul(out_mag , out_sin)

        out_real = out_real.reshape(2*batch_size,1,time_steps,frequency_bins)
        out_imag = out_imag.reshape(2*batch_size,1,time_steps,frequency_bins)
        x = istft(out_real, out_imag, audio_length)
        waveform = x.reshape(batch_size,2,audio_length)
        return waveform

In [11]:
#@title Loss Implementation
class LossAudio(nn.Module):
  def __init__(self):
    super(LossAudio, self).__init__()
  def forward(self,s,s_hat):
    loss = torch.abs(s-s_hat)
    return loss.mean()

In [12]:
#@title Network definition with loss and optimizer
class AudioSep(nn.Module):
  def __init__(self,input_size, output_size):
      super(AudioSep, self).__init__()
      self.model = ResUnet(input_size,output_size)
      self.loss = LossAudio()
      self.optimizer = torch.optim.AdamW(params=self.model.parameters(),
                lr=1e-3,
                betas=(0.9, 0.999),
                eps=1e-08,
                weight_decay=0.0,
                amsgrad=True)

  def forward(self,x):
    pass

In [16]:
rete = AudioSep(2,2).to(device='cuda')

In [None]:
torch.save(rete.model.state_dict(), "./checkpoint.pt")

In [17]:
training_steps = 10
batch_size = 8
torch.autograd.set_detect_anomaly(True)
for i in range(training_steps):
  batch,target,labels,isAudio = get_batch(batch_size,"./download","./new_balanced.csv","text")

  audio_embeddings = model.get_text_embedding(labels,use_tensor=True).to(device="cuda")
  batch = batch.to(device="cuda")
  target = target.to(device="cuda")
  output = rete.model(batch,audio_embeddings)
  loss = rete.loss(output,target)
  print("iteration",i,loss)
  # Backward and optimize
  rete.optimizer.zero_grad()
  loss.backward()
  rete.optimizer.step()



iteration 0 tensor(2.9621e+08, device='cuda:0', grad_fn=<MeanBackward0>)
iteration 1 tensor(30952894., device='cuda:0', grad_fn=<MeanBackward0>)
iteration 2 tensor(6509434.5000, device='cuda:0', grad_fn=<MeanBackward0>)
iteration 3 tensor(5757191., device='cuda:0', grad_fn=<MeanBackward0>)
iteration 4 tensor(166935.1562, device='cuda:0', grad_fn=<MeanBackward0>)
iteration 5 tensor(27511.4746, device='cuda:0', grad_fn=<MeanBackward0>)
iteration 6 tensor(1916.7948, device='cuda:0', grad_fn=<MeanBackward0>)
iteration 7 tensor(1815.7736, device='cuda:0', grad_fn=<MeanBackward0>)
iteration 8 tensor(0.0958, device='cuda:0', grad_fn=<MeanBackward0>)
iteration 9 tensor(0.0962, device='cuda:0', grad_fn=<MeanBackward0>)


In [25]:
random.seed(12345)
mixed,target,labels = get_validation_set(3)

['/m/068hy', '/m/07q6cd_', '/m/0bt9lr', '/m/0jbk', '/m/03l9g', '/m/01b_21', '/m/04rlf', '/m/09x0r', '/t/dd00004', '/t/dd00005', '/m/07rgt08', '/m/07sq110', '/t/dd00001', '/m/07pbtc8', '/m/0140xf', '/m/02cjck', '/m/03v3yw', '/m/0k4j', '/m/03k3r', '/m/07q5rw0', '/m/04brg2', '/g/122z_qxw', '/m/025_jnm', '/m/01g90h', '/m/07pzfmf', '/m/015p6', '/m/0chx_', '/t/dd00128', '/m/03dnzn', '/m/07p7b8y', '/m/07ptzwd', '/m/0838f', '/t/dd00088', '/m/0130jx', '/m/02jz0l', '/m/0k65p', '/t/dd00125', '/m/06bz3', '/m/02fs_r', '/m/02zsn', '/m/07pjjrj', '/m/0xzly', '/m/07qn4z3', '/m/01jg1z', '/m/01jt3m', '/m/01h82_', '/m/02mk9', '/m/07pb8fc', '/m/07ppn3j', '/m/014zdl', '/m/01j4z9', '/m/01glhc', '/m/0342h', '/m/042v_gx', '/m/07s0s5r', '/m/0fx80y', '/m/02rhddq', '/m/07r04', '/m/07yv9', '/t/dd00018', '/m/06mb1', '/m/0jb2l', '/m/0ngt1', '/t/dd00038', '/t/dd00013', '/m/07cx4', '/m/04szw', '/m/0l156b', '/t/dd00130', '/m/04wptg', '/m/06hck5', '/m/06wzb', '/m/019jd', '/m/03m9d0z', '/m/06q74', '/t/dd00092', '/m/01bjv

In [29]:
for i in range(len(mixed)):
  audio_input = torch.unsqueeze(mixed[i],dim=0).to(device="cuda")
  print(audio_input.shape)
  label_input = labels[i]
  output_target = torch.unsqueeze(target[i],dim=0).to(device="cuda")

  audio_embedding = model.get_text_embedding([label_input,label_input],use_tensor=True).to(device="cuda")[0]
  output = rete.model(audio_input,audio_embedding)
  validation_loss = rete.loss(output,output_target)
  print(validation_loss)


  torchaudio.save("./mixed.mp3",audio_input.to(device="cpu")[0],32000)
  torchaudio.save("./res.mp3",output.to(device="cpu")[0],32000)
  torchaudio.save("./target.mp3",output_target.to(device="cpu")[0],32000)
  break


torch.Size([1, 2, 160000])
tensor(0.0016, device='cuda:0', grad_fn=<MeanBackward0>)


In [None]:
res = output.to(device="cpu")
target = target.to(device="cpu")
batch = batch.to(device="cpu")
print(labels[0])
torchaudio.save("./mixed.mp3",batch[0],32000)
torchaudio.save("./res.mp3",res[0],32000)
torchaudio.save("./target.mp3",target[0],32000)


In [None]:
from pydub import AudioSegment
import os
list_files = os.listdir("./download")
for elem in list_files:
  sound = AudioSegment.from_file(f"./download/{elem}")
  sound = sound.set_channels(1)
  sound.export(f"./download_mono/{elem}", format="mp3")