# Import

In [50]:
import numpy as np
import pandas as pd
import transformers
from transformers import BertTokenizer, BertModel, AutoModel, AutoProcessor
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch, torchaudio, torchtext
from torcheval.metrics.functional import multiclass_f1_score
import torch.nn as nn
import os
import gc
import pickle
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


# Load df

In [2]:
try:
    df_path = '/kaggle/input/MM-USElecDeb60to16/MM-USElecDeb60to16.csv'
    audio_path = '/kaggle/input/MM-USElecDeb60to16/audio_clips'
    save_path = '/kaggle/working/'
    df = pd.read_csv(df_path, index_col=0)
except FileNotFoundError:
    df_path = 'multimodal-dataset/files/MM-USElecDeb60to16/MM-USElecDeb60to16.csv'
    audio_path = 'multimodal-dataset/files/MM-USElecDeb60to16/audio_clips'
    save_path = 'multimodal-dataset/files'
    df = pd.read_csv(df_path, index_col=0)
# drop rows where audio length is 0
df = df[df['NewBegin'] != df['NewEnd']]

train_df_complete = df[df['Set'] == 'TRAIN']
val_df_complete = df[df['Set'] == 'VALIDATION']
test_df_complete = df[df['Set'] == 'TEST']

DATASET_RATIO = 0.40

train_df = train_df_complete.iloc[:int(DATASET_RATIO * len(train_df_complete))]
val_df = val_df_complete.iloc[:int(DATASET_RATIO * len(val_df_complete))]
test_df = test_df_complete.iloc[:int(DATASET_RATIO * len(test_df_complete))]

In [3]:
train_df.head()

Unnamed: 0,Text,Part,Document,Order,Sentence,Start,End,Annotator,Tag,Component,...,Speaker,SpeakerType,Set,Date,Year,Name,MainTag,NewBegin,NewEnd,idClip
0,"CHENEY: Gwen, I want to thank you, and I want ...",1,30_2004,0,0,2101,2221,,"{""O"": 27}",O,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,O,126.52,131.08,clip_0
1,"It's a very important event, and they've done ...",1,30_2004,1,1,2221,2304,,"{""O"": 19}",O,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,O,131.08,134.4,clip_1
2,It's important to look at all of our developme...,1,30_2004,2,2,2304,2418,,"{""O"": 23}",O,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,O,134.4,140.56,clip_2
3,"And, after 9/11, it became clear that we had t...",1,30_2004,3,3,2418,2744,,"{""O"": 16, ""Claim"": 50}",Claim,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,Claim,140.56,158.92,clip_3
4,And we also then finally had to stand up democ...,1,30_2004,4,4,2744,2974,,"{""O"": 4, ""Claim"": 13, ""Premise"": 25}",Premise,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,Mixed,158.92,172.92,clip_4


In [4]:
len(train_df), len(test_df), len(val_df)

(4967, 2986, 2758)

## Distribution of classes over train df

In [25]:
num_claim = len(train_df[train_df['Component'] == 'Claim'])
num_premise = len(train_df[train_df['Component'] == 'Premise'])
num_other = len(train_df[train_df['Component'] == 'O'])

print(f'Total Claim: {num_claim}: {num_claim*100/len(train_df):.2f}%')
print(f'Total Premise: {num_premise}: {num_premise*100/len(train_df):.2f}%')
print(f'Total Other: {num_other}: {num_other*100/len(train_df):.2f}%')

Total Claim: 2103: 42.34%
Total Premise: 1697: 34.17%
Total Other: 1167: 23.50%


# Train and evaluation Loop

In [54]:
class BestModel:
    """
        Class to keep track of the best performing model on validation set during training
    """
    def __init__(self):
        self.best_validation_loss = float('Infinity')
        self.best_state_dict = None
    def __call__(self, model, loss):
        if loss < self.best_validation_loss:
            self.best_validation_loss = loss
            self.best_state_dict = model.state_dict()

