In [None]:
from scipy.stats import describe
# from src.util_tools import compute_contrastive_loss_with_labels
from tools.project import INPUT_PATH, LOGS_PATH, OUTPUT_PATH
from torch.nn import functional as F
import audiocraft
from audiocraft.models import MusicGen
from audiocraft.utils.notebook import display_audio
import torch
from gradio.cli.commands.components.publish import colors
from omegaconf import DictConfig
from torch import set_grad_enabled
from torch.onnx.symbolic_opset9 import tensor
from torchviz import make_dot
import typing as tp
from audiocraft.modules.conditioners import ConditioningAttributes
import tqdm
import torch
from audiocraft.data.audio import audio_read, audio_write
from audiocraft.data.audio_utils import convert_audio_channels, convert_audio
import umap
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from torch.utils.tensorboard import SummaryWriter
from sklearn.decomposition import PCA
from random import shuffle
from torch.utils.data import TensorDataset, random_split, DataLoader, Dataset
from dataclasses import dataclass
import random
from collections import namedtuple
from pytorch_metric_learning.losses import NTXentLoss

EXP_NUM = 1
EXAMPLES_LEN = 5
BATCH_SIZE = 5
EXAMPLES_NUM = 20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [None]:
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=EXAMPLES_LEN
)

In [None]:
new_token = "S*"
text_conditioner = list(model.lm.condition_provider.conditioners.values())[0]
tokenizer = text_conditioner.t5_tokenizer
text_model = text_conditioner.t5.to(DEVICE)
lm = model.lm
new_token_id = tokenizer.convert_tokens_to_ids(new_token)

if tokenizer.add_tokens([new_token]) == 1:
    text_model.resize_token_embeddings(len(tokenizer))
    with torch.no_grad():
        text_model.shared.weight[new_token_id] = text_model.shared.weight.mean(dim=0)

In [None]:
def get_music(song: str, idxs):
    songs = []
    for p in idxs:
        music, sr = audio_read(INPUT_PATH('textual-inversion', 'contrastive', song, f'p{p}.wav'), duration=EXAMPLES_LEN,
                               pad=True)
        music = music[None]
        songs.append(convert_audio(music, sr, 32000, 1))
    with torch.no_grad():
        encoded_music, _ = model.compression_model.encode(torch.concatenate(songs).to(DEVICE))
    return encoded_music


def load_music_to_pt():
    return torch.concatenate([get_music(range(i, i + 10)) for i in range(0, 230, 10)])

# torch.save(get_music('music_1', range(25)), INPUT_PATH('textual-inversion', 'contrastive', 'music_1.pt'))
# torch.save(get_music('music_10', range(26)), INPUT_PATH('textual-inversion', 'contrastive', 'music_10.pt'))

# dl = lambda x, s: DataLoader(x, batch_size=BATCH_SIZE, shuffle=s, pin_memory=True if torch.cuda.is_available() else False)
# torch.save(load_music_to_pt(), INPUT_PATH('textual-inversion', '8bit', 'encoded.pt'))

In [None]:
music_1_ts = torch.load(INPUT_PATH('textual-inversion', 'contrastive', 'music_1.pt'))[:EXAMPLES_NUM, :, :]
music_2_ts = torch.load(INPUT_PATH('textual-inversion', 'contrastive', 'music_10.pt'))[:EXAMPLES_NUM, :, :]

In [None]:
happy_adventurous_prompts = [
    "In the style of S*, a bright and cheerful journey begins",
    "S* audio of an energetic adventure through sunny lands",
    "A vibrant melody in the style of S*, setting off on a great quest",
    "Start of an epic journey, upbeat audio of S*",
    "S* journey's intro with lively, adventurous sounds",
    "An exciting adventure unfolds in the style of S*",
    "Audio of S*, upbeat and joyful for a new journey",
    "A happy, adventurous intro, S* soundtrack",
    "S* soundtrack for the first steps of a thrilling adventure",
    "Uplifting audio in the style of S* for a bold new quest",
    "Adventurous intro, in the energetic style of S*",
    "S* audio capturing the thrill of setting out on a journey",
    "Bright, adventurous music in the style of S*",
    "The start of a grand journey with S* in a joyful tone",
    "Audio of S* filled with excitement and adventure",
    "An exciting journey begins, S* theme music",
    "Upbeat and energetic, in the style of S*",
    "Audio of S*, inspiring an adventurous beginning",
    "S* style melody, bright and full of anticipation",
    "Joyful, adventurous sounds in the style of S*"
]
slow_melancholic_prompts = [
    "In the style of S*, a slow and melancholic melody",
    "S* audio of a somber and reflective tone",
    "A sorrowful tune in the style of S*, slow and gentle",
    "Melancholic audio of S*, capturing a deep sadness",
    "S* in a slow, reflective style, evoking nostalgia",
    "A slow, melancholic theme, in the style of S*",
    "Audio of S*, quiet and mournful",
    "S* soundtrack, slow and filled with melancholy",
    "A melancholic, emotional piece in the style of S*",
    "In the reflective style of S*, a slow, sad melody",
    "S* audio expressing deep emotion and sorrow",
    "Slow and melancholic, in the style of S*",
    "S* soundtrack evoking a deep sense of longing",
    "A gentle, melancholic piece, in the style of S*",
    "In the style of S*, a slow melody of quiet sadness",
    "Audio of S*, portraying a melancholy mood",
    "A somber theme in the style of S*, slow and emotional",
    "S* audio filled with gentle sorrow",
    "Melancholic and slow, in the reflective style of S*",
    "S* soundtrack, soft and melancholic, filled with longing"
]


