In [2]:
from utils import INPUT_PATH
from audiocraft.models import MusicGen
from audiocraft.modules.conditioners import ConditioningAttributes
from torch.utils.data import TensorDataset, random_split, DataLoader
import torch
import torch.nn as nn
import tqdm

EXAMPLES_LEN = 5
BATCH_SIZE = 5
N_TOKENS = 5
DEVICE = torch.device("cpu" if torch.cuda.is_available() else 'cpu')

In [3]:
model = MusicGen.get_pretrained('facebook/musicgen-small', device=DEVICE)
model.set_generation_params(
	use_sampling=True,
	top_k=250,
	duration=EXAMPLES_LEN
)



In [4]:
dl = lambda x, s: DataLoader(x, batch_size=BATCH_SIZE, shuffle=s, pin_memory=True if torch.cuda.is_available() else False)
ds=torch.load(INPUT_PATH('8bit_encoded.pt'))[:225, :, :].cpu()
ds = TensorDataset(ds)
train_ds, val_ds = random_split(ds, [0.8, 0.2], generator=torch.Generator().manual_seed(42))
train_dl, val_dl = dl(train_ds, True), dl(val_ds, False)

In [5]:
class LitAutoEncoder(nn.Module):
    def __init__(self, input_dim=784, latent_dim=64, sparsity_target=0.05, sparsity_weight=0.001):
        super().__init__()

        # Encoder
        self.encoder = nn.Sequential(nn.Linear(input_dim, latent_dim), nn.ReLU())

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, input_dim),
        )
        self.sparsity_target = sparsity_target
        self.sparsity_weight = sparsity_weight

    def forward(self, x):
        z = self.encoder(x)
        return z, self.decoder(z)

In [6]:
hook_point = model.lm.get_submodule('transformer.layers.12.cross_attention.out_proj')
n = hook_point.in_features
sae = LitAutoEncoder(input_dim=n, latent_dim=5 * n).to(DEVICE)
sae_diff = []
bottlneck = []

def perform_sae(module, input, output):
    z, out = sae(output)
    sae_diff.append((out, output))
    bottlneck.append(z)

hook_point.register_forward_hook(perform_sae)

optimizer = torch.optim.Adam(sae.parameters(), lr=1e-3)
a_coef = 1e-3
epochs = 100
with tqdm.tqdm(total=epochs) as pbar:
    for epoch in range(epochs):
        sae_diff, bottlneck, total_loss = [], [], 0
        for batch in train_dl:
            batch = batch[0].to(DEVICE)
            with model.autocast:
                model.lm.compute_predictions(
                    batch, [ConditioningAttributes(text={"description": "Amazing metal music"})]
                )
            loss = torch.norm(sae_diff[-1][0]-sae_diff[-1][1]) + a_coef*torch.norm(bottlneck[-1], p=1)
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        with torch.no_grad():
            sae_diff, bottlneck, val_loss = [], [], 0
            for batch in val_dl:
                batch = batch['encoded_music'].to(DEVICE)
                with model.autocast:
                    model.lm.compute_predictions(
                        batch, [ConditioningAttributes(text={"description": "Amazing metal music"})]
                    )
                val_loss += torch.norm(sae_diff[-1][0]-sae_diff[-1][1]) + a_coef*torch.norm(bottlneck[-1], p=1).item()
                
        pbar.set_postfix_str(f'epoch: {epoch}, loss: {total_loss:.3f} val_los::{val_loss:.3f}')
        pbar.update(1)

  0%|          | 0/100 [01:05<?, ?it/s]


KeyboardInterrupt: 