# Using the model to cluster songs according to their genre

### Loading the model

In [None]:
import torch
import torch.nn as nn
import deeplay as dl
import deeptrack as dt
import os

vae = dl.VariationalAutoEncoder(input_size=(256,2560),
    latent_dim=10, channels=[64, 128, 256, 512],
    reconstruction_loss=torch.nn.MSELoss(reduction="sum")
).create()
vae.load_state_dict(torch.load("vae.pth"))
vae.eval()

#Preprocess data
data_dir = os.path.expanduser("./trainImages")

#Load image files using ImageFolder
trainFiles = dt.sources.ImageFolder(root=data_dir)

class CropWidth:
    def __init__(self, target_width):
        self.target_width = target_width

    def __call__(self, x: torch.Tensor):
        # assuming input shape [C, H, W]
        return x[..., :self.target_width]
    
image_pip = (dt.LoadImage(trainFiles.path) >> dt.NormalizeMinMax()
             >> dt.MoveAxis(2, 0) >> dt.pytorch.ToTensor(dtype=torch.float) >> CropWidth(2560))

train_dataset = dt.pytorch.Dataset(image_pip & image_pip, inputs=trainFiles)
train_loader = dl.DataLoader(train_dataset, batch_size=1, shuffle=True)

from Image2Sound import *
from torchvision.utils import save_image
i = 0
for image, _ in train_loader:
    mu, _ = vae.encode(image)
    image = vae.decode(mu).clone().detach().squeeze(0)
    break

"""
latentSpaceImage = vae.encode(img_tensor.unsqueeze(0))
print(latentSpaceImage.shape)
z = 255*torch.stack(latentSpaceImage)
print(z)
"""
iamge = image.float()
class conf:
    sampling_rate = 44100
    duration = 30
    samples = sampling_rate * duration
    n_mels = 256
    hop_length = 512
    n_fft = 2048 
    fmin = 20
    fmax = sampling_rate // 2

save_image(image, 'test_sample.jpg')
audio = Image2Sound('test_sample.jpg', conf)
SaveAudio(audio,os.getcwd(),"test_sample.mp3")



ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enab