In [None]:
from transformers import ViTModel, ViTImageProcessorFast, ClapModel, ClapProcessor
import torch
from torchvision import transforms
from torch import nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ImageEncoder(nn.Module):
    def __init__(self, model_type, checkpoint):
        super().__init__()
        self.model_type = model_type

        if self.model_type == 'vit':
            # Load the model and the class that the inputs go through
            self.model = ViTModel.from_pretrained(checkpoint, device_map='auto')
            self.preprocessor = ViTImageProcessorFast.from_pretrained(checkpoint)
        else:
            raise Exception('Unsupported model for image encoder')

    def encode(self, data):
        if self.model_type == 'vit':
            inputs = self.preprocessor(images = data, return_tensors="pt", do_normalize=True, do_convert_rgb=True, do_rescale=True, do_resize=True)
            inputs = inputs.to(device)
            outputs = self.model(**inputs)
            return outputs.pooler_output

class AudioEncoder(nn.Module):
    def __init__(self, model_type, checkpoint):
        super().__init__()
        self.model_type = model_type

        if self.model_type == 'clap':
            # Load the model and the class that the inputs go through
            self.model = ClapModel.from_pretrained(checkpoint, device_map='auto')
            self.preprocessor = ClapProcessor.from_pretrained(checkpoint, use_fast=True)
        else:
            raise Exception('Unsupported model for audio encoder')

    def encode(self, sampling_rate, waveforms):
        if self.model_type == 'clap':
            inputs = self.preprocessor(audios=waveforms.numpy(), sampling_rate=sampling_rate, return_tensors='pt')
            inputs = inputs.to(device)
            outputs = self.model.get_audio_features(**inputs)
            return outputs

