In [1]:
from musicgen.model.musicgen import MusicGen
import torch.distributed as dist
from accelerate import Accelerator
import torch
from musicgen.utils.torch_utils import print_once, get_rank, get_world_size
from utils import INPUT_PATH
from torch.utils.data import TensorDataset, random_split, DataLoader
import torch
import torch.nn as nn
import tqdm

EXAMPLES_LEN = 5
BATCH_SIZE = 5
N_TOKENS = 5
dist.init_process_group(backend="gloo", init_method="file:///tmp/sharedfile", rank=0, world_size=1)
use_gpu = dist.is_available()

amp_type = 'fp16' if use_gpu else 'fp32'
accelerator = Accelerator(mixed_precision=amp_type)
DEVICE = accelerator.device
rank = get_rank()
world_size = get_world_size()
print(f'Device (rank {rank}): {DEVICE}')

  from .autonotebook import tqdm as notebook_tqdm


Device (rank 0): cuda


In [2]:
model = MusicGen.get_pretrained('facebook/musicgen-small', device=DEVICE)
model.set_generation_params(
	use_sampling=True,
	top_k=250,
	duration=EXAMPLES_LEN
)
model = accelerator.prepare(model)

=>=>=> Loading an LM and conditioner model : [facebook/musicgen-small]
[Unmatched keywards]
	condition_provider.conditioners.description.output_proj.weight
	condition_provider.conditioners.description.output_proj.bias
=>=>=> Loading a compression model : [facebook/musicgen-small]


In [3]:
from musicgen.data.audio_dataset import MonoAudioFilesDataset
ds = MonoAudioFilesDataset('/home/mszawerda/music-sae/dependencies/musicgen/example/dataset/audio')
dl = lambda x, s: DataLoader(x, batch_size=BATCH_SIZE, shuffle=s, pin_memory=True if torch.cuda.is_available() else False)
train_dl, val_dl = dl(ds, True), dl(ds, False)

->-> Searching audio files...
->-> Found 2 files.
->-> Loading audio metadata...
->-> Keep 2 files.


In [4]:
# dl = lambda x, s: DataLoader(x, batch_size=BATCH_SIZE, shuffle=s, pin_memory=True if torch.cuda.is_available() else False)
# ds=torch.load(INPUT_PATH('8bit_encoded.pt'))[:225, :, :].cpu()
# ds = TensorDataset(ds)
# train_ds, val_ds = random_split(ds, [0.8, 0.2], generator=torch.Generator().manual_seed(42))
# train_dl, val_dl = dl(train_ds, True), dl(val_ds, False)

In [5]:
class LitAutoEncoder(nn.Module):
    def __init__(self, input_dim=784, latent_dim=64, sparsity_target=0.05, sparsity_weight=0.001):
        super().__init__()

        # Encoder
        self.encoder = nn.Sequential(nn.Linear(input_dim, latent_dim), nn.ReLU())

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, input_dim),
        )
        self.sparsity_target = sparsity_target
        self.sparsity_weight = sparsity_weight

    def forward(self, x):
        z = self.encoder(x)
        return z, self.decoder(z)

4

In [23]:
hook_point = model.lm.get_submodule('transformer.layers.12.cross_attention.out_proj')
n = hook_point.in_features
sae = LitAutoEncoder(input_dim=n, latent_dim=5 * n).to(DEVICE)
sae_diff = []
bottlneck = []

def perform_sae(module, input, output):
    z, out = sae(output)
    sae_diff.append((out, output))
    bottlneck.append(z)

hook_point.register_forward_hook(perform_sae)

optimizer = torch.optim.Adam(sae.parameters(), lr=1e-3)
a_coef = 1e-3
epochs = 100
with accelerator.autocast():
    model.eval()
    conds = model.conditioner({"text_prompt": ["hello"]*2})
    conds.requite_grad = True

loss_func = lambda x, z: torch.norm(torch.nan_to_num(x[-1][0], 0.0)-torch.nan_to_num(x[-1][1], 0.0)) + a_coef*torch.norm(torch.nan_to_num(z[-1], 0.0), p=1).item()
with tqdm.tqdm(total=epochs) as pbar:
    for epoch in range(epochs):
        model.train()
        sae_diff, bottlneck, total_loss = [], [], 0
        for batch in train_dl:
            music = batch[0].to(DEVICE)
            with torch.no_grad(), accelerator.autocast():
                codec, _ = model.compression_model.encode(music)
            with accelerator.autocast():
                model.lm.compute_predictions(
                    codec, conds.detach()
                )
            loss = loss_func(sae_diff, bottlneck)
            total_loss += loss.item()
            optimizer.zero_grad()
            accelerator.backward(loss)
            optimizer.step()
        with torch.no_grad():
            model.eval()
            sae_diff, bottlneck, val_loss = [], [], 0
            for batch in val_dl:
                batch = batch[0].to(DEVICE)
                with torch.no_grad(), accelerator.autocast():
                    codec, _ = model.compression_model.encode(music)
                with accelerator.autocast():
                    model.lm.compute_predictions(
                        codec, conds
                    )
                val_loss += loss_func(sae_diff, bottlneck)
                
        pbar.set_postfix_str(f'epoch: {epoch}, loss: {total_loss:.3f} val_los::{val_loss:.3f}')
        pbar.update(1)

100%|██████████| 100/100 [00:34<00:00,  2.91it/s, epoch: 99, loss: 142.841 val_los::142.841]
