In [1]:
import auraloss
import collections
import librosa
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
import plotly.graph_objects as go
import pretty_midi
import pytorch_lightning as pl
import pywt
import random
import scipy.signal
import sklearn
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchaudio
from torch.utils.data import DataLoader, Dataset
import wandb
from pytorch_lightning.loggers import WandbLogger
from tqdm import tqdm

print(torch.cuda.is_available())


seed_value = 3407
torch.manual_seed(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.set_float32_matmul_precision('high')

True


In [2]:
class Autoencoder(pl.LightningModule):
    def __init__(self):
        super(Autoencoder, self).__init__()

        #losses
        self.loss_fn_1 = auraloss.freq.RandomResolutionSTFTLoss(
                    sample_rate=32000,
                    device="cuda"
                )
        self.loss_fn_2 = auraloss.time.SISDRLoss()
        self.loss_fn_3 = torch.nn.L1Loss()
        
        # Encoder
        self.enc_conv1 = nn.Conv1d(1, 8, kernel_size=33, stride=4, padding=16)
        self.enc_conv2 = nn.Conv1d(8, 16, kernel_size=17, stride=4, padding=8)
        self.enc_conv3 = nn.Conv1d(16, 32, kernel_size=9, stride=2, padding=4)
        self.enc_conv4 = nn.Conv1d(32, 64, kernel_size=9, stride=2, padding=4)
        self.enc_conv5 = nn.Conv1d(64,128, kernel_size=9, stride=2, padding=4)
        self.enc_conv6 = nn.Conv1d(128, 256, kernel_size=9, stride=2, padding=4)
        self.enc_conv7 = nn.Conv1d(256, 512, kernel_size=9, stride=2, padding=4)
        self.enc_conv8 = nn.Conv1d(512, 1024, kernel_size=9, stride=2, padding=4)
        
        # Decoder
        self.dec_conv1 = nn.ConvTranspose1d(1024, 512, kernel_size=9, stride=2, padding=4, output_padding=1)
        self.dec_conv2 = nn.ConvTranspose1d(512, 256, kernel_size=9, stride=2, padding=4, output_padding=0)
        self.dec_conv3 = nn.ConvTranspose1d(256, 128, kernel_size=9, stride=2, padding=5, output_padding=0)
        self.dec_conv4 = nn.ConvTranspose1d(128, 64, kernel_size=9, stride=2, padding=4, output_padding=0)
        self.dec_conv5 = nn.ConvTranspose1d(64,32, kernel_size=9, stride=2, padding=4, output_padding=0)
        self.dec_conv6 = nn.ConvTranspose1d(32, 16, kernel_size=9, stride=2, padding=4, output_padding=0)
        self.dec_conv7 = nn.ConvTranspose1d(16, 8, kernel_size=21, stride=4, padding=9, output_padding=0)
        self.dec_conv8 = nn.ConvTranspose1d(8, 1, kernel_size=37, stride=4, padding=22, output_padding=0)

    def forward(self, x):
        # Encoder
        x = self.enc_conv1(x)
        x = self.enc_conv2(x)
        x = self.enc_conv3(x)
        x = self.enc_conv4(x)
        x = self.enc_conv5(x)
        x = self.enc_conv6(x)
        x = self.enc_conv7(x)
        x = self.enc_conv8(x)
        encoded = x
        
        # Decoder
        x = self.dec_conv1(x)
        x = self.dec_conv2(x)
        x = self.dec_conv3(x)
        x = self.dec_conv4(x)
        x = self.dec_conv5(x)
        x = self.dec_conv6(x)
        x = self.dec_conv7(x)
        x = self.dec_conv8(x)

        x = x[:,:,:160000]
        return x, encoded


In [3]:
model = Autoencoder.load_from_checkpoint('./final_model.ckpt')

In [4]:
path = "../data/rendered_audio/rendered_audio/"
files = os.listdir(path)
fs = []
embeddings = []
transform = torchaudio.transforms.Resample(44100, 32000)
for file in tqdm(files):
    if ".flac" in file:
        try:
            full_path = path + file
            audio, _ = torchaudio.load(full_path)
            audio = transform(audio)

            _, embedding = model(audio.to(model.device).unsqueeze(0))
            embeddings.append(embedding.squeeze(0).squeeze(0).detach().cpu().numpy())
            fs.append(full_path)
        except Exception as e:
            print("error", e)
            


 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                         | 22263/27131 [06:03<01:29, 54.38it/s]

error Error opening '../data/rendered_audio/rendered_audio/d270f326-a3f6-4807-ac06-8716c9166ad1.flac': Format not recognised.


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27131/27131 [08:44<00:00, 51.73it/s]


# Dump output



In [None]:
import pickle
with open('embeddings.pkl', 'wb') as f:
    pickle.dump(embeddings, f)

# Visualize

In [5]:
embeddings[0].shape

(1024, 157)

In [6]:
emb = [e.flatten() for e in embeddings]

KeyboardInterrupt: 

In [None]:
pca = PCA(n_components=3, svd_solver='full')
pca_ = pca.fit_transform(emb)

In [None]:
kmean_model = KMeans(n_clusters=11).fit(emb)
labels = kmean_model.predict(emb)

In [None]:
feature_a = []
feature_b = []
feature_c = []
# feature_d = []
# feature_e = []

for i in pca_:
    feature_a.append(i[0])
    feature_b.append(i[1])
    feature_c.append(i[2])
    # feature_d.append(i[3])
    # feature_e.append(i[4])

In [None]:
plots = [feature_a, feature_b, feature_c]#, feature_d, feature_e]

for idx, val in enumerate(plots):
    for idx1, val1 in enumerate(plots):
        if idx > idx1:
            #plotting the results
            plt.scatter(val, val1, c=labels)
            plt.show()