## PlayWrite Pipeline

What to expect:
1. Caption any image uploads with image captioning model
2. Generation of music prompt through Llama2-7b
3. Music generation through Mustango

Note: Llama7b Requires at least 16gb of vram to be loaded in half precision (Bfloat16)

In [1]:
#Imports
import pickle
import io
import os
import sys
import soundfile as sf
import IPython

import gc

from PIL import Image
import numpy as np
import torch
from torchvision import transforms, models
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

### Imports

In [2]:
class Vocabulary():
  """
  Class to convert the captions to index sequential tensors

  Args:
    freq_threshold (int, optional): How many times a word has to appear in dataset before it can be added to the vocabulary. Defaults to 2

  """

  def __init__(self, freq_threshold:int=2):
    self.itos = {0:"<PAD>", 1:"<SOS>", 2:"<EOS>", 3:"<UNK>"} #index to sentence
    self.stoi = {"<PAD>": 0, "<SOS>":1, "<EOS>": 2, "<UNK>":3} #sentence to index
    self.freq_threshold = freq_threshold #threshold for adding a word to the vocab

  def __len__(self):
    return len(self.itos)

  @staticmethod
  def tokenizer_eng(text):
    #convert sentence to list of words
    return [tok.text.lower() for tok in word_tokenizer.tokenizer(text)] #convert sentence to words


  def build_vocabulary(self, sentence_list):
    frequencies = {}
    idx = 4 #0-3 are for special tokens

    for sentence in sentence_list:
      for word in self.tokenizer_eng(sentence): #convert sentence to words
        if word not in frequencies:
          frequencies[word] = 1
        else:
          frequencies[word] += 1

        if frequencies[word] == self.freq_threshold: #once met freq_threshold, add to vocab list
          self.stoi[word] = idx
          self.itos[idx] = word
          idx += 1

  def numericalize(self, text):
    tokenized_text = self.tokenizer_eng(text) #convert annnotations to labels by converting each word to the index inside the vocab, else UNK tag
    return [
        self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
        for token in tokenized_text
    ]
  

class InceptionV3EncoderCNN(torch.nn.Module):
    """
    InceptionV3 CNN Encoder Model, feature extraction layer (mixed_7c) is always trainable

    Args:
        finetuned_model: Finetuned inceptionV3 model, else None
        train_cnn (bool, optional): Determines if the entire CNN model will be unfreeze and trained during the training. Defaults to False.
    """
    def __init__(self, finetuned_model, train_cnn:bool=False):
        super(InceptionV3EncoderCNN, self).__init__()
        if finetuned_model != None:
            self.inception = list(finetuned_model.children())[0]
        
        else:
            self.inception = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
        self.inception.aux_logits = False
        
        #Remove last classification layer
        self.inception.fc = torch.nn.Identity()

        #Variable that will hold the features
        self.features = None
        
        #Register the hook to capture features at output of last CNN layer
        self.inception.Mixed_7c.register_forward_hook(self.capture_features_hook)

        #Train the feature map, the rest depends on train_CNN
        for name, param in self.inception.named_parameters():
            if 'Mixed_7c' in name:
                param.requires_grad_(True)
            else:
                param.requires_grad_(train_cnn)


    def capture_features_hook(self, module, input, output):
        self.features = output #update feature 


    def forward(self, images):
        """
        Take images and return feature maps of size (batch, height*width)
        """
        _ = self.inception(images)  #Pass through the inception network
        batch, feature_maps, size_1, size_2 = self.features.size()  #self.features contain the feature map of size (batch size, 2048, 8,8)
        features = self.features.permute(0, 2, 3, 1)
        features = features.view(batch, size_1*size_2, feature_maps) #resize to (batch size, h*w, feature_maps)

        return features


