In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from data_loader import get_train_loader
from speechbrain.inference.vocoders import HIFIGAN
from transformers import AutoModelForAudioClassification
import torchaudio

: 

In [None]:
class AudioTextEmotionModel(nn.Module):
    def __init__(self, num_emotions, embedding_dim, num_frames):
        super(AudioTextEmotionModel, self).__init__()
        self.num_frames = num_frames

        # Emotion embedding layer
        self.emotion_embedding = nn.Embedding(num_emotions, embedding_dim)

        # Convolutional layers for mel spectrogram input
        self.conv1 = nn.Conv2d(2, 10, kernel_size=(3, 3), padding=(1, 1))
        self.conv2 = nn.Conv2d(10, 5, kernel_size=(5, 5), padding=(2, 2))
        self.conv3 = nn.Conv2d(5, 1, kernel_size=(3, 3), padding=(1, 1))            
        
    def forward(self, audio_input, emotion_idx):
        emotion_embedding = self.emotion_embedding(emotion_idx)
        emotion_repeated = emotion_embedding.repeat(1, self.num_frames)

        mel_emotion = torch.stack((audio_input, emotion_repeated), dim=0)

        x = F.relu(self.conv1(mel_emotion))
        x = F.relu(self.conv2(x))
        output = F.relu(self.conv3(x))
                
        return output

In [None]:
hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir="pretrained_models/tts-hifigan-ljspeech")
emotion_rec = AutoModelForAudioClassification.from_pretrained("3loi/SER-Odyssey-Baseline-WavLM-Categorical-Attributes", trust_remote_code=True)
mean = emotion_rec.config.mean
std = emotion_rec.config.std

In [None]:
num_emotions = 8
embedding_dim = 80
num_frames = 444 
lr = 0.001

model = AudioTextEmotionModel(num_emotions, embedding_dim, num_frames)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model.train()

In [None]:
emotion_criterion = nn.CrossEntropyLoss()
num_epochs = 10

In [None]:
train_loader = get_train_loader()

for epoch in range(num_epochs):
    total_loss = 0
    
    for audio_input, emotion_input, text_input in train_loader:
        optimizer.zero_grad()
        
        output = model(audio_input, emotion_input)

        waveforms = hifi_gan.decode_batch(output)

        norm_wav = (waveforms - mean) / (std+0.000001)
        mask = torch.ones(1, len(norm_wav))
        
        with torch.no_grad():
            pred = emotion_rec(norm_wav, mask)

        probabilities = torch.nn.functional.softmax(pred, dim=1)

        emotion_loss = emotion_criterion(output, emotion_input)
        loss = emotion_loss 
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Avg. Loss: {avg_loss:.4f}")

In [None]:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss,
            }, 'checkpoint.pth')