In [2]:
import sys
import os
import torchvision.datasets

sys.path.append(os.path.abspath("src"))

from audiocraft.models import MusicGen
import torch
from tools.project import INPUT_PATH, MODELS_PATH
from src.data import TextConcepts, TokensProvider
from src.losses import compute_cross_entropy
import tqdm
import pytorch_lightning as L
from src.model import TIMusicGen, ModelConfig
from torch.optim import Adam
from src.clip_textual_inversion import ConceptDataModule, ClipProjector
from torch.utils.data import Dataset, DataLoader
from src.img_feature_extractor import LitMNISTModel
from torchvision import transforms

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TOKENS_NUM = 5
NUM_WORKERS = int(os.cpu_count() * 0.75)


In [2]:
music_model = MusicGen.get_pretrained('facebook/musicgen-small')
music_model.set_generation_params(
	use_sampling=True,
	top_k=250,
	duration=5
)
text_conditioner = list(music_model.lm.condition_provider.conditioners.values())[0]
tokenizer = text_conditioner.t5_tokenizer
text_model = text_conditioner.t5



In [118]:
tokenized_prompt = tokenizer(
	['Hello world'], return_tensors="pt", padding=True, add_special_tokens=False
)
tokenized_prompt

{'input_ids': tensor([[8774,  296]]), 'attention_mask': tensor([[1, 1]])}

In [119]:
custom = torch.rand_like(text_model.shared.weight[tokenized_prompt['input_ids']], requires_grad=True)
embeds = text_model(inputs_embeds=custom, attention_mask=tokenized_prompt['attention_mask']).last_hidden_state
loss = torch.norm(embeds)
loss.backward()

In [120]:
custom.grad

tensor([[[-0.0575, -0.0300, -0.2495,  ..., -0.1649,  0.0237, -0.0073],
         [ 0.1011,  0.0372, -0.0026,  ...,  0.1519, -0.0089, -0.0003]]])

In [3]:
# import clip
# model, clip_preprocess = clip.load("ViT-B/32")

100%|███████████████████████████████████████| 338M/338M [00:03<00:00, 92.1MiB/s]


In [3]:
class FilteredCIFAR10(Dataset):
	def __init__(self, root, train=True, transform=None, download=False, target_label: int=3):
		self.cifar10 = torchvision.datasets.MNIST(
			root=root,
			train=train,
			transform=transform,
			download=download
		)
		self.indices = []
		for i, (_, label) in enumerate(self.cifar10):
			if label == target_label:
				self.indices.append(i)

	def __len__(self):
		return len(self.indices)

	def __getitem__(self, idx):
		real_idx = self.indices[idx]
		image, label = self.cifar10[real_idx]
		return image, label


def compute_and_save_embeddings(loader, model, device, save_path):
	all_embeddings = []
	all_labels = []

	model.eval()
	with torch.no_grad():
		for images, labels in tqdm.tqdm(loader):
			images = images.to(device)
			labels = labels.to(device)

			# Encode images using CLIP
			embeddings = model.get_features(images)
			# (Optional) L2-normalize the embeddings
			embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)

			all_embeddings.append(embeddings.cpu())
			all_labels.append(labels.cpu())

	all_embeddings = torch.cat(all_embeddings, dim=0).float()
	all_labels = torch.cat(all_labels, dim=0)

	# Save to disk
	torch.save((all_embeddings, all_labels), save_path)
	print(f"Saved: {save_path} - Embeddings shape: {all_embeddings.shape}, Labels shape: {all_labels.shape}")
model = LitMNISTModel()
model.load_state_dict(torch.load(MODELS_PATH('minist', "mnist_feature_extractor_weights.pth")))
model = model.to(DEVICE)
def embeds_for_label(num:int):
	tr = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
	])
	train_ds = FilteredCIFAR10(INPUT_PATH('cifar'), train=True, transform=tr, target_label=num)
	val_ds = FilteredCIFAR10(INPUT_PATH('cifar'), train=False, transform=tr, target_label=num)
	train_loader = DataLoader(train_ds, batch_size=64, shuffle=False)
	val_loader = DataLoader(val_ds, batch_size=64, shuffle=False)

	compute_and_save_embeddings(train_loader, model, 'cuda', INPUT_PATH('cifar', f'{"dog" if num == 5 else "cat"}_train_embeds.pt'))
	compute_and_save_embeddings(val_loader, model, 'cuda', INPUT_PATH('cifar', f'{"dog" if num == 5 else "cat"}_val_embeds.pt'))
embeds_for_label(3)
embeds_for_label(5)

100%|██████████| 96/96 [00:01<00:00, 75.82it/s]