class BahdanauAttention(torch.nn.Module):
    """
    Adaptive Attention Module

    Args:
        feature_dim (int): Dimension of feature maps (h*w)
        hidden_dim (int): Dimension of hidden states
        output_dim (int, optional): Dimension of output, default to 1
    """
    def __init__(self, feature_dim:int, hidden_dim:int, output_dim:int = 1):
        super(BahdanauAttention, self).__init__()
         # fully-connected layer to learn first weight matrix Wa
        self.W_a = torch.nn.Linear(feature_dim, hidden_dim)
        # fully-connected layer to learn the second weight matrix Ua
        self.U_a = torch.nn.Linear(hidden_dim, hidden_dim)
        # fully-connected layer to produce score (output), learning weight matrix va
        self.v_a = torch.nn.Linear(hidden_dim, output_dim)


    def forward(self, features, hidden_state):
        """
        Args:
            features: image features from Encoder
            hidden_state: hidden state output for Decoder

        Returns:
            context: context vector with size (1,2048)
            atten_weights: probabilities of feature relevance 
        """
        #add additional dimension to a hidden (required for summation) 
        hidden_state = hidden_state.unsqueeze(1) #(batch size, 1, seq length)

        atten_1 = self.W_a(features) #(batch size, h*w, hidden_dim)
        atten_2 = self.U_a(hidden_state) #(batch size, 1, hidden_dim)

        #apply tangent to combine result from 2 fc layers
        atten_tan = torch.tanh(atten_1+atten_2)
        atten_score = self.v_a(atten_tan) #(batch size, hidden_dim)
        atten_weight = torch.nn.functional.softmax(atten_score, dim = 1) #get softmax probablilities

        #multiply each vector with its softmax score and sum to get attention context vector
        context = torch.sum(atten_weight * features,  dim = 1) #size of context equals to a number of feature maps
        atten_weight = atten_weight.squeeze(dim=2)
        
        return context, atten_weight


class DecoderRNN(torch.nn.Module):
     """
     LSTM decoder model

     Args:
          feature_dim (int): Feature Map dimension (h*w)
          embed_size (int): Embedding dimension to embed words
          hidden_size (int): Hidden state dimension for LSTM
          vocab_size (int): Total number of unique vocab
          drop_prob (float, optional): Dropout layer probability, deafults to 0.5
          sample_temp (float, optional): Scale outputs before softmax to allow the model to be more picky as the differences are exaggerated. Defaults to 0.5
     """
     def __init__(self, feature_dim:int, embedding_dim:int, hidden_dim:int, vocab_size:int, drop_prob:float=0.5, sample_temp:float=0.5):
          super(DecoderRNN, self).__init__()
          
          self.feature_dim = feature_dim
          self.embedding_dim = embedding_dim
          self.hidden_dim = hidden_dim
          self.vocab_size = vocab_size
          self.sample_temp = sample_temp #scale the outputs b4 softmax

          #layers

          #embedding layer that turns words into index 
          self.embeddings = torch.nn.Embedding(vocab_size, embedding_dim)
          #lstm layer that takes in feature + embedding (image + caption) and output hidden_dim
          self.lstm = torch.nn.LSTMCell(embedding_dim + feature_dim, hidden_dim)
          #fc linear layer that predicts next word
          self.fc = torch.nn.Linear(hidden_dim, vocab_size)
          #attention layer
          self.attention = BahdanauAttention(feature_dim, hidden_dim)
          #dropout layer
          self.drop = torch.nn.Dropout(p=drop_prob)
          #initialisation of fully-connected layers
          self.init_h = torch.nn.Linear(feature_dim, hidden_dim) #initiialising hidden state and cell memory using avg of feature
          self.init_c = torch.nn.Linear(feature_dim, hidden_dim)

     def init_hidden(self, features):
          """
          Initializes hidden state and cell memory using average feature vector
          Args:
               features: feature map of the image
          Returns:
               h0: initial hidden state (short-term memory)
               c0: initial cell state (long-term memory)
          """
          mean_annotations = torch.mean(features, dim = 1) #getting average of the features
          h0 = self.init_h(mean_annotations)
          c0 = self.init_c(mean_annotations)
          return h0, c0

     def forward(self, features, captions, device:str, sample_prob:float=0.2):
          """
          Args:
               features: feature map of image
               captions: true caption of image
               device (str): cuda or cpu
               sample_prob (float, optional): Probability for auto-regressive RNN where they train on RNN output rather than true layer, defaults to 0.2

          """
          embed = self.embeddings(captions)
          h,c = self.init_hidden(features)
          batch_size = captions.size(0) #captions: (batch size, seq length)
          seq_len = captions.size(1) 
          feature_size = features.size(1) #features: (batch size, size, 2048)

          #storage of outputs and attention weights of lstm
          outputs = torch.zeros(batch_size, seq_len, self.vocab_size).to(device)
          atten_weights = torch.zeros(batch_size, seq_len, feature_size).to(device)

          #scheduled sampling for training, using the models output to train instead of using the true output
          #autoregressive RNN training, only when length of seq > 1 (cannot be first word)
          for t in range(seq_len):
               s_prob = 0.0 if t==0 else sample_prob
               use_sampling = np.random.random() < s_prob

               if not use_sampling: #no sampling
                    word_embeddings = embed[:, t, :] #embedding until word t
               
               context, atten_weight = self.attention(features,h)
               inputs = torch.cat([word_embeddings, context], 1) #embed captions and features for next lstm state
               h, c = self.lstm(inputs, (h,c)) #pass through lstm
               output = self.fc(self.drop(h))
               
               if use_sampling: #using predicted word instead of true output
                    scaled_output = output/self.sample_temp #using scaling temp to amplify the values
                    #this way softmax will have a larger difference in values
                    #makes the model more selective of whats its picking 
                    
                    scoring = torch.nn.functional.log_softmax(scaled_output, dim=1)
                    top_idx = scoring.topk(1)[1]
                    word_embeddings = self.embeddings(top_idx).squeeze(1) #update word embeddings with predicted instead of actual
               
               #update results
               outputs[:,t,:] = output
               atten_weights[:, t, :] = atten_weight

          return outputs, atten_weights