In [None]:
res = tokenizer(happy_adventurous_prompts + slow_melancholic_prompts, return_tensors='pt', padding=True,
                add_special_tokens=False)
ids, mask = res['input_ids'], res['attention_mask']

In [None]:
out_1 = lm.compute_predictions(music_1_ts[:3], [ConditioningAttributes({'description': "S*"})])
out_2 = lm.compute_predictions(music_2_ts[:3], [ConditioningAttributes({'description': "S*"})])

In [None]:
def flatten(x, m, k):
    m_k = m[:, k, ...].contiguous().view(-1)
    x_k = x[:, k, ...].contiguous().view(-1, x.size(-1))
    return x_k[m_k]


a = flatten(out_1.logits, out_1.mask, 0)
b = flatten(out_2.logits, out_2.mask, 0)

In [None]:
a[0, 0, ...] * b[0, 0, ...]

In [None]:
ContrastiveRow = namedtuple('ContrastiveRow', 'song_1 prompt_1 song_2 prompt_2 example_type')


class PairMusic(Dataset):
    def __init__(self, ts_1: torch.tensor, ts_2: torch.tensor, pos=10, neg=20):
        self.ts_1 = ts_1
        self.ts_2 = ts_2
        self.pos = pos
        self.neg = neg

    def __len__(self):
        return 2 * self.pos + self.neg - 1

    def __getitem__(self, idx) -> ContrastiveRow:
        s1, p1, s2, p2 = random.sample(range(EXAMPLES_NUM), 4)
        if idx < self.pos:
            return ContrastiveRow(
                self.ts_1[s1],
                (ids[p1], mask[p1]),
                self.ts_1[s2],
                (ids[p2], mask[p2]),
                1
            )
        if idx < 2 * self.pos:
            return ContrastiveRow(
                self.ts_2[s1],
                (ids[p1], mask[p1]),
                self.ts_2[s2],
                (ids[p2], mask[p2]),
                1
            )
        return ContrastiveRow(
            self.ts_1[s1],
            (ids[p1], mask[p1]),
            self.ts_2[s2],
            (ids[p2], mask[p2]),
            0
        )


dl = lambda x, s: DataLoader(x, batch_size=BATCH_SIZE, shuffle=s,
                             pin_memory=True if torch.cuda.is_available() else False)
ds = PairMusic(music_1_ts, music_2_ts)
train_ds, val_ds = random_split(ds, [0.8, 0.2], generator=torch.Generator().manual_seed(42))
train_dl, val_dl = dl(train_ds, True), dl(val_ds, False)
batch = next(iter(train_dl))

In [None]:
new_token = "S*"
text_conditioner = list(model.lm.condition_provider.conditioners.values())[0]
tokenizer = text_conditioner.t5_tokenizer
text_model = text_conditioner.t5.to(DEVICE)
new_token_id = tokenizer.convert_tokens_to_ids(new_token)

if tokenizer.add_tokens([new_token]) == 1:
    text_model.resize_token_embeddings(len(tokenizer))
    with torch.no_grad():
        text_model.shared.weight[new_token_id] = text_model.shared.weight.mean(dim=0)

In [None]:
song_1, prompt_1, song_2, prompt_2, label = batch
mask = torch.concat([prompt_1[1], prompt_2[1]])
input_ids = torch.concat([prompt_1[0], prompt_2[0]])
song = torch.concat([song_1, song_2])
with text_conditioner.autocast and torch.set_grad_enabled(True):
    embeds = text_model(**{'input_ids': input_ids, 'attention_mask': mask}).last_hidden_state
embeds = text_conditioner.output_proj(embeds.to(text_conditioner.output_proj.weight))
embeds = (embeds * mask.unsqueeze(-1))
with model.autocast:
    out = lm.compute_predictions(song, [], {'description': (embeds, mask)})

In [None]:
out.mask.shape

In [None]:
from torch.nn import functional as F
import torch
import typing as tp


def compute_contrastive_loss_with_labels(
        logits: torch.Tensor,
        labels: torch.Tensor,
        temperature: float = 0.5
) -> torch.Tensor:
    """
	Compute contrastive loss between logits of song_1 and song_2 using labels to determine
	whether a pair is positive or negative.

	Args:
		logits (torch.Tensor): Stacked logits for song_1 and song_2 of shape [2 * B, K, T, card].
							   logits[:B] are for song_1, logits[B:] are for song_2.
		labels (torch.Tensor): Binary labels (1 for positive, 0 for negative) of shape [B].
							   These labels determine if the pair (song_1, song_2) is a positive or negative pair.
		temperature (float): Temperature scaling for contrastive loss.

	Returns:
		contrastive_loss (torch.Tensor): The computed contrastive loss.
	"""
    B, K, T, _ = logits.shape
    num_examples = logits.shape[0] // 2  # Assuming first half is for song_1, second half is for song_2
    assert logits.shape[0] % 2 == 0, "Logits should be stacked for song_1 and song_2."
    cl = torch.zeros([], device=logits.device)
    for k in range(K):
        logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1))
        mask_k = mask[:, k, ...].contiguous().view(-1)
        cl_logits = logits_k[mask_k]
        logits_s1 = cl_logits[:num_examples]
        logits_s2 = cl_logits[num_examples:]
        logits_s1 = F.normalize(logits_s1, dim=1)
        logits_s2 = F.normalize(logits_s2, dim=1)
        cos_sim = F.cosine_similarity(logits_s1, logits_s2)
        pos_loss = (1 - labels) * torch.pow(cos_sim, 2)
        neg_loss = labels * torch.pow(torch.clamp(1.0 - cos_sim, min=0.0), 2)
        cl += torch.mean(pos_loss + neg_loss)

    return cl / K


compute_contrastive_loss_with_labels(out.logits, label)