class MLPMapper(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.Tanh(),
            torch.nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.mlp(x)

class PictureToMusicModel(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.init_image_encoder()
        self.init_audio_encoder()

        self.image_mapper = MLPMapper(
            input_dim=self.config.image_embedding_size, 
            hidden_dim=self.config.mlp_hidden_size, 
            output_dim=self.config.shared_embedding_size
        )
        if config.has_audio_mapper:
            self.audio_mapper = MLPMapper(
                input_dim=self.config.audio_embedding_size,
                hidden_dim=self.config.mlp_hidden_size,
                output_dim=self.config.shared_embedding_size
            )

        # Learnable temperature
        self.logit_scale = torch.nn.Parameter(torch.tensor(1.0))

    def forward(self, image_input, audio_input):
        image_emb = self.image_encoder.encode(data=image_input)
        audio_emb = self.audio_encoder.encode(sampling_rate=48000, waveforms=audio_input)

        image_proj = self.image_mapper(image_emb)
        if self.config.has_audio_mapper:
            audio_proj = self.audio_mapper(audio_emb)
        else:
            audio_proj = audio_emb

        # clamping (restricting) the temperature
        logit_scale = self.logit_scale.exp()
        logit_scale = torch.clamp(logit_scale, 0, 100)

        return image_proj, audio_proj, logit_scale

    def encode_image(self, image_input):
        image_emb = self.image_encoder.encode(data=image_input)
        image_proj = self.image_mapper(image_emb)
        image_proj = torch.nn.functional.normalize(image_proj, dim=-1)
        res = image_proj.detach().cpu().numpy()
        return res
        
    def encode_audio(self, audio_input):
        audio_emb = self.audio_encoder.encode(sampling_rate=48000, waveforms=audio_input)
        if self.config.has_audio_mapper:
            audio_proj = self.audio_mapper(audio_emb)
        else:
            audio_proj = audio_emb
        audio_proj = torch.nn.functional.normalize(audio_proj, dim=-1)
        res = audio_proj.detach().cpu().numpy()
        
        return res
        
    def init_image_encoder(self):        
        self.image_encoder = ImageEncoder(
            model_type=self.config.image_encoder_type,
            checkpoint=self.config.image_encoder_checkpoint
        )
        
        for param in self.image_encoder.model.parameters():
            param.requires_grad = False
        if not self.config.freeze_image_encoder:
            # Unfreeze only the last couple of layers so we can finetune
            for param in self.image_encoder.model.encoder.layer[-self.config.num_layers_to_unfreeze].parameters():
                param.requires_grad = True

    def init_audio_encoder(self):
        self.audio_encoder = AudioEncoder(
            model_type=self.config.audio_encoder_type, 
            checkpoint=self.config.audio_encoder_checkpoint
        )
        if self.config.freeze_audio_encoder:
            for param in self.audio_encoder.model.parameters():
                param.requires_grad = False


from dataclasses import dataclass, asdict

@dataclass
class PictureToMusicConfig:
    image_encoder_type: str = 'vit'
    image_encoder_checkpoint: str = 'google/vit-base-patch16-224-in21k'
    audio_encoder_type: str = 'clap'
    audio_encoder_checkpoint: str = 'laion/clap-htsat-unfused'
    freeze_audio_encoder: bool = True
    freeze_image_encoder: bool = True
    num_layers_to_unfreeze: int = 1
    image_embedding_size: int = 768
    audio_embedding_size: int = 512
    shared_embedding_size: int = 512
    has_audio_mapper: bool = False
    mlp_hidden_size: int = 1024

**Load model from hugging face**

In [None]:
from huggingface_hub import hf_hub_download
#from picture_to_music import PictureToMusicModel, PictureToMusicConfig
import torch, json

# # Load config from hugging face
config_path = hf_hub_download("Pesho564/Picture-to-music", "config.json")
with open(config_path) as f:
     config = json.load(f)

config_class = PictureToMusicConfig(**config)

# Load weights
weights_path = hf_hub_download("Pesho564/Picture-to-music", "model_state_dict.bin")
model = PictureToMusicModel(config_class).to(device)
model.load_state_dict(torch.load(weights_path))

model.eval()

# Model is now ready

In [4]:
audio_dir = '/kaggle/input/fma-free-music-archive-small-medium/fma_small/fma_small'
save_path = '/kaggle/working/audio_embeddings.pkl'
sr = 48000 # HZ
duration = 10 # seconds

In [5]:
import os 

# Count total files and subdirs
n_files = 0
n_dirs = 0
for root, dirs, files in os.walk(audio_dir):
    n_files += len([f for f in files if f.endswith(".mp3")])
    n_dirs += len(dirs)
print(f"Total MP3 files: {n_files}")
print(f"Total subdirs: {n_dirs}")

Total MP3 files: 8000
Total subdirs: 156


In [6]:
def load_and_preprocess_audio(filepath, sr, duration):
    y, _ = librosa.load(filepath, sr=sr, mono=True)
    target_len = sr * duration
    if len(y) < target_len:
        # Pad with zeros if too short
        y = np.pad(y, (0, target_len - len(y)))
    elif len(y) > target_len:
        # Trim if too long
        y = y[:target_len]
    return y

def embed(audio):
    return model.encode_audio(audio)

**Convert music dataset into embeddings**

In [33]:
import librosa
import numpy as np
import pandas as pd

if not os.path.exists(save_path):
    audios = []
    track_ids = []
    track_to_filepath = {}

    idx = 0
    for root, _, files in os.walk(audio_dir):
        for fname in files:
            if fname.endswith(".mp3"):
                try:
                    filepath = os.path.join(root, fname)
                    track_id = os.path.splitext(fname)[0]
                    track_ids.append(track_id)
                    track_to_filepath[track_id] = filepath
                    
                    audio = load_and_preprocess_audio(filepath, sr, duration)
                    audios.append(audio)
                    idx += 1
                    if idx % 100 == 0: # Save only on 100 iterations or so
                        embeddings = embed(torch.tensor(np.array(audios)))
                        embeddings_df = pd.DataFrame(embeddings, index=track_ids)
                        embeddings_df.to_pickle(save_path+str(idx))
                        audios = []
                        track_ids = []
                        print('Saved')
                except Exception as error:
                    print(error)
    
    embeddings = embed(torch.tensor(np.array(audios)))
    embeddings_df = pd.DataFrame(embeddings, index=track_ids)
    embeddings_df.to_pickle(save_path)
else:
    embeddings_df = pd.read_pickle(save_path)

In [31]:
if not os.path.exists(save_path):
    # Since all of the audio is too big to fit into one pass through the model
    # We split it into multiple passes and into multiple files and combine the results into a dataframe here
    dfs = []
    for _, _, files in os.walk('/kaggle/working'):
        for f in files:
            dfs.append(pd.read_pickle(f))
    embeddings_df = pd.concat(dfs)
    embeddings_df.to_pickle(save_path)

**Load a sample image**

In [32]:
import numpy as np
from PIL import Image
import requests

im = Image.open(requests.get('http://farm2.staticflickr.com/1357/947054324_da3b551fa9_z.jpg', stream=True).raw)

image_embedding = model.encode_image(np.array(im))

**Find and display the filename of the best music match**

In [None]:
def find_best_match_knn(query_emb, embeddings, k=5):
    scores = embeddings_df.apply(lambda row: np.dot(query_emb, row.values).item(), axis=1)
    return scores.nlargest(k)

top_matches = find_best_match_knn(image_embedding, embeddings_df, k=5)
print(top_matches)
print(track_to_filepath[top_matches.index[0]])