class InceptV3EncoderAttentionDecoder(torch.nn.Module):
    """
    InceptionV3 Encoder with Attention and LSTM Decoder

    Args:
        finetuned_model: Finetuned InceptionV3 model, else None
        feature_dim (int): Feature Map dimension (h*w)
        embedding_dim (int): Embedding dimension to embed words
        hidden_dim (int): Hidden state dimension for LSTM
        vocab_size (int): Total number of unique vocab
        device (str): cuda or cpu
        train_cnn (boolean, optional): Determines if inceptionCNN model is unfreezed. Defaults to False.
        drop_prob (float, optional): Dropout layer probability. Defaults to 0.5.
        sample_temp (float, optional): Scale outputs before softmax to allow the model to be more picky as the differences are exaggerated. Defaults to 0.5
    """
    def __init__(self, finetuned_model, feature_dim:int, embedding_dim:int, hidden_dim:int, vocab_size:int, device:str, train_cnn:bool=False, drop_prob:float=0.5, sample_temp:float=0.5):
        super(InceptV3EncoderAttentionDecoder, self).__init__()
        self.encoder = InceptionV3EncoderCNN(finetuned_model, train_cnn)
        self.decoder= DecoderRNN(feature_dim, embedding_dim, hidden_dim, vocab_size, drop_prob, sample_temp)
        self.sample_temp = sample_temp
        self.device = device

    def forward(self, image, captions):
        features = self.encoder(image)
        outputs, atten_weights = self.decoder(features, captions, self.device, self.sample_temp)
        return outputs, atten_weights
    

    #for inference
    def caption_image(self, image, vocabulary:Vocabulary, device:str, max_length:int=50):
        """
        Generate caption using a greedy algorithm based on image input

        Args:
            image: image input
            vocabulary (Vocabulary): Vocabulary to decode predictions
            device (str): cuda or cpu
            max_length (int, optional): Max length of generated captions. Defaults to 50.

        Returns:
            captions: string caption in a list
            atten_weights: probabilities of feature relevance 
        """
        self.encoder.eval()

        result_caption = []
        result_weights = []

        with torch.no_grad(): #no training
            input_word = torch.tensor(1).unsqueeze(0).to(device)
            result_caption.append(1)
            features = self.encoder(image)
            h, c = self.decoder.init_hidden(features)

            for _ in range(max_length):
                embedded_word = self.decoder.embeddings(input_word)
                context, atten_weight = self.decoder.attention(features, h)
                # input_concat shape at time step t = (batch, embedding_dim + context size)
                input_concat = torch.cat([embedded_word, context],  dim = 1)
                h, c = self.decoder.lstm(input_concat, (h,c))
                h = self.decoder.drop(h)
                output = self.decoder.fc(h) 
                scoring = torch.nn.functional.log_softmax(output, dim=1)
                top_idx = scoring[0].topk(1)[1]
                result_caption.append(top_idx.item())
                result_weights.append(atten_weight)
                input_word = top_idx

                if (len(result_caption) >= max_length or vocabulary.itos[input_word.item()] == "<EOS>"):
                    break

            return [vocabulary.itos[idx] for idx in result_caption], result_weights


