In [None]:
from src.util_tools import compute_cross_entropy, compute_ortho_loss
from tools.project import INPUT_PATH, LOGS_PATH, OUTPUT_PATH, MODELS_PATH

import audiocraft
from audiocraft.models import MusicGen
from audiocraft.utils.notebook import display_audio
import torch
from torch.onnx.symbolic_opset9 import tensor
from torchviz import make_dot
import typing as tp
from audiocraft.modules.conditioners import ConditioningAttributes
import tqdm
import torch
from audiocraft.data.audio import audio_read, audio_write
from audiocraft.data.audio_utils import convert_audio_channels, convert_audio
import umap
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from torch.utils.tensorboard import SummaryWriter
from sklearn.decomposition import PCA
from random import shuffle
from torch.utils.data import TensorDataset, random_split, DataLoader
from audioldm_eval.metrics.fad import FrechetAudioDistance
import os
import contextlib
import io
import torchaudio
import random
def count_directories(path):
    import os
    if not os.path.exists(path):
        return 0
    return sum(os.path.isdir(os.path.join(path, entry)) for entry in os.listdir(path))

letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
EXP_NUM = count_directories(LOGS_PATH('textual-inversion'))+1
EXAMPLES_LEN = 5
BATCH_SIZE = 5
N_TOKENS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
torch.cuda.is_available()

In [None]:
model = MusicGen.get_pretrained('facebook/musicgen-style')


In [None]:
model.set_generation_params(
    duration=5, # generate 8 seconds, can go up to 30
    use_sampling=True, 
    top_k=250,
    cfg_coef=3., # Classifier Free Guidance coefficient 
    cfg_coef_beta=8., # double CFG is necessary for text-and-style conditioning
                   # Beta in the double CFG formula. between 1 and 9. When set to 1 it is equivalent to normal CFG. 
                   # When we increase this parameter, the text condition is pushed. See the bottom of https://musicgenstyle.github.io/ 
                   # to better understand the effects of the double CFG coefficients. 
)

model.set_style_conditioner_params(
    eval_q=2, # integer between 1 and 6
              # eval_q is the level of quantization that passes
              # through the conditioner. When low, the models adheres less to the 
              # audio conditioning
    excerpt_length=4.5, # the length in seconds that is taken by the model in the provided excerpt, can be                 
                       # between 1.5 and 4.5 seconds but it has to be shortest to the length of the provided conditioning
    )

In [None]:
concept="8bit"
examples = os.listdir(INPUT_PATH('textual-inversion-v3', 'data', 'valid', f'{concept}', 'fad'))
random.shuffle(examples)
songs = []
for fname in tqdm.tqdm(examples[:5]):
    melody, sr = audio_read(INPUT_PATH('textual-inversion-v3', 'data', 'valid', f'{concept}', 'fad', fname), pad=True, duration=5)
    songs.append(melody[0][None].expand(3, -1, -1))
songs = torch.cat(songs, dim=0)
results = model.generate_with_chroma([None]*len(songs), songs, sr, progress=True)
for a_idx in range(results.shape[0]):
    music = results[a_idx].cpu()
    music = music/np.max(np.abs(music.numpy()))
    path = OUTPUT_PATH("musigen-style", concept, 'temp', f'music_p{a_idx}')
    audio_write(path, music, model.cfg.sample_rate)

In [None]:
exp_name = "polar-totem-39"
torch.load(MODELS_PATH("textual-inversion-v3", f"{exp_name}-best.pt"))

In [None]:
def append_new_tokens(tokenizer, weights, data_by_concept):
    for concept, data in data_by_concept.items():
        assert len(data['tokens']) == data['embeds'].shape[0]
        for i, token in enumerate(data['tokens']):
            tokenizer.add_tokens(token)
            token_id = tokenizer.convert_tokens_to_ids([token])[0]
            with torch.no_grad():
                weights[token_id] = data['embeds'][i]
class TokensProvider:
    def __init__(self, num: int):
        self.num = num
    
    def get(self, base: str):
        return [f'<{base}_{x}>' for x in range(self.num)]
    
    def get_str(self, base: str):
        return ' '.join(self.get(base))
concept="8bit"
exp_name = "polar-totem-39"
embedings = torch.load(MODELS_PATH("textual-inversion-v3", f"{exp_name}-best.pt"))
model = MusicGen.get_pretrained("facebook/musicgen-small")
model.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=5
)
tokens_provider = TokensProvider(embedings[concept]['embeds'].shape[0])
text_conditioner=list(model.lm.condition_provider.conditioners.values())[0]
tokenizer=text_conditioner.t5_tokenizer
text_model=text_conditioner.t5

append_new_tokens(tokenizer, text_model.shared.weight, {concept: {
    'tokens': tokens_provider.get(concept),
    'embeds': embedings[concept]['embeds']
    }})
text_model.resize_token_embeddings(len(tokenizer))
ti_res=model.generate([f'In the style of {tokens_provider.get_str(concept)}']*5)
for a_idx in range(results.shape[0]):
    music = results[a_idx].cpu()
    music = music/np.max(np.abs(music.numpy()))
    path = OUTPUT_PATH("musigen-style", concept, 'temp', f'music_p{a_idx}')
    audio_write(path, music, model.cfg.sample_rate)

In [None]:
fad = FrechetAudioDistance(verbose=True, use_pca=True, use_activation=True)


In [None]:
calc_fad = lambda path: list(fad.score(INPUT_PATH('textual-inversion-v3', 'data', 'valid', f'{concept}', 'fad'), path, recalculate=True).values())[0]*1e-5
print("STYLE:", calc_fad(OUTPUT_PATH("musigen-style", concept, 'temp')))
print("TI:", calc_fad(OUTPUT_PATH("textual-inversion-v3", concept, 'temp')))