In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from data_loader import get_train_loader
from speechbrain.inference.vocoders import HIFIGAN
from transformers import AutoModelForAudioClassification
import torchaudio
import numpy as np 
import pandas as pd

In [21]:
class AudioTextEmotionModel(nn.Module):
    def __init__(self, num_emotions, embedding_dim, num_frames):
        super(AudioTextEmotionModel, self).__init__()
        self.num_frames = num_frames

        # Emotion embedding layer
        self.emotion_embedding = nn.Embedding(num_emotions, embedding_dim)

        # Convolutional layers for mel spectrogram input
        self.conv1 = nn.Conv2d(2, 10, kernel_size=(3, 3), padding=(1, 1))
        self.conv2 = nn.Conv2d(10, 5, kernel_size=(5, 5), padding=(2, 2))
        self.conv3 = nn.Conv2d(5, 1, kernel_size=(3, 3), padding=(1, 1))            
        
    def forward(self, audio_input, emotion_idx):
        emotion_embedding = self.emotion_embedding(emotion_idx)
        emotion_repeated = emotion_embedding.unsqueeze(-1).repeat(1, 1, self.num_frames)

        mel_emotion = torch.stack((audio_input, emotion_repeated), dim=1)

        x = F.relu(self.conv1(mel_emotion))
        x = F.relu(self.conv2(x))
        output = F.relu(self.conv3(x))
                
        return output

In [5]:
hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir="pretrained_models/tts-hifigan-ljspeech")
emotion_rec = AutoModelForAudioClassification.from_pretrained("3loi/SER-Odyssey-Baseline-WavLM-Categorical-Attributes", trust_remote_code=True)
mean = emotion_rec.config.mean
std = emotion_rec.config.std

Some weights of the model checkpoint at microsoft/wavlm-large were not used when initializing WavLMModel: ['encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing WavLMModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing WavLMModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of WavLMModel were not initialized from the model checkpoint at microsoft/wavlm-large and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and i

In [22]:
num_emotions = 8
embedding_dim = 80
num_frames = 444 
lr = 0.001

model = AudioTextEmotionModel(num_emotions, embedding_dim, num_frames)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model.train()

AudioTextEmotionModel(
  (emotion_embedding): Embedding(8, 80)
  (conv1): Conv2d(2, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(10, 5, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(5, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [7]:
emotion_criterion = nn.CrossEntropyLoss()
num_epochs = 10

In [36]:
df = pd.read_excel('output_text.xlsx')
text_indices = np.arange(len(df)) # attempt to get text length (not working)
audio_paths = df['Spectrogram Path'].tolist()
max_length = max(torch.load(path).shape[1] for path in audio_paths)
train_loader = get_train_loader(audio_paths, text_indices, max_length)

for epoch in range(num_epochs):
    total_loss = 0
    
    for audio_input, emotion_input, text_input in train_loader:
        optimizer.zero_grad()
        
        output = model(audio_input, emotion_input)
        output = output.squeeze()
        print("Output Done ", end="")
        print(output.shape)

        waveforms = hifi_gan.decode_batch(output)
        waveforms = waveforms.squeeze()
        # waveforms = waveforms.unsqueeze(1)
        print("Waveforms Done ", end="")
        print(waveforms.shape)

        norm_wav = (waveforms - mean) / (std+0.000001)
        mask = torch.ones(norm_wav.shape)

        # with torch.no_grad():
        pred = emotion_rec(norm_wav, mask)

        probabilities = torch.nn.functional.softmax(pred, dim=1)
        print("Emotion Rec Done ", end="")
        print(probabilities.shape)

        emotion_loss = emotion_criterion(probabilities, emotion_input)
        loss = emotion_loss 
        # loss.requires_grad = True
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

        print(f"Total Loss: {total_loss}")
    
    avg_loss = total_loss / len(train_loader)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Avg. Loss: {avg_loss:.4f}")

Output Done torch.Size([8, 80, 444])
Waveforms Done torch.Size([8, 116224])




Emotion Rec Done torch.Size([8, 8])
Total Loss: 2.1301352977752686
Output Done torch.Size([8, 80, 444])
Waveforms Done torch.Size([8, 116224])
Emotion Rec Done torch.Size([8, 8])
Total Loss: 4.2996156215667725
Output Done torch.Size([8, 80, 444])
Waveforms Done torch.Size([8, 116224])
Emotion Rec Done torch.Size([8, 8])
Total Loss: 6.472322225570679
Output Done torch.Size([8, 80, 444])
Waveforms Done torch.Size([8, 116224])
Emotion Rec Done torch.Size([8, 8])
Total Loss: 8.603610754013062
Output Done torch.Size([8, 80, 444])
Waveforms Done torch.Size([8, 116224])
Emotion Rec Done torch.Size([8, 8])
Total Loss: 10.772891759872437
Output Done torch.Size([8, 80, 444])
Waveforms Done torch.Size([8, 116224])
Emotion Rec Done torch.Size([8, 8])
Total Loss: 12.945634841918945
Output Done torch.Size([8, 80, 444])
Waveforms Done torch.Size([8, 116224])
Emotion Rec Done torch.Size([8, 8])
Total Loss: 15.077552318572998
Output Done torch.Size([8, 80, 444])
Waveforms Done torch.Size([8, 116224])
E

KeyboardInterrupt: 

In [37]:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss,
            }, 'checkpoint.pth')