class PreTrainedCNNModels(torch.nn.Module):
    def __init__(self, model_type:str, num_unfreeze:int, num_class:int):
        super(PreTrainedCNNModels, self).__init__()
        """
        Class that contains InceptionV3, Resnet50, Resnet152, EfficientNet, DenseNet, VGG16, MaxVit fine tuned models

        Args:
            model_type (str): Determines which pre-trained models to use
                              Must be: InceptionV3, Resnet50, Resnet152, EfficientNet, DenseNet, VGG16, MaxVit
            num_unfreeze (int): Number of layers to unfreeze and finetune
            num_class (int): Number of output classes for the classification
        """
        #selecting model type
        if model_type == 'InceptionV3':
            self.model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
            self.model.aux_logits = False

        elif model_type == 'Resnet50':
            self.model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

        elif model_type == 'Resnet152':
            self.model = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)

        elif model_type == 'EfficientNet':
            self.model = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.DEFAULT)

        elif model_type == 'DenseNet':
            self.model = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
        
        elif model_type == 'VGG16':
            self.model = models.vgg16_bn(weights=models.VGG16_BN_Weights.DEFAULT)

        elif model_type == 'MaxVit':
            self.model = models.maxvit_t(weights=models.MaxVit_T_Weights.DEFAULT)
        
        else:
            raise Exception("Invalid model type chosen. Please select one of the following\n[InceptionV3, Resnet50, Resnet152, EfficientNet, DenseNet, VGG16, MaxVit]")

        
        #modifying final layer
        if model_type in ['InceptionV3', 'Resnet50', 'Resnet152']:
            self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_class)

        elif model_type == 'DenseNet':
            self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, num_class)

        else:
            self.model.classifier[-1] = torch.nn.Linear(self.model.classifier[-1].in_features, num_class)


        model_paramteres = list(self.model.parameters())
        #unfreeze last num_unfreeze layers
        for param in model_paramteres[-num_unfreeze:]:
            param.requires_grad = True

        #freeze rest of the layers
        for param in model_paramteres[:-num_unfreeze]:
            param.requires_grad = False


    def forward(self, images):
        return self.model(images)
    

