In [1]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel, BertConfig
import torch.optim as optim
from torch.nn import MSELoss, TransformerEncoder, TransformerEncoderLayer
import pandas as pd
import numpy as np
import os
import time
import re
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.utils.rnn import pad_sequence

In [2]:
data = pd.read_pickle("/kaggle/input/audio-dataset-mel-256/audio_dataset_256.pkl")
data

Unnamed: 0,ytid,start_s,end_s,audioset_positive_labels,aspect_list,caption,author_id,is_balanced_subset,is_audioset_eval,audio_features
0,-0Gj8-vB1q4,30,40,"/m/0140xf,/m/02cjck,/m/04rlf","['low quality', 'sustained strings melody', 's...",The low quality recording features a ballad so...,4,False,True,"[[-40.99979, -40.858463, -44.591995, -42.415, ..."
1,-0SdAVK79lg,30,40,"/m/0155w,/m/01lyv,/m/0342h,/m/042v_gx,/m/04rlf...","['guitar song', 'piano backing', 'simple percu...",This song features an electric guitar as the m...,0,False,False,"[[-47.59354, -44.512863, -42.904636, -44.10178..."
2,-0vPFx-wRRI,30,40,"/m/025_jnm,/m/04rlf","['amateur recording', 'finger snipping', 'male...",a male voice is singing a melody with changing...,6,False,True,"[[-47.81186, -51.93988, -60.75714, -58.41188, ..."
3,-0xzrMun0Rs,30,40,"/m/01g90h,/m/04rlf","['backing track', 'jazzy', 'digital drums', 'p...",This song contains digital drums playing a sim...,6,False,True,"[[-25.13749, -30.922548, -43.063854, -38.35595..."
4,-1LrH01Ei1w,30,40,"/m/02p0sh1,/m/04rlf","['rubab instrument', 'repetitive melody on dif...",This song features a rubber instrument being p...,0,False,False,"[[-22.071985, -19.954773, -20.680244, -20.6276..."
...,...,...,...,...,...,...,...,...,...,...
5516,zw5dkiklbhE,15,25,"/m/01sm1g,/m/0l14md","['amateur recording', 'percussion', 'wooden bo...",This audio contains someone playing a wooden b...,6,False,False,"[[-42.779514, -41.359673, -48.218506, -50.4325..."
5517,zwfo7wnXdjs,30,40,"/m/02p0sh1,/m/04rlf,/m/06j64v","['instrumental music', 'arabic music', 'genera...",The song is an instrumental. The song is mediu...,1,True,True,"[[-38.654037, -40.134087, -46.6978, -51.105743..."
5518,zx_vcwOsDO4,50,60,"/m/01glhc,/m/02sgy,/m/0342h,/m/03lty,/m/04rlf,...","['instrumental', 'no voice', 'electric guitar'...",The rock music is purely instrumental and feat...,2,True,True,"[[-48.891968, -56.01191, -55.617645, -58.16963..."
5519,zyXa2tdBTGc,30,40,"/m/04rlf,/t/dd00034","['instrumental music', 'gospel music', 'strong...",The song is an instrumental. The song is slow ...,1,False,False,"[[-47.926243, -47.799667, -54.788925, -57.6263..."


In [3]:
data['audio_features'].iloc[0].shape

(256, 861)

In [4]:
data['audio_features'].iloc[0]

array([[-40.99979 , -40.858463, -44.591995, ..., -38.19406 , -36.710896,
        -38.839684],
       [-46.70161 , -42.733616, -40.621758, ..., -37.159702, -33.01973 ,
        -35.885857],
       [-44.044918, -45.685513, -45.924263, ..., -37.338154, -36.525146,
        -38.146904],
       ...,
       [-79.06895 , -77.472626, -77.350655, ..., -69.1514  , -75.54788 ,
        -75.59643 ],
       [-80.      , -80.      , -80.      , ..., -79.809616, -80.      ,
        -80.      ],
       [-80.      , -80.      , -80.      , ..., -80.      , -80.      ,
        -80.      ]], dtype=float32)

In [5]:
def clip_array(arr):
    return arr[:, :420]

# Apply the clipping function to each array in the 'audio_features' column
data['audio_features'] = data['audio_features'].apply(clip_array)

In [15]:
# lstm - TS task