def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=10, device="cuda"):
    best_model_tracker = BestModel()
    for epoch in tqdm(range(epochs)):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()

        for batch in train_loader:
            optimizer.zero_grad()
            texts, audio_features, audio_attention, targets = batch
            audio_features = audio_features.to(device)
            audio_attention = audio_attention.to(device)
            targets = targets.to(device)
            output = model(texts,audio_features,audio_attention)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.detach()
        training_loss = training_loss.cpu().item()
        training_loss /= len(train_loader.dataset)

        model.eval()
        num_correct = 0 
        num_examples = 0
        tot_pred, tot_targ = torch.LongTensor().to(device), torch.LongTensor().to(device)
        for batch in val_loader:
            texts, audio_features, audio_attention, targets = batch
            audio_features = audio_features.to(device)
            audio_attention = audio_attention.to(device)
            targets = targets.to(device)
            output = model(texts,audio_features,audio_attention)
            loss = loss_fn(output, targets)
            valid_loss += loss.detach()
            predicted_labels = torch.argmax(output, dim=-1)
            tot_targ = torch.cat((tot_targ, targets))
            tot_pred = torch.cat((tot_pred, predicted_labels))            
            correct = torch.eq(predicted_labels, targets).view(-1)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        best_model_tracker(model, valid_loss)
        valid_loss = valid_loss.cpu().item()
        valid_loss /= len(val_loader.dataset)  
        print(f'Epoch: {epoch}, Training Loss: {training_loss:.4f}, Validation Loss: {valid_loss:.4f}, accuracy = {num_correct/num_examples:.4f}, F1={multiclass_f1_score(tot_pred, tot_targ, num_classes=3, average="macro"):.4f}')
    model.load_state_dict(best_model_tracker.best_state_dict)

# Dataset Creation

In [6]:
text_model_card = 'bert-base-uncased'
audio_model_card = 'facebook/wav2vec2-base-960h'

tokenizer = BertTokenizer.from_pretrained(text_model_card)
embedder = BertModel.from_pretrained(text_model_card).to(device)

for params in embedder.parameters():
    params.requires_grad = False

label_2_id = {
    'Claim': 0,
    'Premise': 1,
    'O': 2
}

DOWNSAMPLE_FACTOR = 1/5

class MM_Dataset(torch.utils.data.Dataset):
    def __init__(self, df, audio_dir, sample_rate):
        self.audio_dir = audio_dir
        self.sample_rate = sample_rate

        self.audio_processor = AutoProcessor.from_pretrained(audio_model_card)
        self.audio_model = AutoModel.from_pretrained(audio_model_card).to(device)

        self.dataset = []

        # Iterate over df
        for _, row in tqdm(df.iterrows()):
            path = os.path.join(self.audio_dir, f"{row['Document']}/{row['idClip']}.wav")
            if os.path.exists(path):
                # obtain audio WAV2VEC features
                audio, sampling_rate = torchaudio.load(path)
                if sampling_rate != self.sample_rate:
                    audio = torchaudio.functional.resample(audio, sample_rate, self.sample_rate)
                    audio = torch.mean(audio, dim=0, keepdim=True)
                with torch.inference_mode():
                    input_values = self.audio_processor(audio, sampling_rate=self.sample_rate).input_values[0]
                    input_values = torch.tensor(input_values).to(device)
                    audio_model_output = self.audio_model(input_values)
                    audio_features = audio_model_output.last_hidden_state[0].unsqueeze(0)
                    audio_features = torch.nn.functional.interpolate(audio_features.permute(0,2,1), scale_factor=DOWNSAMPLE_FACTOR, mode='linear')
                    audio_features = audio_features.permute(0,2,1)[0]
                    audio_features = audio_features.cpu()
                
                text = row['Text']

                self.dataset.append((text, audio_features, label_2_id[row['Component']]))
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return self.dataset[index]

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [7]:
try:
    train_dataset = torch.load(f'{save_path}/train_dataset.pkl')
    test_dataset = torch.load(f'{save_path}/test_dataset.pkl')
    val_dataset = torch.load(f'{save_path}/val_dataset.pkl')
    print('Restored datasets from memory')
except:
    print('Creating new datasets')
    train_dataset = MM_Dataset(train_df, audio_path, 16_000)
    test_dataset = MM_Dataset(test_df, audio_path, 16_000)
    val_dataset = MM_Dataset(val_df, audio_path, 16_000)
    torch.save(train_dataset, f'{save_path}/train_dataset.pkl')
    torch.save(test_dataset, f'{save_path}/test_dataset.pkl')
    torch.save(val_dataset, f'{save_path}/val_dataset.pkl')