class PromptTemplate:
    system_prompt = None
    user_messages = []
    model_replies = []

    def __init__(self, system_prompt=None):
        self.system_prompt = system_prompt

    def add_user_message(self, message: str, return_prompt=True):
        self.user_messages.append(message)
        if return_prompt:
            return self.build_prompt()

    def add_model_reply(self, reply: str, includes_history=True, return_reply=True):
        reply_ = reply.replace(self.build_prompt(), "") if includes_history else reply
        self.model_replies.append(reply_)
        if len(self.user_messages) != len(self.model_replies):
            raise ValueError(
                "Number of user messages does not equal number of system replies."
            )
        if return_reply:
            return reply_

    def get_user_messages(self, strip=True):
        return [x.strip() for x in self.user_messages] if strip else self.user_messages

    def get_model_replies(self, strip=True):
        return [x.strip() for x in self.model_replies] if strip else self.model_replies

    def clear_chat_history(self):
        self.user_messages.clear()
        self.model_replies.clear()

    def build_prompt(self):
        if self.user_messages == [] and self.model_replies == []:
            return f"<s>[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>> [/INST]</s>"
        
        elif len(self.user_messages) != len(self.model_replies) + 1:
            raise ValueError(
                "Error: Expected len(user_messages) = len(model_replies) + 1. Add a new user message!"
            )

        if self.system_prompt is not None:
            SYS = f"[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>>"
        else:
            SYS = ""

        CONVO = ""
        SYS = "<s>" + SYS
        for i in range(len(self.user_messages) - 1):
            user_message, model_reply = self.user_messages[i], self.model_replies[i]
            conversation_ = f"{user_message} [/INST] {model_reply} </s>"
            if i != 0:
                conversation_ = "[INST] " + conversation_
            CONVO += conversation_

        CONVO += f"[INST] {self.user_messages[-1]} [/INST]"

        return SYS + CONVO

In [3]:
os.chdir("../models/mustango")
from mustango import Mustango
os.chdir('../../notebooks')

### Testing

