In [1]:
# from src.data import ConceptDataset, get_ds
from src.util_tools import compute_cross_entropy, compute_ortho_loss

from torch.utils.data import DataLoader, default_collate
import tqdm
import pytorch_lightning as L
from datasets import load_dataset
from tools.project import INPUT_PATH, LOGS_PATH, OUTPUT_PATH, MODELS_PATH
import torch
import torch.nn.functional as F
from torch.optim import Adam
import numpy as np
import contextlib
import io

from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_read, audio_write
from audioldm_eval.metrics.fad import FrechetAudioDistance

EXAMPLES_LEN = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else 'cpu')


In [29]:
from tools.project import INPUT_PATH, LOGS_PATH, OUTPUT_PATH, MODELS_PATH, RAW_PATH
import torch
import os
from datasets import Audio, load_dataset
from datasets import load_dataset
from random import choice
import tqdm

train_desc = [
    "the sound of %s",
    "pure %s audio",
    "the recorded %s sound",
    "%s audio sample",
    "recording of %s",
    "high fidelity %s audio",
    "%s sound clip",
    "audio of %s",
    "captured %s sound",
    "%s audio recording",
    "%s recording capture",
    "audio file of %s",
    "isolated %s sound",
    "distinct %s recording",
    "quality %s audio file",
    "high-definition %s sound",
    "the sound recording of %s",
    "audio segment of %s",
    "raw %s audio",
    "%s sound snippet",
    "%s audio track",
    "%s sound fragment",
    "audio recording for %s",
    "sound capture of %s",
    "%s audio file sample",
    "the isolated %s recording",
    "%s recorded audio",
    "pure capture of %s",
    "audio segment capture of %s",
    "the sample of %s audio",
    "the sound file of %s",
    "full recording of %s",
    "%s audio archive",
    "%s sound collection",
    "captured audio of %s",
    "%s isolated sound file",
    "the audio snippet of %s",
    "clean audio of %s",
    "%s audio capture",
    "%s sound extract"
]

val_desc = [
    "audio capture of %s",
    "%s sound recording",
    "pristine %s audio",
    "clear %s recording",
    "the audio of %s",
    "%s audio sample capture",
    "the recorded sound of %s",
    "sample of %s audio",
    "%s audio segment",
    "recorded audio of %s",
    "%s audio",
    "distinct sound of %s",
    "unprocessed %s audio",
    "%s recording",
    "high clarity %s sound",
    "%s recording sample",
    "audio portion of %s",
    "sampled audio of %s",
    "unfiltered %s audio",
    "audio segment for %s",
    "clip of %s audio",
    "the audio snippet for %s",
    "audio portion of %s",
    "%s recorded segment",
    "sampled sound of %s",
    "%s captured in audio",
    "audio excerpt of %s",
    "full audio capture of %s",
    "%s sound archive",
    "audio track of %s",
    "%s in sound format",
    "%s sound recording sample",
    "captured file of %s sound",
    "the distinct sound of %s",
    "high quality %s sound sample",
    "%s in captured audio",
    "pure audio of %s",
    "clean capture of %s audio",
    "recorded file of %s",
    "audio format of %s"
]

def get_ds():
    return load_dataset('json', data_files={
                'valid': INPUT_PATH('textual-inversion-v3', 'metadata_val.json'),
                'train': INPUT_PATH('textual-inversion-v3', 'metadata_train.json')
                })

class TokensProvider:
    def __init__(self, num: int):
        self.num = num
    
    def get(self, base: str):
        return [f'<{base}_{x}>' for x in range(self.num)]
    
    def get_str(self, base: str):
        return ' '.join(self.get(base))

class PromptProvider:
    def __init__(self, prompts_template):
        self.template = prompts_template
    
    def get(self, *args):
        return choice(self.template) % args