class MusicGenerationModel(nn.Module):
    def __init__(self, bert_model_name='bert-base-uncased', hidden_dim=512, lstm_hidden_dim=256, output_dim=256, n_lstm_layers=2, dropout=0.1):
        super(MusicGenerationModel, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.text_encoder = BertModel.from_pretrained(bert_model_name)
        
        self.fc_text_to_lstm = nn.Linear(self.text_encoder.config.hidden_size, lstm_hidden_dim)
        self.lstm = nn.LSTM(lstm_hidden_dim, lstm_hidden_dim, n_lstm_layers, batch_first=True, dropout=dropout)
        self.fc_lstm_to_output = nn.Linear(lstm_hidden_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, text_description, seq_length=420):
        # Tokenize and encode the input text description
        encoded_input = self.tokenizer(text_description, return_tensors='pt', padding=True, truncation=True).to(next(self.parameters()).device)
        with torch.no_grad():
            text_embeddings = self.text_encoder(**encoded_input).last_hidden_state

        # Use the [CLS] token's embedding as the representation of the entire sequence
        cls_embedding = text_embeddings[:, 0, :]  # Extract the [CLS] token embedding
        
        # Map the text embedding to the LSTM hidden dimension size
        lstm_input = self.fc_text_to_lstm(cls_embedding).unsqueeze(1)  # Shape: (batch_size, 1, lstm_hidden_dim)
        
        # Initialize hidden and cell states for the LSTM
        h_0 = torch.zeros(self.lstm.num_layers, lstm_input.size(0), self.lstm.hidden_size).to(next(self.parameters()).device)
        c_0 = torch.zeros(self.lstm.num_layers, lstm_input.size(0), self.lstm.hidden_size).to(next(self.parameters()).device)
        
        # Sequentially generate each value in the spectrogram
        outputs = []
        lstm_output, (h_t, c_t) = self.lstm(lstm_input, (h_0, c_0))
        for _ in range(seq_length):
            lstm_output, (h_t, c_t) = self.lstm(lstm_output, (h_t, c_t))
            output = self.fc_lstm_to_output(lstm_output)
            outputs.append(output)
            lstm_output = output  # Feed the current output as the next input
        
        audio_output = torch.cat(outputs, dim=1).squeeze(2)  # Concatenate all outputs along the sequence length dimension
        audio_output = audio_output.squeeze(0)  # Remove the first dimension of size 1, resulting in shape [431, 256]
        audio_output = audio_output.transpose(0, 1)
        
        return audio_output

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MusicGenerationModel().to(device)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [24]:
num_epochs = 50  # Set the number of epochs
training_losses = []
validation_losses = []

# Define loss and optimizer
criterion = nn.MSELoss()
# optimizer = optim.Adam(model.parameters(), lr=1e-5)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
# scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


# Directory to save model checkpoints and losses
checkpoint_dir = './checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Initialize best loss to a large value
best_loss = float('inf')
early_stopping_patience = 5
early_stopping_counter = 0

# Split your data into training and validation sets
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)  # Adjust test_size and random_state as needed


for epoch in range(num_epochs):
    epoch_start_time = time.time()
    epoch_loss = 0.0
    
    # Training phase
    model.train()
    for i, row in train_data.iterrows():
#         text_description = ''.join(row['aspect_list'])
        text_description = row['caption']
#         print(text_description)
        # Ensure audio_target is a 2D tensor
        audio_target = torch.tensor(row['audio_features']).to(device)
#         audio_target = pad_audio_features(row['audio_features'], 861).to(device)
#         audio_target = torch.tensor(row['normalized_audio_features']).unsqueeze(0).to(device)
        if len(audio_target.shape) == 3:
            audio_target = audio_target.squeeze(0).permute(1, 0)
        
        seq_length = audio_target.shape[0]
        
        optimizer.zero_grad()
#         output = model([text_description], seq_length)
        output = model([text_description])
        
        if output.shape != audio_target.shape:
            print(f"Output shape {output.shape} and target shape {audio_target.shape} do not match!")
            continue

        loss = criterion(output, audio_target)   
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_epoch_loss = epoch_loss / len(train_data)
    training_losses.append(avg_epoch_loss)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for i, row in val_data.iterrows():
#             text_description = ''.join(row['aspect_list'])
            text_description = row['caption']
            # Ensure audio_target is a 2D tensor
            audio_target = torch.tensor(row['audio_features']).to(device)
#             audio_target = pad_audio_features(row['audio_features'], 861).to(device)
#             audio_target = torch.tensor(row['normalized_audio_features']).unsqueeze(0).to(device)
            
            if len(audio_target.shape) == 3:
                audio_target = audio_target.squeeze(0).permute(1, 0)
        
            seq_length = audio_target.shape[0]
            
#             output = model([text_description], seq_length)
            output = model([text_description])
            
            if output.shape != audio_target.shape:
                print(f"Output shape {output.shape} and target shape {audio_target.shape} do not match!")
                continue

            loss = criterion(output, audio_target)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_data)
    validation_losses.append(avg_val_loss)
    
#     scheduler.step(avg_val_loss) #if cosine we dont need that
    scheduler.step()
    
    # Early Stopping check
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        early_stopping_counter = 0
        
        checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_val_loss,
        }, checkpoint_path)
        
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= early_stopping_patience:
            print("Early stopping triggered.")
            break
    
    # Save training and validation losses to files
    loss_file = os.path.join(checkpoint_dir, 'training_losses.txt')
    with open(loss_file, 'a') as f:
        f.write(f"Epoch {epoch+1}, Training Loss: {avg_epoch_loss}, Validation Loss: {avg_val_loss}\n")
    
    # Save validation losses to a separate file
    val_loss_file = os.path.join(checkpoint_dir, 'validation_losses.txt')
    with open(val_loss_file, 'a') as f:
        f.write(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss}\n")
    
    epoch_end_time = time.time()
    epoch_duration = epoch_end_time - epoch_start_time
    
    print(f"Epoch {epoch+1}, Training Loss: {avg_epoch_loss}, Validation Loss: {avg_val_loss}, Time Elapsed: {epoch_duration:.2f} seconds")

# Save the final training and validation losses to files
final_loss_file = os.path.join(checkpoint_dir, 'final_losses.txt')
with open(final_loss_file, 'w') as f:
    for epoch, (train_loss, val_loss) in enumerate(zip(training_losses, validation_losses), 1):
        f.write(f"Epoch {epoch}, Training Loss: {train_loss}, Validation Loss: {val_loss}\n")

torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
Epoch 1, Training Loss: 5.6290473841190884, Validation Loss: 21.23659039599429, Time Elapsed: 6.14 seconds
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
Epoch 2, Training Loss: 5.628806904195502, Validation Loss: 21.23564761679111, Time Elapsed: 5.41 seconds
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
torch.Size([256, 420])
Epoch 3, Training Loss: 5.628572239155313, Validation Loss: 21.234749431539726, Time Elapsed: 6.3