In [4]:
class playWrite():
    def __init__(self, 
                 device:str,
                 vocab_path:str,
                 image_caption_path:str,
                 hg_access_token:str=None,
                 llama_model_path: str=None,
                 llama_tokenizer_path:str=None
                 ):
        
        self.vocab = self._load_vocab(vocab_path)
        self.image_caption = self._load_image_caption(image_caption_path, device)
        self.llama_model, self.llama_tokenizer = self._load_llama(hg_access_token, llama_model_path, llama_tokenizer_path)
        self.mustango = self._load_mustango()
        self.device = device  

    def _load_vocab(self, filepath:str):
        file = open(filepath, 'rb')
        vocab = pickle.load(file)
        if isinstance(vocab, Vocabulary):
            print("Vocabulary Loaded Successfully")
            return vocab
        else:
            raise Exception("Invalid Vocabulary")
    
    def _load_image_caption(self, model_path, device):
        try:
            model = torch.load(model_path).to(device)
            print("Image Captioning Model Loaded Successfully")
            return model
        
        except Exception as e:
            raise Exception(f"Unable to load torch model, reasons: {e}")

    def _load_llama(self, hg_access, llama_model_path, llama_tokenizer_path):
        if hg_access != None:
            try:
                print("Loading Llama Models")
                llama_model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token=hg_access, torch_dtype=torch.bfloat16, device_map="auto")
                llama_tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token=hg_access)
                print("Llama Loaded Successfully")
                return llama_model, llama_tokenizer
            
            except Exception as e:
                raise Exception(f"Unable to load Llama model from hugging face, reasons: {e}")


        elif llama_model_path != None and llama_tokenizer_path != None:
            try:
                print("Loading Llama Models")
                llama_model = AutoModelForCausalLM.from_pretrained(llama_model_path, torch_dtype=torch.bfloat16, device_map="auto")
                llama_tokenizer = AutoTokenizer.from_pretrained(llama_tokenizer_path)
                print("Llama Loaded Successfully")
                return llama_model, llama_tokenizer
            
            except Exception as e:
                raise Exception(f"Unable to load Llama model from directory, reasons: {e}")
        
        else:
            raise Exception("No Llama resources provided")

    def _load_mustango(self):
        try:
            mustango = Mustango("declare-lab/mustango")
            print("Mustango Loaded Successfully")
            return mustango
        
        except Exception as e:
            raise Exception(f"Unable to load mustango, reasons: {e}")


    def caption_image(self, image:bytes, model:InceptV3EncoderAttentionDecoder, vocab:Vocabulary, device:str, max_length:int=50):
        """
        Function to caption uploaded image from streamlit

        Args:
            image (bytes): uploaded image by users in bytes based on streamlit file reading format
            model: image captioning model 
            vocab (Vocabulary): image captioning model vocabulary
            device (str): cuda or cpu
            max_length (int, optiona;): max length of generated captions, default to 50

        Returns:
            generated captions: string of image caption
        """
        
        pil_image = Image.open(io.BytesIO(image)).convert('RGB') #convert bytes to image

        #setup transform to convert image to readable tensor for the model
        transform = transforms.Compose([
            transforms.Resize((299,299)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
        ])

        transformed_image = transform(pil_image).unsqueeze(0) #unsqueeze to add a batch size of 1
        generated_captions, attention = model.caption_image(transformed_image.to(device), vocab, device, max_length) 

        return " ".join(generated_captions[1:-1])



    def generate_music_prompt(self, caption:str, text_prompt:str, llama_model:AutoModelForCausalLM, llama_tokenizer:AutoTokenizer, device:str):
        """
        Function to generate music prompt by merging image captions and text prompts from user with llama2-7b

        Args:
            caption (str): generated image caption
            text_prompt (str): text prompt from user
            llama_model (AutoModelForCausalLM): llama model
            llama_tokenizer (AutoTokenizer): llama tokenizer
            device (str): cuda or cpu

        Returns:
            music_prompt: string of llama results
        """
        prompt = f"""=============context================
        {caption},
        {text_prompt},
        =========================================

        Take the following 2 context and merge them to create a textual prompt for music generation. Your prompt should be a single line. Do not give prompts that suggest increasing intensity.
        The prompt should contain the atmosphere of the song, where the song would fit environment wise and chord progression you have come up with. I have given you some example prompts, format your prompt similarly to them but do not copy their content.
        Example prompts: This is a live performance of a classical music piece. There is an orchestra performing the piece with a violin lead playing the main melody. The atmosphere is sentimental and heart-touching. This piece could be playing in the background at a classy restaurant.
        The song is an instrumental. The song is in medium tempo with a classical guitar playing a lilting melody in accompaniment style. The song is emotional and romantic. The song is a romantic instrumental song.
        This is a new age piece. There is a flute playing the main melody with a lot of staccato notes. The rhythmic background consists of a medium tempo electronic drum beat with percussive elements all over the spectrum. There is a playful atmosphere to the piece.

    """
        promptGenerator =  PromptTemplate(system_prompt=prompt)
        llama_prompt = promptGenerator.build_prompt()
        config = GenerationConfig(max_new_tokens=1024,
                                do_sample=True,
                                top_k = 10,
                                num_return_sequences = 1,
                                return_full_text = False,
                                temperature = 0.1,
                                )
            
        encoded_input = llama_tokenizer.encode(llama_prompt, return_tensors='pt', add_special_tokens=False).to(device)
        results = llama_model.generate(encoded_input, generation_config=config)
        decoded_output = llama_tokenizer.decode(results[0], skip_special_tokens=True)
        response = decoded_output.split("[/INST]")[-1].strip()
        
        #cleaning up the response to remove additional prompts
        quote_index = response.find('"')
        last_quote_index = response.rfind('"')
        if quote_index != -1 and last_quote_index != -1: #if the result is in quotation marks
            music_prompt = response[quote_index+1:last_quote_index]

        else:
            colon_index = response.rfind(":") #getting text in the format of prompt:\n{actual prompt}
            music_prompt = response[colon_index+3:] #remove the \n as well
        return music_prompt


    def generate_music(self, music_prompt:str, model:Mustango, steps:int, guidance:int):
        """
        Function to generate music from music prompt with mustango

        Args:
            music_prompt (str): text prompt to generate music with mustango
            model (Mustango): mustango model
            steps (int): Number of epochs the music generation model iterates through
            guidance (int): How much guidance needed for the model

        Returns:
            generated music
        """
        music = model.generate(music_prompt, steps, guidance)
        return music

In [5]:
def generate(playwrite:playWrite, byte_image:bytes, text_prompt:str, max_length:int=50, steps:int=100, guidance:int=3, delete_model:bool=True):
    """
    Overall Function to generate music from image and textual prompts

    Args:
        playwrite (playWrite): Class instance with all models initiated
        byte_image (bytes): Image prompt
        text_prompt (str): Textural prompt
        max_length (int, optional): Maximum caption length, defaults to 50
        steps (int, optional): Number of epochs the music generation model iterates through, defaults to 100
        guidance (int, optional): How much guidance needed for the model, defaults to 3
        delete_mode (bool, optional): Delete models after they are used, mainly used for a memory situation as it requires models to be re-initialiased. Defaults too True

    Returns:
        generated music
    """

    image_caption = playwrite.caption_image(image=byte_image,
                                            model=playwrite.image_caption,
                                            vocab=playwrite.vocab,
                                            device=playwrite.device,
                                            max_length=max_length
                                            )
    print(f"Image Caption: {image_caption}")
    if delete_model:
        del playwrite.image_caption
        torch.cuda.empty_cache()
        gc.collect()
    music_prompt = playwrite.generate_music_prompt(caption=image_caption,
                                                   text_prompt=text_prompt,
                                                   llama_model=playwrite.llama_model,
                                                   llama_tokenizer=playwrite.llama_tokenizer,
                                                   device=playwrite.device)
    print(f"Music Prompt: {music_prompt}")
    if delete_model:
        del playwrite.llama_model
        torch.cuda.empty_cache()
        gc.collect()
    music = playwrite.generate_music(music_prompt=music_prompt,
                                     model=playwrite.mustango,
                                     steps=steps,
                                     guidance=guidance)
    
    return music

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


playwrite = playWrite(device=device,
                      vocab_path='../resources/Vocabulary.pkl',
                      image_caption_path='../models/image_captioning/model.pt',
                      hg_access_token=None,
                      llama_model_path='../models/llama/model',
                      llama_tokenizer_path='../models/llama/tokenizer')

Vocabulary Loaded Successfully
Image Captioning Model Loaded Successfully
Loading Llama Models


Loading checkpoint shards: 100%|██████████| 3/3 [00:18<00:00,  6.19s/it]


Llama Loaded Successfully


Fetching 13 files: 100%|██████████| 13/13 [00:00<?, ?it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of DebertaV2ForTokenClassificationRegression were not initialized from the model checkpoint at microsoft/deberta-v3-large and are newly initialized: ['hidden2.bias', 'regressor.bias', 'classifier.weight', 'hidden2.weight', 'regressor.weight', 'hidden1.weight', 'classifier.bias', 'hidden1.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  fft_window = pad_center(fft_window, filter_length)
  mel_basis = librosa_mel_fn(


UNet initialized randomly.
Successfully loaded checkpoint from: declare-lab/mustango
Mustango Loaded Successfully


In [8]:
with open("../input/Test/ce3273386dd20ea00320c54667c13c9048ea17c9-hq-1069880.jpeg", "rb") as uploaded_image:
    f = uploaded_image.read()
    b = bytearray(f)

text_prompt = "light-hearted fun platform adventure game where the player explore different areas to collect stars and fight bosses"

generated_music = generate(playwrite=playwrite,
                           byte_image=b,
                           text_prompt=text_prompt,
                           max_length=50,
                           steps=150,
                           guidance=3,
                           delete_model=True)

IPython.display.Audio(data=generated_music, rate=16000)

Image Caption: a man is in a black jacket standing on a snowy mountain .
Music Prompt: This is an adventurous platformer game set on a snowy mountain, with a light-hearted and fun atmosphere. The player must explore different areas to collect stars and fight bosses, creating a sense of excitement and wonder. The music should reflect this atmosphere, with a medium tempo and a mix of classical and electronic elements to create a playful and emotional sound. Imagine a classical guitar playing a lilting melody in accompaniment style, with a medium tempo electronic drum beat providing a rhythmic background and percussive elements adding depth and complexity. The overall feeling should be sentimental and heart-touching, with a sense of wonder and excitement.


  hidden_states = F.scaled_dot_product_attention(