Creating new datasets


preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.60k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize


model.safetensors:   0%|          | 0.00/378M [00:00<?, ?B/s]

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
4967it [08:29,  9.74it/s]


Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2986it [04:25, 11.23it/s]


Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2758it [05:18,  8.65it/s]


## Dataloader creation

In [8]:
def create_dataloader(dataset, batch_size):
    def pack_fn(batch):
        texts = [x[0] for x in batch]
        audio_features = [x[1] for x in batch]
        labels = torch.tensor([x[2] for x in batch])
        
        # pad audio features
        audio_features = pad_sequence(audio_features, batch_first=True, padding_value=float('-inf'))

        audio_features_attention_mask = audio_features[:, :, 0] != float('-inf')
        
        audio_features[(audio_features == float('-inf'))] = 0

        return texts, audio_features, audio_features_attention_mask, labels

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=pack_fn)
    return dataloader

In [9]:
train_dataloader = create_dataloader(train_dataset, 8)
val_dataloader = create_dataloader(val_dataset, 8)
test_dataloader = create_dataloader(test_dataset, 8)

In [10]:
#del early_fusion
gc.collect()

1096

# Positional Encoding

In [71]:
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dual_modality=False, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        self.dual_modality = dual_modality
        self.pe = self.pe.to(device)

    def forward(self, x, is_first=True):
        if self.dual_modality:
            modality = torch.ones((x.shape[0], x.shape[1], 4), dtype=torch.float32).to(device) * (0 if is_first else 1)
            x = x + self.pe[:x.size(0)]
            x = self.dropout(x)        
            return torch.cat((x, modality), axis=-1)
        else:
            x = x + self.pe[:x.size(0)]
            return self.dropout(x)

# Multimodal-Transformer Model

In [73]:
class MultiModalTransformer(nn.Module):
    def __init__(self, tokenizer, embedder, transformer, head):
        super().__init__()
        self.pos_encoder = PositionalEncoding(768, dual_modality=True)
        self.tokenizer = tokenizer
        self.embedder = embedder
        self.transformer = transformer
        self.head = head

    def forward(self, texts, audio_features, audio_attentions):
        tokenizer_output = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=False).to(device)
        embedder_output = self.embedder(**tokenizer_output, output_hidden_states=True)
        text_features = embedder_output['hidden_states'][0]
        text_features = self.pos_encoder(text_features, is_first=True)
        text_attentions = tokenizer_output.attention_mask
        
        audio_features = self.pos_encoder(audio_features, is_first=False)
        
        concatenated_features = torch.cat((text_features, audio_features), dim=1)
        concatenated_attentions = torch.cat((text_attentions, audio_attentions.float()), dim=1)
        
        # padding mask is 1 where there is padding (i.e. where attention is 0) and 0 otherwise
        concatenated_padding_mask = ~concatenated_attentions.to(torch.bool)
        
        # compute a full attention mask of size [seq_len, seq_len]
        full_attention_mask = torch.zeros((concatenated_features.shape[1], concatenated_features.shape[1]), dtype=torch.bool).to(device)
                
        transformer_output = self.transformer(src=concatenated_features,  mask=full_attention_mask, src_key_padding_mask=concatenated_padding_mask)
        transformer_output_sum = (transformer_output * concatenated_attentions.unsqueeze(-1)).sum(axis=1)
        transformer_output_pooled = transformer_output_sum / concatenated_attentions.sum(axis=1).unsqueeze(-1)
        return self.head(transformer_output_pooled)

transformer_layer = nn.TransformerEncoderLayer(d_model=772, nhead=4, dim_feedforward=512, batch_first=True).to(device)
transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=4).to(device)

head = nn.Sequential(
    nn.Linear(772, 256),
    nn.ReLU(),
    nn.Linear(256, 3)
).to(device)

multimodal_transformer = MultiModalTransformer(tokenizer, embedder, transformer_encoder, head).to(device)