class ConceptDataset(torch.utils.data.Dataset):
    def __init__(self, ds, tokenizer, split: str, sr: int=32000, tokens_num: int=1, music_len: int=100):
        self.ds = ds
        self.tokenizer = tokenizer

        if self.ds.cache_files:
            self.base_dir = os.path.dirname(self.ds.cache_files[0]["filename"])
        else:
            raise ValueError("No cache files found in the dataset")
        self.base_dir = INPUT_PATH('textual-inversion-v3')

        if split == 'valid':
            def map_path(x):
                x['audio'] = os.path.join(self.base_dir, x['audio_path'])
                return x
            self.ds = self.ds.map(map_path).cast_column('audio', Audio(sampling_rate=sr))

        self.encoded = {}
        self.tokens_num = tokens_num
        self.prompter = PromptProvider(val_desc if split == 'valid' else train_desc)
        self.tokens_provider = TokensProvider(tokens_num)
        self.music_len = music_len
        self.split = split
        self.concpets = None
        self.tokenized_prompts = {}
        self.tokens_ids = {}
    
    def __len__(self):
        return len(self.ds)
    
    def _random_slice(self, tensor):
        n, k = tensor.shape
        
        if self.music_len <= k:
            start_col = torch.randint(0, k - self.music_len + 1, (1,)).item()
            return tensor[:, start_col:start_col + self.music_len]
        else:
            padding = torch.zeros((n, self.music_len - k), device=tensor.device)
            return torch.cat((tensor, padding), dim=1)
    
    def __getitem__(self, idx):
        row = self.ds[idx]
        path = row['encoded_path']
        if path not in self.encoded:
            self.encoded[path] = torch.load(os.path.join(self.base_dir, path)).squeeze()
        y = path.replace("\\", "").split('/')[2]
        if y not in self.tokens_ids:
            self.tokens_ids[y] = self.tokenizer.convert_tokens_to_ids(list(self.tokens_provider.get(y)))
        prompt = self.prompter.get(self.tokens_provider.get_str(y))
        # if prompt not in self.tokenized_prompts:
        #     self.tokenized_prompts[prompt] = self.tokenizer([prompt], return_tensors='pt', padding=True, add_special_tokens=False)
        return {
            'encoded_music': self._random_slice(self.encoded[path]),
            'prompt': prompt,
            'new_token_ids': self.tokens_ids[y],
            **({} if self.split == 'train' else 
                {
                    'audio': row['audio']['array']
                })
        }
    
    def _get_concepts(self):
        unique_values = set()
        def collect_unique(batch):
            unique_values.update([x.replace("\\", "").split('/')[2] for x in batch['audio_path']])
        self.ds.map(collect_unique, batched=True, batch_size=1000)
        return unique_values
    
    def get_concepts(self):
        if self.concpets is None:
            self.concpets = self._get_concepts()
        return self.concpets
    
    def get_new_tokens(self) -> set[str]:
        res = set()
        for concept in self.get_concepts():
            res.update(self.tokens_provider.get(concept))
        return res
    
    def get_new_tokens_ids(self) -> set[int]:
        return self.tokenizer.convert_tokens_to_ids(self.get_new_tokens())

In [30]:
class ConceptDataModule(L.LightningDataModule):
    def __init__(self, tokenizer, tokens_num:int=5, music_len: int = 255, batch_size: int = 5):
        super().__init__()
        self.tokens_num = tokens_num
        self.music_len = music_len
        self.batch_size = batch_size
        self.tokenizer = tokenizer

    def prepare_data(self) -> None:
        get_ds()
    
    def setup(self, stage: str):
        print(stage)
        ds = get_ds()
        self.train_ds = ConceptDataset(ds['train'], self.tokenizer, 'train', tokens_num=self.tokens_num, music_len=self.music_len)
        self.val_ds = ConceptDataset(ds['valid'], self.tokenizer, 'valid', tokens_num=self.tokens_num, music_len=self.music_len)
    
    def get_new_tokens(self)->list[str]:
        new_tokens = self.train_ds.get_new_tokens()
        new_tokens.update(self.val_ds.get_new_tokens())
        return list(new_tokens)
    
    def collate_fn(self, batch):
        prompts = [item['prompt'] for item in batch]
        tokenized_prompts = self.tokenizer(prompts, return_tensors='pt', padding=True, add_special_tokens=False)
        for i, item in enumerate(batch):
            item['tokenized_prompt'] = {
                'input_ids': tokenized_prompts['input_ids'][i],
                'attention_mask': tokenized_prompts['attention_mask'][i],
            }
        collated_batch = default_collate(batch)
        collated_batch['batch_tokens'] = torch.unique(torch.cat(collated_batch['new_token_ids']))
        return collated_batch
    
    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_ds, batch_size=self.batch_size, collate_fn=self.collate_fn)
    
    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.val_ds, batch_size=self.batch_size, collate_fn=self.collate_fn)

