In [26]:
import torch
import torch.nn as nn
from huggingface_hub import notebook_login
# notebook_login()

In [27]:
#Hyperparameters
from dataclasses import dataclass
@dataclass
class ModelArgs:
  epochs=30
  text_embeddings = 768
  audio_embeds=2048
  block_size = 100
  batch_size = 32
  lr = 4e-4
  device= 'cuda:0'
  SAMPLING_RATE=44100
  N_MELS = 64
  max_t = 500
  n_channels = N_MELS
  window_size = 1024
  hop_size = 320
  mel_bins = N_MELS
  fmin = 50
  fmax = 8000
  output_embeddings = 1024
  head_lr = 1e-3
  audio_encoder_lr = 1e-4
  text_encoder_lr = 1e-5

In [28]:
torch.set_default_device(ModelArgs.device)

In [None]:
# !pip install datasets
# !pip install tqdm
import wandb
!wandb login

In [None]:
# HF_TOKEN =
from datasets import load_dataset

gs = load_dataset("speechcolab/gigaspeech", "xs",token='...')



In [8]:
# !git clone https://github.com/qiuqiangkong/audioset_tagging_cnn.git

In [5]:
# CHECKPOINT_PATH="Cnn14_mAP=0.431.pth"
# !wget -O $CHECKPOINT_PATH https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth?download=1
model_type="Cnn14"

In [None]:
# Load model directly
from transformers import  BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased", token='...')
text_model = BertModel.from_pretrained("google-bert/bert-base-uncased", token='...')
tokenizer.add_special_tokens({'pad_token':"[PAD]"})

In [31]:
for params in text_model.parameters():
    params.requires_grad = True

In [32]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

class AudioTextDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=ModelArgs.block_size):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Load audio and text
        item = self.dataset[idx]
        audio_path = torch.tensor(item["audio"]["array"], dtype=torch.float32)
        text = item["text"].lower()


        # Tokenize text
        tokenized_text = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            "audio": audio_path.squeeze(0),
            "input_ids": tokenized_text["input_ids"].squeeze(0),
            "attention_mask": tokenized_text["attention_mask"].squeeze(0)
        }

In [33]:
def collate_fn(batch):
    # Separate audio, input_ids, and attention_mask
    audios = [item["audio"] for item in batch]
    input_ids = [item["input_ids"] for item in batch]
    attention_masks = [item["attention_mask"] for item in batch]

    # Pad audio sequences to the length of the longest audio in the batch
    audios_padded = pad_sequence(audios, batch_first=True)

    # Stack input_ids and attention_masks
    input_ids = torch.stack(input_ids)
    attention_masks = torch.stack(attention_masks)

    return {
        "audio": audios_padded,
        "input_ids": input_ids,
        "attention_mask": attention_masks
    }

In [None]:
torch.__version__

In [34]:
# Create the dataset
train_dataset = AudioTextDataset(gs["train"], tokenizer)
# val_dataset = AudioTextDataset(dataset["validation"], tokenizer)
generator = torch.Generator(device=ModelArgs.device)
# Create the DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn, generator=generator)
# val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [None]:
!pip install torchlibrosa

In [None]:
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '../utils'))
import numpy as np
import argparse
import librosa
import matplotlib.pyplot as plt
import torch

# from utilities import create_folder, get_filename
from pytorch.models import *
from pytorch.pytorch_utils import move_data_to_device
import utils.config


Model = eval(model_type)
model = Model(sample_rate=ModelArgs.SAMPLING_RATE, window_size=ModelArgs.window_size,
    hop_size=ModelArgs.hop_size, mel_bins=ModelArgs.mel_bins, fmin=ModelArgs.fmin, fmax=ModelArgs.fmax,
    classes_num=527)

checkpoint = torch.load('/mnt/c/Users/yuvra/OneDrive/Desktop/Work/pytorch/Paper-Replications/CLAP/audioset_tagging_cnn/Cnn14_mAP=0.431.pth', map_location=ModelArgs.device)
model.load_state_dict(checkpoint['model'])

In [38]:
for i in model.parameters():
    i.requires_grad = True