optimizer = torch.optim.Adam(multimodal_transformer.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

train(multimodal_transformer, optimizer, criterion, train_dataloader, val_dataloader, epochs=10, device=device)

 10%|█         | 1/10 [01:31<13:39, 91.08s/it]

Epoch: 0, Training Loss: 0.1320, Validation Loss: 0.1361, accuracy = 0.4543, F1=0.2083


 20%|██        | 2/10 [03:02<12:08, 91.07s/it]

Epoch: 1, Training Loss: 0.1303, Validation Loss: 0.1335, accuracy = 0.4569, F1=0.2170


 30%|███       | 3/10 [04:33<10:37, 91.07s/it]

Epoch: 2, Training Loss: 0.1321, Validation Loss: 0.1340, accuracy = 0.4543, F1=0.2083


 40%|████      | 4/10 [06:04<09:06, 91.13s/it]

Epoch: 3, Training Loss: 0.1336, Validation Loss: 0.1328, accuracy = 0.4543, F1=0.2083


 50%|█████     | 5/10 [07:35<07:35, 91.08s/it]

Epoch: 4, Training Loss: 0.1319, Validation Loss: 0.1340, accuracy = 0.4543, F1=0.2083


 60%|██████    | 6/10 [09:06<06:04, 91.04s/it]

Epoch: 5, Training Loss: 0.1334, Validation Loss: 0.1326, accuracy = 0.4543, F1=0.2083


 70%|███████   | 7/10 [10:37<04:33, 91.10s/it]

Epoch: 6, Training Loss: 0.1333, Validation Loss: 0.1351, accuracy = 0.3183, F1=0.1610


 80%|████████  | 8/10 [12:08<03:02, 91.10s/it]

Epoch: 7, Training Loss: 0.1336, Validation Loss: 0.1365, accuracy = 0.4543, F1=0.2083


 90%|█████████ | 9/10 [13:40<01:31, 91.16s/it]

Epoch: 8, Training Loss: 0.1329, Validation Loss: 0.1336, accuracy = 0.4543, F1=0.2083


100%|██████████| 10/10 [15:11<00:00, 91.12s/it]

Epoch: 9, Training Loss: 0.1334, Validation Loss: 0.1355, accuracy = 0.4543, F1=0.2083





# Ensembling-Fusion Model

## Text-Only and Audio-Only Models 

In [56]:
class TextModel(nn.Module):
    def __init__(self, tokenizer, embedder, head):
        super().__init__()
        self.pos_encoder = PositionalEncoding(768, dual_modality=False)
        self.tokenizer = tokenizer
        self.embedder = embedder
        self.head = head
    def forward(self, texts, audio_features, audio_attention):
        tokenizer_output = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=False).to(device)
        embedder_output = self.embedder(**tokenizer_output, output_hidden_states=True)
        text_features = embedder_output['last_hidden_state']
        text_features = self.pos_encoder(text_features)        
        text_features_sum = (text_features * tokenizer_output.attention_mask.unsqueeze(-1)).sum(axis=1)
        text_features_pooled = text_features_sum / tokenizer_output.attention_mask.sum(axis=1).unsqueeze(-1)
        return self.head(text_features_pooled)
    
class AudioModel(nn.Module):        
    def __init__(self, transformer, head):
        super().__init__()
        self.pos_encoder = PositionalEncoding(768, dual_modality=False)
        self.transformer = transformer
        self.head = head
        
    def forward(self, texts, audio_features, audio_attention):
        padding_mask = ~audio_attention.to(torch.bool)
        audio_features = self.pos_encoder(audio_features)
        full_attention_mask = torch.zeros((audio_features.shape[1],audio_features.shape[1]), dtype=torch.bool).to(device)
        transformer_output = self.transformer(src=audio_features, mask=full_attention_mask, src_key_padding_mask=padding_mask)
        
        # pooling transformer output
        transformer_output_sum = (transformer_output * audio_attention.unsqueeze(-1)).sum(axis=1)
        transformer_output_pooled = transformer_output_sum / audio_attention.sum(axis=1).unsqueeze(-1)
        return self.head(transformer_output_pooled)

## Ensembling Model

