In [15]:
import os
import sys
working_dir = os.path.join(os.getcwd().split("Text2BGAudio")[0],'Text2BGAudio')
sys.path.append(working_dir)
os.chdir(working_dir)
from datasets import load_dataset
from transformers import ClapModel, ClapProcessor
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
from collections import Counter,defaultdict
from tqdm import tqdm
from Dataset_Creation import audio_dataset
from torch.utils.data import DataLoader,Dataset
if torch.cuda.is_available():
    DEVICE = torch.device("cuda:0")
    print("Device Name:", torch.cuda.get_device_name(DEVICE))
else:
    DEVICE = torch.device("cpu")
    print("Device Name: CPU")


Device Name: NVIDIA GeForce RTX 4080 SUPER


In [46]:
class EmbeddedDataset(Dataset):
    def __init__(self, embedded_data):
        self.embedded_data = embedded_data
    def __len__(self):
        return len(self.embedded_data)
    def __getitem__(self, idx):
        x,y= self.embedded_data[idx]
        return x,y

criterion = torch.nn.CrossEntropyLoss()
class MLPHead(torch.nn.Module):
    def __init__(self, input_dim=512, output_dim=6, hidden_dim=256):
        super(MLPHead, self).__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, output_dim),
            torch.nn.Softmax(dim=1)
        )
    
    def forward(self, x):
        return self.mlp(x)

In [32]:
label_to_index  = {'Angry' : 0, 'Joy' : 1, 'Love' : 2, 'Sad' : 3, 'Scary' : 4, 'Surprise' : 5}
index_to_label = {v: k for k, v in label_to_index.items()}

def create_embedded_ds(ds_path,model_path,return_indexed=True):
    embedded_ds_path = ds_path.split('.')[0] + "_embedded.pt"
    if os.path.isfile(embedded_ds_path):
        print(f"Loading Embedded Dataset {os.path.basename(embedded_ds_path)}")
        embedded_data = torch.load(embedded_ds_path,weights_only=False)
    else:
        print(f"Creating Embedded Dataset {os.path.basename(embedded_ds_path)}")
        dataset = audio_dataset.AudioDataset(ds_path)
        data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

        model_name = "laion/larger_clap_music"
        model = ClapModel.from_pretrained(model_name).to(DEVICE)
        processor = ClapProcessor.from_pretrained(model_name)
        model.load_state_dict(torch.load(model_path,weights_only=False)['model_state_dict'])

        embedded_data =list()
        with torch.no_grad():
            for batch in tqdm(data_loader,desc="Batches"):
                audio = batch[0]
                labels = list(batch[1])
                unique_labels = list(set(labels))
                inputs = processor(
                    text=unique_labels,
                    audios=audio.numpy(),
                    return_tensors="pt",
                    sampling_rate=48000,
                    padding=True,
                )
                inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
                outputs = model(**inputs)
                audio_embeds = outputs.audio_embeds
                embedded_data.extend([(audio_embed.cpu().detach(), label) for audio_embed, label in zip(audio_embeds, labels)])
        torch.save(embedded_data,embedded_ds_path)
    if return_indexed:
        embedded_data_indexed = [(audio_embed.cpu().detach(), label_to_index[label]) for audio_embed, label in embedded_data]
        return EmbeddedDataset(embedded_data_indexed)

In [42]:
model_path = r"CLAP\models\clap_fine_tunned_BatchSize_32_LR_1e-05_Epochs_50_LOSS_27.06.pt"
train_dataset_path = r"_Data\Music\music_dataset_train_size7507.pt"
val_dataset_path = r"_Data\Music\music_dataset_test_size39.pt"
train_dataset = create_embedded_ds(train_dataset_path,model_path)
val_dataset = create_embedded_ds(val_dataset_path,model_path)

Loading Embedded Dataset music_dataset_train_size7507_embedded.pt
Creating Embedded Dataset music_dataset_test_size39_embedded.pt


Batches: 100%|██████████| 2/2 [00:00<00:00,  2.17it/s]


In [None]:
batch_size = 32
epochs = 100
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=True)

classification_head = MLPHead().to(DEVICE)
optimizer = torch.optim.Adam(classification_head.parameters(), lr=5e-4)

train_acc = list()
val_acc = list()
for e in range(epochs):
    train_correct = 0
    for batch in train_data_loader:
        audio_embeds, labels = batch
        audio_embeds = audio_embeds.to(DEVICE)
        labels = labels.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = classification_head(audio_embeds)
        est_classification = torch.argmax(outputs, dim=1)
        train_correct += torch.sum(est_classification == labels).item()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    train_acc.append(train_correct/len(train_dataset))
    with torch.no_grad():
        val_correct = 0
        for batch in val_data_loader:
            audio_embeds, labels = batch
            # audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
            audio_embeds = audio_embeds.to(DEVICE)
            labels = labels.to(DEVICE)
            outputs = classification_head(audio_embeds)
            est_classification = torch.argmax(outputs, dim=1)
            val_correct += torch.sum(est_classification == labels).item()
        val_acc.append(val_correct/len(val_dataset))
    print(f"Epoch {e+1}/{epochs}, Loss: {loss.item()}, Train Acc : {train_acc[-1]} , Val Acc : {val_acc[-1]}")

Epoch 1/100, Loss: 1.1938018798828125, Train Acc : 0.7485013986945518 , Val Acc : 0.6923076923076923
Epoch 2/100, Loss: 1.2446573972702026, Train Acc : 0.815771946183562 , Val Acc : 0.6666666666666666
Epoch 3/100, Loss: 1.156229853630066, Train Acc : 0.8242973224990009 , Val Acc : 0.6666666666666666
Epoch 4/100, Loss: 1.2113853693008423, Train Acc : 0.8266950845877181 , Val Acc : 0.6666666666666666
Epoch 5/100, Loss: 1.2848998308181763, Train Acc : 0.8290928466764353 , Val Acc : 0.6666666666666666
Epoch 6/100, Loss: 1.0612423419952393, Train Acc : 0.8320234447848674 , Val Acc : 0.6666666666666666
Epoch 7/100, Loss: 1.1632806062698364, Train Acc : 0.8348208338883708 , Val Acc : 0.6923076923076923
Epoch 8/100, Loss: 1.1985416412353516, Train Acc : 0.835753296922872 , Val Acc : 0.6923076923076923
Epoch 9/100, Loss: 1.054551124572754, Train Acc : 0.8380178500066604 , Val Acc : 0.6923076923076923
Epoch 10/100, Loss: 1.2772855758666992, Train Acc : 0.8390835220460903 , Val Acc : 0.6153846153