## Model Related

This notebook contains information on the models trained.

#### Model Architecture

The models had the following architecture. More information about the training process can be found in the train notebooks. We cannot provide the lyrics or the audios used to train the models. 

In [None]:
import torch
from transformers import AutoModelForAudioClassification, ASTModel, RobertaModel
import torch.nn as nn

# Language model
AutoModelForSequenceClassification.from_pretrained('roberta-large', num_labels=9).to(device)

# Audio model
model = AutoModelForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", num_labels=9, ignore_mismatched_sizes=True)

# The multimodal model consists of these two classes
class CombinedClassificationHead(nn.Module):
    def __init__(self, audio_feature_size, text_feature_size, num_labels):
        super().__init__()
        combined_feature_size = audio_feature_size + text_feature_size
        self.layer_norm = nn.LayerNorm(combined_feature_size)
        self.fc = nn.Linear(combined_feature_size, num_labels)

    def forward(self, combined_features):
        normalized_features = self.layer_norm(combined_features)
        logits = self.fc(normalized_features)
        return logits

class AudioTextClassificationModel(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.audio_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        self.text_model = RobertaModel.from_pretrained('roberta-large')

        audio_feature_size = self.audio_model.config.hidden_size
        text_feature_size = self.text_model.config.hidden_size

        self.classifier = CombinedClassificationHead(audio_feature_size, text_feature_size, num_labels)

    def forward(self, input_values, input_ids, attention_mask):
        audio_output = self.audio_model(input_values=input_values)
        audio_pooled_output = audio_output[1]

        text_output = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = text_output[0]
        text_pooled_output = sequence_output[:, 0]

        combined_features = torch.cat((audio_pooled_output, text_pooled_output), dim=1)
        class_logits = self.classifier(combined_features)
        return class_logits