In [58]:
 class EnsemblingFusion(nn.Module):
    def __init__(self, text_model, audio_model):
        super().__init__()
        self.text_model = text_model
        self.audio_model = audio_model
        self.weight = torch.nn.Parameter(torch.tensor(0.0))
        
    def forward(self, texts, audio_features, audio_attentions):
        text_logits = self.text_model(texts, audio_features, audio_attentions)
        audio_logits = self.audio_model(texts, audio_features, audio_attentions)
        
        text_probabilities = torch.nn.functional.softmax(text_logits)
        audio_probabilities = torch.nn.functional.softmax(audio_logits)
        
        coefficient = (torch.tanh(self.weight) + 1) / 2
        
        coefficient = coefficient*0.4 + 0.3
        
        return coefficient*text_probabilities + (1-coefficient)*audio_probabilities
    
text_head = nn.Sequential(
    nn.Linear(768, 256),
    nn.ReLU(),
    nn.Linear(256, 3)
).to(device)

audio_head = nn.Sequential(
    nn.Linear(768, 256),
    nn.ReLU(),
    nn.Linear(256, 3)
).to(device)

transformer_layer = nn.TransformerEncoderLayer(d_model=768, nhead=4, dim_feedforward=512, batch_first=True).to(device)
transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=4).to(device)

text_model = TextModel(tokenizer, embedder, text_head)
audio_model = AudioModel(transformer_encoder, audio_head)

ensembling_fusion = EnsemblingFusion(text_model, audio_model)

optimizer = torch.optim.Adam(ensembling_fusion.parameters(), lr=1e-4)

def custom_loss(outputs, targets):
    return torch.nn.functional.nll_loss(torch.log(outputs), targets, reduction='mean')

train(ensembling_fusion, optimizer, custom_loss, train_dataloader, val_dataloader, epochs=5, device=device)

 20%|██        | 1/5 [01:20<05:21, 80.28s/it]

Epoch: 0, Training Loss: 0.1268, Validation Loss: 0.1263, accuracy = 0.4543, F1=0.2083


 40%|████      | 2/5 [02:40<04:00, 80.23s/it]

Epoch: 1, Training Loss: 0.1183, Validation Loss: 0.1215, accuracy = 0.5174, F1=0.4007


 60%|██████    | 3/5 [04:00<02:40, 80.14s/it]

Epoch: 2, Training Loss: 0.1137, Validation Loss: 0.1202, accuracy = 0.5413, F1=0.4591


 80%|████████  | 4/5 [05:21<01:20, 80.34s/it]

Epoch: 3, Training Loss: 0.1113, Validation Loss: 0.1191, accuracy = 0.5518, F1=0.4844


100%|██████████| 5/5 [06:41<00:00, 80.24s/it]

Epoch: 4, Training Loss: 0.1098, Validation Loss: 0.1188, accuracy = 0.5540, F1=0.4901





# Text-Only

In [52]:
head = nn.Sequential(
    nn.Linear(768, 256),
    nn.ReLU(),
    nn.Linear(256, 3)
).to(device)

text_only = TextModel(tokenizer, embedder, head)

optimizer = torch.optim.Adam(text_only.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

train(text_only, optimizer, criterion, train_dataloader, val_dataloader, epochs=5, device=device)

 20%|██        | 1/5 [00:36<02:27, 36.80s/it]

Epoch: 0, Training Loss: 0.1243, Validation Loss: 0.1234, accuracy = 0.4833
0.2991773784160614


 40%|████      | 2/5 [01:14<01:52, 37.58s/it]

Epoch: 1, Training Loss: 0.1142, Validation Loss: 0.1184, accuracy = 0.5370
0.44944000244140625


 60%|██████    | 3/5 [01:53<01:15, 37.93s/it]

Epoch: 2, Training Loss: 0.1090, Validation Loss: 0.1165, accuracy = 0.5544
0.495005339384079


 80%|████████  | 4/5 [02:30<00:37, 37.54s/it]

Epoch: 3, Training Loss: 0.1064, Validation Loss: 0.1160, accuracy = 0.5638
0.511488676071167


100%|██████████| 5/5 [03:07<00:00, 37.43s/it]

Epoch: 4, Training Loss: 0.1052, Validation Loss: 0.1161, accuracy = 0.5667
0.514333188533783





# Unaligned Multimodal Modal

In [111]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ffn: int, dropout: float = 0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ffn)
        self.w_2 = nn.Linear(d_ffn, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(self.w_1(x).relu()))