dm = ConceptDataModule(tokenizer)
dm.setup('a')

# if tokenizer.add_tokens(dm.get_new_tokens()) > 0:
#     text_model.resize_token_embeddings(len(tokenizer))


dl=dm.val_dataloader()
for batch in tqdm.tqdm(dl):
    x, y, a = batch['encoded_music'], batch['prompt'], batch['audio']
dl=dm.train_dataloader()
for batch in tqdm.tqdm(dl):
    x, y = batch['encoded_music'], batch['prompt']

a


Map:   0%|          | 0/120 [00:00<?, ? examples/s]

  0%|          | 0/24 [00:01<?, ?it/s]


RuntimeError: stack expects each tensor to be equal size, but got [3839648] at entry 0 and [3839826] at entry 1

In [8]:
class TransformerTextualInversion(L.LightningModule):
    def __init__(self, text_model, tokenizer, music_model, music_model_conditioner, lamda_tokens, 
                 grad_amplify: float=10.0,
                 entropy_alpha: float=1e1,
                 ortho_alpha: float=1e-2
                 ):
        super().__init__()
        # self.save_hyperparameters()  # Saves all init arguments to the checkpoint
        self.grad_amplify = grad_amplify
        self.entropy_alpha = entropy_alpha
        self.ortho_alpha = ortho_alpha

        self.text_model = text_model
        self.tokenizer = tokenizer
        self.music_model = music_model
        self.fetch_new_tokens = lamda_tokens
        self.music_model_conditioner = music_model_conditioner

        
    def _init_text_model(self, new_tokens):
        if tokenizer.add_tokens(new_tokens) > 0:
            self.text_model.resize_token_embeddings(len(tokenizer))
        new_token_ids = tokenizer.convert_tokens_to_ids(new_tokens)
        with torch.no_grad():
            for new_token_id in new_token_ids:
                text_model.shared.weight[new_token_id] = text_model.shared.weight.mean(dim=0)
        def zero_existing_emb(grad):
            mask = torch.zeros_like(grad)
            for new_token_id in new_token_ids:
                mask[new_token_id] = self.grad_amplify
            return grad * mask

        self.text_model.shared.weight.register_hook(zero_existing_emb)
        
    def on_train_start(self):
        self._init_text_model(self.fetch_new_tokens())

    def forward(self, encoded_music, tokenized_prompt):
        mask = tokenized_prompt['attention_mask']
        with self.music_model_conditioner.autocast and torch.set_grad_enabled(True):
            x_e = self.text_model(**tokenized_prompt).last_hidden_state
        x_e = self.music_model_conditioner.output_proj(x_e.to(self.music_model_conditioner.output_proj.weight))
        x_e = (x_e * mask.unsqueeze(-1))
        with self.music_model.autocast:
            x = self.music_model.lm.compute_predictions(encoded_music, [], {'description': (x_e, mask)})
        return x

    def training_step(self, batch, batch_idx):
        music, prompt = batch['encoded_music'], batch['tokenized_prompt']
        out = self(music, prompt)
        ce_loss, _ = compute_cross_entropy(out.logits, music, out.mask)
        ortho_loss = compute_ortho_loss(self.text_model.shared.weight[batch['batch_tokens']])
        loss = self.entropy_alpha * ce_loss + self.ortho_alpha * ortho_loss
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("ortho_loss", ortho_loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        music, prompt = batch['encoded_music'], batch['tokenized_prompt']
        out = self(music, prompt)
        val_loss, _ = compute_cross_entropy(out.logits, music, out.mask)
        self.log("val_loss", val_loss, prog_bar=True)
        return val_loss

    def configure_optimizers(self):
        # Optimizer and learning rate scheduler setup
        optimizer = Adam([self.text_model.shared.weight], lr=1e-1)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return ([optimizer], 
                []
                )
class GenEvalCallback(L.Callback):
    def __init__(self, generation_concepts, fad, n_epochs=2):
        super().__init__()
        self.n_epochs = n_epochs
        self.concepts = generation_concepts
        self.fad = fad

    def on_validation_epoch_end(self, trainer, pl_module):
        if (trainer.current_epoch+1) % self.n_epochs == 0:
            print(f"Generation time at epoch {trainer.current_epoch + 1}")
            concept = self.concepts[0]
            response = pl_module.music_model.generate([f'In the style of {TokensProvider(5).get_str(concept)}']*3)
            for a_idx in range(response.shape[0]):
                music = response[a_idx].cpu()
                music = music/np.max(np.abs(music.numpy()))
                path = OUTPUT_PATH("textual-inversion-v3", concept, 'temp', f'music_p{a_idx}')
                audio_write(path, music, pl_module.music_model.cfg.sample_rate)
                pl_module.logger.experiment.add_audio(f"{concept} {a_idx}", music, trainer.global_step, sample_rate=pl_module.music_model.cfg.sample_rate)
            with contextlib.redirect_stdout(io.StringIO()):
                fd_score = self.fad.score(INPUT_PATH('textual-inversion-v3', 'data', 'valid', f'{concept}', 'audio'), OUTPUT_PATH("textual-inversion-v3", concept, 'temp'))
                os.remove(OUTPUT_PATH("textual-inversion-v3", concept, 'temp_fad_feature_cache.npy'))
                pl_module.log('FAD', list(fd_score.values())[0], trainer.global_step)

music_model = MusicGen.get_pretrained('facebook/musicgen-small')
music_model.set_generation_params(
	use_sampling=True,
	top_k=250,
	duration=EXAMPLES_LEN
)
text_conditioner=list(music_model.lm.condition_provider.conditioners.values())[0]
tokenizer=text_conditioner.t5_tokenizer
text_model=text_conditioner.t5
fad = FrechetAudioDistance()

dm = ConceptDataModule(tokenizer, music_len=255)
model = TransformerTextualInversion(text_model, tokenizer, music_model, text_conditioner, lambda: dm.get_new_tokens())
tb_logger = L.loggers.TensorBoardLogger(LOGS_PATH, name='textual-inversion-v3')
trainer = L.Trainer(accelerator='cpu', callbacks=[GenEvalCallback(['upbeat'], fad)], enable_checkpointing=False, logger=tb_logger)
trainer.fit(model, dm)

Using cache found in /home/mszawerda/.cache/torch/hub/harritaylor_torchvggish_master
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/mszawerda/.cache/pypoetry/virtualenvs/musicgen-ufgTm-Qc-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/home/mszawerda/.cache/pypoetry/virtualenvs/musicgen-ufgTm-Qc-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:72: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.


TrainerFn.FITTING


Map:   0%|          | 0/120 [00:00<?, ? examples/s]


  | Name                    | Type           | Params | Mode
------------------------------------------------------------------
0 | text_model              | T5EncoderModel | 109 M  | eval
1 | music_model_conditioner | T5Conditioner  | 787 K  | eval
------------------------------------------------------------------
110 M     Trainable params
0         Non-trainable params
110 M     Total params
441.664   Total estimated model params size (MB)
0         Modules in train mode
225       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/mszawerda/.cache/pypoetry/virtualenvs/musicgen-ufgTm-Qc-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


RuntimeError: stack expects each tensor to be equal size, but got [3839648] at entry 0 and [3839826] at entry 1