Saved: /home/ubuntu/musical-generative-models-conditioning/data/input/cifar/cat_train_embeds.pt - Embeddings shape: torch.Size([6131, 64]), Labels shape: torch.Size([6131])


100%|██████████| 16/16 [00:00<00:00, 89.90it/s]


Saved: /home/ubuntu/musical-generative-models-conditioning/data/input/cifar/cat_val_embeds.pt - Embeddings shape: torch.Size([1010, 64]), Labels shape: torch.Size([1010])


100%|██████████| 85/85 [00:00<00:00, 87.46it/s]


Saved: /home/ubuntu/musical-generative-models-conditioning/data/input/cifar/dog_train_embeds.pt - Embeddings shape: torch.Size([5421, 64]), Labels shape: torch.Size([5421])


100%|██████████| 14/14 [00:00<00:00, 89.35it/s]

Saved: /home/ubuntu/musical-generative-models-conditioning/data/input/cifar/dog_val_embeds.pt - Embeddings shape: torch.Size([892, 64]), Labels shape: torch.Size([892])





In [3]:
concepts_db = TextConcepts.from_musicgen(
	music_model, TokensProvider(6), ['8bit', 'metal']
)
dm = ConceptDataModule(concepts_db, 10)
dm.setup('a')
next(iter(dm.train_dataloader()))

{'img': tensor([[-0.0040,  0.0136, -0.0177,  ...,  0.0831, -0.0048,  0.0246],
         [ 0.0096, -0.0562, -0.0107,  ...,  0.0546, -0.0154, -0.0437],
         [ 0.0214, -0.0078, -0.0205,  ...,  0.0582, -0.0021,  0.0043],
         ...,
         [ 0.0159, -0.0214, -0.0319,  ...,  0.0668, -0.0015,  0.0104],
         [-0.0030, -0.0091, -0.0344,  ...,  0.0593,  0.0268,  0.0196],
         [-0.0011, -0.0177, -0.0249,  ...,  0.0495, -0.0052,  0.0117]]),
 'encoded_music': tensor([[[1668, 1288,  433,  ...,  946, 1404, 1077],
          [1714, 1751,  426,  ..., 1462,  637, 1099],
          [1530,  360, 1695,  ...,  711, 1070,  453],
          [ 814,  745,  204,  ...,  599,  962, 1134]],
 
         [[1966, 1808, 1449,  ..., 1608,  941,  380],
          [1664, 1914,  274,  ...,  536,  676,  189],
          [ 141, 1907,  863,  ..., 1537,  319, 1342],
          [1706,  196, 1487,  ...,  380,  391, 1098]],
 
         [[1964,  839, 1095,  ..., 1474,   87, 1767],
          [1166, 1166, 1029,  ...,  872,  

In [170]:

def forward(batch):
	tokenized = tokenizer(
		batch['prompt'], return_tensors="pt", padding=True, add_special_tokens=False
	)
	mask = tokenized['attention_mask']
	text_with_clip = text_model.shared.weight[tokenized['input_ids']]
	text_with_clip[:, -TOKENS_NUM:, :] = projector(batch['img']).view(-1, TOKENS_NUM, 768)
	with text_conditioner.autocast and torch.set_grad_enabled(True):
		text_emb = text_model(inputs_embeds=text_with_clip, attention_mask=mask).last_hidden_state
	text_emb = text_conditioner.output_proj(text_emb.to(text_conditioner.output_proj.weight))
	text_emb = (text_emb * mask.unsqueeze(-1))
	with music_model.autocast:
		return music_model.lm.compute_predictions(batch['encoded_music'], [], {'description': (text_emb, mask)})
projector = ClipProjector()
optimizer = torch.optim.Adam(projector.parameters(), lr=1e-3)
epochs = 10
# for epoch in range(epochs):
# 	total_loss, num_batches = 0, len(train_dl)
# 	for batch in tqdm.tqdm(train_dl):
# 		optimizer.zero_grad()
# 		out = forward(batch)
# 		loss, _ = compute_cross_entropy(out.logits, batch['encoded_music'], out.mask)
# 		loss.backward()
# 		optimizer.step()
# 		total_loss += loss.item()
# 
# 	with torch.no_grad():
# 		total_val_loss, val_num_batches = 0.0, len(val_dl)
# 		for val_batch in tqdm.tqdm(val_dl):
# 			val_out = forward(val_batch)
# 			val_loss, _ = compute_cross_entropy(val_out.logits, val_batch['encoded_music'], val_out.mask)
				
		

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/1000 [00:00<?, ?it/s][A
  0%|          | 1/1000 [00:08<2:16:43,  8.21s/it][A[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument
  0%|          | 1/1000 [00:11<3:04:18, 11.07s/it]
  0%|          | 0/10 [00:11<?, ?it/s]


KeyboardInterrupt: 