class CrossModalAttentionBlock(nn.Module):
    def __init__(self, embedding_dim, d_ffn):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.d_ffn = d_ffn
        self.layer_norm = nn.LayerNorm(self.embedding_dim)
        self.mh_attention = nn.MultiheadAttention(self.embedding_dim, 4, 0.1, batch_first=True)
        self.pointwise_ff = PositionwiseFeedForward(self.embedding_dim, d_ffn=self.d_ffn)
    
    def forward(self, elem_a, elem_b, attn_mask):
        elem_a = self.layer_norm(elem_a)
        elem_b = self.layer_norm(elem_b)
        attn_mask = attn_mask.to(torch.float32)
        
        mh_out, _ = self.mh_attention(elem_a, elem_b, elem_b, key_padding_mask=attn_mask, need_weights=False)
        add_out = mh_out + elem_a
        
        add_out_norm = self.layer_norm(add_out)
        out_ffn = self.pointwise_ff(add_out_norm)
        out = out_ffn + add_out
        return out
    
class UnalignedMultimodalModel(nn.Module):
    def __init__(self, embedding_dim, d_ffn, n_blocks, head):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.d_ffn = d_ffn
        self.n_blocks = n_blocks
        self.head = head
        self.text_crossmodal_blocks = nn.ModuleList([
            CrossModalAttentionBlock(self.embedding_dim, self.d_ffn) for _ in range(self.n_blocks)
        ])
        self.audio_crossmodal_blocks = nn.ModuleList([
            CrossModalAttentionBlock(self.embedding_dim, self.d_ffn) for _ in range(self.n_blocks)
        ])
        self.pos_encoder = PositionalEncoding(embedding_dim, dual_modality=False)
    
    def forward(self, texts, audio_features, audio_attentions):
        tokenizer_output = tokenizer(texts, return_tensors='pt', padding=True, truncation=False).to(device)
        embedder_output = embedder(**tokenizer_output, output_hidden_states=True)
        text_features = embedder_output['hidden_states'][0]
        text_features = self.pos_encoder(text_features)
        text_attentions = tokenizer_output.attention_mask
        
        audio_features = self.pos_encoder(audio_features)
        
        text_crossmodal_out = text_features
        for cm_block in self.text_crossmodal_blocks:
            text_crossmodal_out = cm_block(text_crossmodal_out, audio_features, audio_attentions)
        
        audio_crossmodal_out = audio_features
        for cm_block in self.audio_crossmodal_blocks:
            audio_crossmodal_out = cm_block(audio_crossmodal_out, text_features, text_attentions)

        text_crossmodal_out_mean = torch.mean(text_crossmodal_out, dim=1)
        audio_crossmodal_out_mean = torch.mean(audio_crossmodal_out, dim=1)
        
        text_audio = torch.cat((text_crossmodal_out_mean, audio_crossmodal_out_mean), dim=-1)
        
        return self.head(text_audio)
        
        
head = nn.Sequential(
    nn.Linear(768*2, 256),
    nn.ReLU(),
    nn.Linear(256, 3)
).to(device)

unaligned_mm_model = UnalignedMultimodalModel(768, 100, 4, head).to(device)

optimizer = torch.optim.Adam(unaligned_mm_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

train(unaligned_mm_model, optimizer, criterion, train_dataloader, val_dataloader, epochs=5, device=device)

 20%|██        | 1/5 [01:11<04:46, 71.75s/it]

Epoch: 0, Training Loss: 0.1258, Validation Loss: 0.1215, accuracy = 0.5279, F1=0.4592


 40%|████      | 2/5 [02:23<03:34, 71.65s/it]

Epoch: 1, Training Loss: 0.1098, Validation Loss: 0.1197, accuracy = 0.5558, F1=0.5068


 60%|██████    | 3/5 [03:34<02:23, 71.60s/it]

Epoch: 2, Training Loss: 0.1014, Validation Loss: 0.1253, accuracy = 0.5620, F1=0.5301


 80%|████████  | 4/5 [04:46<01:11, 71.61s/it]

Epoch: 3, Training Loss: 0.0957, Validation Loss: 0.1319, accuracy = 0.5566, F1=0.5191


100%|██████████| 5/5 [05:58<00:00, 71.64s/it]

Epoch: 4, Training Loss: 0.0897, Validation Loss: 0.1395, accuracy = 0.5580, F1=0.5252