In [39]:
class Projection(nn.Module):
    def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
        super().__init__()
        self.linear1 = nn.Linear(d_in, d_out, bias=False)
        self.linear2 = nn.Linear(d_out, d_out, bias=False)
        self.layer_norm = nn.LayerNorm(d_out)
        self.drop = nn.Dropout(p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        embed1 = self.linear1(x)
        embed2 = self.drop(self.linear2(F.gelu(embed1)))
        embeds = self.layer_norm(embed1 + embed2)
        return embeds

In [40]:
class CLSBERT:

    def embedding(self, input_ids, attention_mask, model):
        # x = tokenizer(x, truncation=True, padding='max_length', max_length=100, return_tensors='pt')
        # print(input_ids)
        # print()
        embd = model(input_ids=input_ids, attention_mask=attention_mask)
        return embd.last_hidden_state[:, 0, :]
    

In [None]:
torch.tensor(text['input_ids'])

In [None]:
text = "Hi my name "
text = tokenizer(text,  return_tensors='pt')
print(text)
x = CLSBERT()
x.embedding(text['input_ids'], text['attention_mask'], text_model).shape
# text_model(**text)

In [41]:
class EmbeddingCNN14:

    def embedding(self, x, model):
        
        # Forward
        batch_output_dict = None
        # x = x.squeeze(0)
        # waveform = x[None, :] 
        # with torch.no_grad():
            # model.eval()
        batch_output_dict = model(x, None)
        
        embd = batch_output_dict['embedding'].data
        # model.train()
        return embd

In [None]:
torch.randn(1, requires_grad=True)

In [45]:
import numpy as np
class CLAP(nn.Module):


    def __init__(self):

        super().__init__()
        
        self.multimodelTextLayerProjector = Projection(ModelArgs.text_embeddings,ModelArgs.output_embeddings)
        # self.multimodelTextLayerProjector  = nn.Linear(in_features = ModelArgs.text_embeddings, out_features = ModelArgs.output_embeddings, device=ModelArgs.device)
        self.multimodalAudioLayerProjector = Projection(ModelArgs.audio_embeds, ModelArgs.output_embeddings)
        # self.multimodalAudioLayerProjector = nn.Linear(in_features = ModelArgs.audio_embeds, out_features = ModelArgs.output_embeddings, device=ModelArgs.device)
        self.temp = torch.ones([], requires_grad=True, device=ModelArgs.device) * 0.07
        self.text_embeds= CLSBERT()
        self.audio_embeds = EmbeddingCNN14()
        self.text_model = text_model
        self.audio_model = model

    def forward(self, x):
        # print(x['attention_mask'])
        # print(x['input_ids'])
        text_embeds = self.text_embeds.embedding(x['input_ids'], x['attention_mask'], self.text_model)
        audio_embeds = self.audio_embeds.embedding(x['audio'], self.audio_model)
        # print("before: ", audio_embeds.shape)
        # print("before: ", text_embeds.shape)
        # text = self.multimodelTextLayerPorjector(text_embeds)
        proj_txt = torch.nn.functional.normalize(self.multimodelTextLayerProjector(text_embeds))
        # audio = self.multimodalAudioLayerProjector(audio_embeds)
        proj_audio = torch.nn.functional.normalize(self.multimodalAudioLayerProjector(audio_embeds))
        # print("text", text.shape)
        # print("audio", audio.shape)
        out = self.temp * (proj_txt @ proj_audio.T) 
        return out

In [None]:
train_dataset[0]['input_ids']

In [46]:
clap = CLAP()
# clap = clap.to('cuda:0')
# clap(inp)

In [47]:

import itertools
# params = [
#         {"params": clap.audio_model.parameters(), "lr": ModelArgs.audio_encoder_lr},
#         {"params": clap.text_model.parameters(), "lr": ModelArgs.text_encoder_lr},
#         {"params": itertools.chain(
#             clap.multimodalAudioLayerProjector.parameters(), clap.multimodelTextLayerProjector.parameters(), [clap.temp]
#         ), "lr": ModelArgs.head_lr}
#     ]

# optimizer = torch.optim.Adam(lr=ModelArgs.lr, params=params, eps=ModelArgs.epsilon)


optimizer = torch.optim.Adam(lr=ModelArgs.lr, params=clap.parameters())

In [48]:
torch.set_float32_matmul_precision('high')

# scaler = torch.amp.GradScaler(enabled=True)

In [23]:
def find_unused_parameters(model):
    unused = []
    for name, param in model.named_parameters():
        if param.grad is None:
            unused.append(name)
    return unused


In [None]:
clap = clap.to(ModelArgs.device)
clap.train()


train_losses =  torch.zeros(len(train_dataloader))
# val_losses = torch.zeros(len(val_dataloader))
wandb.init(
    project='clap-From-Scratch'
)
for epoch in range(ModelArgs.epochs):

    count = 0
    print("Starting train...")
    device=ModelArgs.device
    for X in train_dataloader:
    #   with torch.autocast(device_type=device, dtype=torch.float16):
    # print(X)
        X['input_ids'] = X['input_ids'].to(device)
        X['attention_mask'] = X['attention_mask'].to(device)
        X['audio'] = X['audio'].to(device)
        # X = X.to(device)
        # print("in", X['input_ids'])
        # print("audio: ", X['audio'])
        # print("atn: ", X['attention_mask'])
        classes = torch.arange(X['input_ids'].shape[0])
        logits = clap(X)
        logits = torch.clamp(logits, max=100)
        # batch_size, block_size, vocab = logits.shape
        # print("Now ", logits)
        # print("logits: ", logits.shape)
        # logits = logits.view(batch_size*block_size, vocab)
        # classes = classes.view(batch_size * block_size)
        # print("logits: ", logits.shape)
        # print("classes: ", classes.shape)
        # print(logits)
        # print(classes)
        # logits = logits.permute(1,0)
        loss_t = torch.nn.functional.cross_entropy(logits, classes) #row wise of the bz,nz matrix 
        loss_a = torch.nn.functional.cross_entropy(logits.T, classes.T) #column wise of the bz,classes matrix
        final_loss = (loss_t + loss_a ) / 2
            # print(logits.shape)

            # batch_size, block_size, vocab = logits.shape
            # print("Va: ", vocab)
            # logits = logits.view(batch_size*block_size, vocab)
            # targets = y.view(batch_size * block_size)
            # print("HiiiL ", en.shape)
            # print("HiiiT ", logits.shape)
            # loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
        train_losses[count] = final_loss.item()
        print("Loss: ", final_loss.item())

        optimizer.zero_grad()
        final_loss.backward(retain_graph=True)
        # loss.backward()





        optimizer.step()
        # scaler.step(optimizer)
        # scaler.update()
        count += 1
        # ls = find_unused_parameters(clap)
        # print(ls)
        # final_loss = None
        # print()
        # print(count)


    # count = 0
    # model.eval()
    # count = 0
    # print("Starting val...")
    # for X, y in val_dataloader:

    #     X = X.to(device)
    #     y = y.to(device)
    #     logits = clap(X,y)
    #     # print(logits.shape)
    #     batch_size, block_size, vocab = logits.shape

    #     logits = logits.view(batch_size*block_size, vocab)
    #     # print("Va: ", vocab)
    #     targets = y.view(batch_size * block_size)
    #     loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)

    #     # print("Loss: ", loss.item())
    #     val_losses[count] = loss.item()

    #     # optimizer.zero_grad()
    #     # loss.backward()
    #     # optimizer.step()
    #     count += 1


    # print("eval")
    # print("Generating text...")
    # generated_text = topk_sampling(model, 'Ich fahre heute mit dem Rad zur Schule', de_tokenizer, device=ModelArgs.device, max_length=50, top_k=50, temperature=1.0)

    # print(generated_text)


    # clap.train()
    wandb.log({
      "Train Loss": train_losses.mean(),
    #   "Val Loss": val_losses.mean(),
      "epoch": epoch
    })
    print("Epoch: ", epoch, "|", "Train Loss: ", train_losses.mean())
