In [2]:
import numpy as np
import pandas as pd
import transformers
from transformers import BertTokenizer, BertModel, AutoModel, AutoProcessor
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch, torchaudio, torchtext
import torch.nn as nn
import os
import gc
import pickle
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')



Using device: cuda


In [4]:
try:
    df_path = '/kaggle/input/MM-USElecDeb60to16/MM-USElecDeb60to16.csv'
    audio_path = '/kaggle/input/MM-USElecDeb60to16/audio_clips'
    save_path = '/kaggle/working/'
    df = pd.read_csv(df_path, index_col=0)
except FileNotFoundError:
    df_path = 'multimodal-dataset/files/MM-USElecDeb60to16/MM-USElecDeb60to16.csv'
    audio_path = 'multimodal-dataset/files/MM-USElecDeb60to16/audio_clips'
    save_path = 'multimodal-dataset/files'
    df = pd.read_csv(df_path, index_col=0)
# drop rows where audio length is 0
df = df[df['NewBegin'] != df['NewEnd']]

train_df_complete = df[df['Set'] == 'TRAIN']
val_df_complete = df[df['Set'] == 'VALIDATION']
test_df_complete = df[df['Set'] == 'TEST']

DATASET_RATIO = 0.40
# NUM_SAMPLES = 3000
# indexes = np.random.choice(list(range(len(train_df_complete))), NUM_SAMPLES, replace=False)
# train_idx = indexes[:int(0.7*NUM_SAMPLES)]
# val_idx = indexes[int(0.7*NUM_SAMPLES):int(0.8*NUM_SAMPLES)]
# test_idx = indexes[int(0.8*NUM_SAMPLES):]

train_df = train_df_complete.iloc[:int(DATASET_RATIO * len(train_df_complete))]
val_df = val_df_complete.iloc[:int(DATASET_RATIO * len(val_df_complete))]
test_df = test_df_complete.iloc[:int(DATASET_RATIO * len(test_df_complete))]

In [5]:
train_df.head()

Unnamed: 0,Text,Part,Document,Order,Sentence,Start,End,Annotator,Tag,Component,...,Speaker,SpeakerType,Set,Date,Year,Name,MainTag,NewBegin,NewEnd,idClip
0,"CHENEY: Gwen, I want to thank you, and I want ...",1,30_2004,0,0,2101,2221,,"{""O"": 27}",O,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,O,126.52,131.08,clip_0
1,"It's a very important event, and they've done ...",1,30_2004,1,1,2221,2304,,"{""O"": 19}",O,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,O,131.08,134.4,clip_1
2,It's important to look at all of our developme...,1,30_2004,2,2,2304,2418,,"{""O"": 23}",O,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,O,134.4,140.56,clip_2
3,"And, after 9/11, it became clear that we had t...",1,30_2004,3,3,2418,2744,,"{""O"": 16, ""Claim"": 50}",Claim,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,Claim,140.56,158.92,clip_3
4,And we also then finally had to stand up democ...,1,30_2004,4,4,2744,2974,,"{""O"": 4, ""Claim"": 13, ""Premise"": 25}",Premise,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,Mixed,158.92,172.92,clip_4


In [6]:
len(train_df), len(test_df), len(val_df)

(4967, 2986, 2758)

In [8]:
class BestModel:
    """
        Class to keep track of the best performing model on validation set during training
    """
    def __init__(self):
        self.best_validation_loss = float('Infinity')
        self.best_state_dict = None
    def __call__(self, model, loss):
        if loss < self.best_validation_loss:
            self.best_validation_loss = loss
            self.best_state_dict = model.state_dict()

def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=10, device="cuda"):
    best_model_tracker = BestModel()
    for epoch in tqdm(range(epochs)):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()

        for batch in train_loader:
            optimizer.zero_grad()
            #text_features, text_attention, audio_features, audio_attention, targets = batch
            texts, audio_features, audio_attention, targets = batch
            #text_features = text_features.to(device)
            #text_attention = text_attention.to(device)
            audio_features = audio_features.to(device)
            audio_attention = audio_attention.to(device)
            targets = targets.to(device)
            #output = model(text_features,text_attention,audio_features,audio_attention)
            output = model(texts,audio_features,audio_attention)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.detach()
        training_loss = training_loss.cpu().item()
        training_loss /= len(train_loader.dataset)

        model.eval()
        num_correct = 0 
        num_examples = 0
        for batch in val_loader:
            #text_features, text_attention, audio_features, audio_attention, targets = batch
            texts, audio_features, audio_attention, targets = batch
            #text_features = text_features.to(device)
            #text_attention = text_attention.to(device)
            audio_features = audio_features.to(device)
            audio_attention = audio_attention.to(device)
            targets = targets.to(device)
            #output = model(text_features,text_attention,audio_features,audio_attention)
            output = model(texts,audio_features,audio_attention)
            loss = loss_fn(output, targets)
            valid_loss += loss.detach()
            predicted_labels = torch.argmax(output, dim=-1)
            correct = torch.eq(predicted_labels, targets).view(-1)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        best_model_tracker(model, valid_loss)
        valid_loss = valid_loss.cpu().item()
        valid_loss /= len(val_loader.dataset)
        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss, valid_loss, num_correct/num_examples))
    model.load_state_dict(best_model_tracker.best_state_dict)

In [9]:
text_model_card = 'bert-base-uncased'
audio_model_card = 'facebook/wav2vec2-base-960h'

tokenizer = BertTokenizer.from_pretrained(text_model_card)
embedder = BertModel.from_pretrained(text_model_card).to(device)

for params in embedder.parameters():
    params.requires_grad = False

label_2_id = {
    'Claim': 0,
    'Premise': 1,
    'O': 2
}

DOWNSAMPLE_FACTOR = 1/5

class MM_Dataset(torch.utils.data.Dataset):
    def __init__(self, df, audio_dir, sample_rate):
        self.audio_dir = audio_dir
        self.sample_rate = sample_rate

        self.audio_processor = AutoProcessor.from_pretrained(audio_model_card)
        self.audio_model = AutoModel.from_pretrained(audio_model_card).to(device)
        
        # self.text_tokenizer = BertTokenizer.from_pretrained(text_model_card)
        # self.text_model = BertModel.from_pretrained(text_model_card).to(device)
        self.dataset = []

        # Iterate over df
        for _, row in tqdm(df.iterrows()):
            path = os.path.join(self.audio_dir, f"{row['Document']}/{row['idClip']}.wav")
            if os.path.exists(path):
                # obtain audio WAV2VEC features
                audio, sampling_rate = torchaudio.load(path)
                if sampling_rate != self.sample_rate:
                    audio = torchaudio.functional.resample(audio, sample_rate, self.sample_rate)
                    audio = torch.mean(audio, dim=0, keepdim=True)
                with torch.inference_mode():
                    input_values = self.audio_processor(audio, sampling_rate=self.sample_rate).input_values[0]
                    input_values = torch.tensor(input_values).to(device)
                    audio_model_output = self.audio_model(input_values)
                    audio_features = audio_model_output.last_hidden_state[0].unsqueeze(0)
#                     audio_features = torch.nn.AvgPool1d(kernel_size=DOWNSAMPLE_FACTOR,stride=DOWNSAMPLE_FACTOR,)
                    audio_features = torch.nn.functional.interpolate(audio_features.permute(0,2,1), scale_factor=DOWNSAMPLE_FACTOR, mode='linear')
                    audio_features = audio_features.permute(0,2,1)[0]
                    audio_features = audio_features.cpu()
                # obtain text BERT features
                text = row['Text']

                # token_text = self.text_tokenizer(text, return_tensors='pt').to(device)
                # with torch.inference_mode():
                #     text_features = self.text_model(**token_text)
                #     text_features = text_features
                #     text_features = text_features.cpu()
                # del token_text
                # gc.collect()

                self.dataset.append((text, audio_features, label_2_id[row['Component']]))
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return self.dataset[index]

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [10]:
try:
    train_dataset = torch.load(f'{save_path}/train_dataset.pkl')
    test_dataset = torch.load(f'{save_path}/test_dataset.pkl')
    val_dataset = torch.load(f'{save_path}/val_dataset.pkl')
    print('Restored datasets from memory')
except:
    print('Creating new datasets')
    train_dataset = MM_Dataset(train_df, audio_path, 16_000)
    test_dataset = MM_Dataset(test_df, audio_path, 16_000)
    val_dataset = MM_Dataset(val_df, audio_path, 16_000)
    torch.save(train_dataset, f'{save_path}/train_dataset.pkl')
    torch.save(test_dataset, f'{save_path}/test_dataset.pkl')
    torch.save(val_dataset, f'{save_path}/val_dataset.pkl')

Creating new datasets


preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.60k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize


model.safetensors:   0%|          | 0.00/378M [00:00<?, ?B/s]

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
4967it [08:13, 10.07it/s]


Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2986it [04:22, 11.39it/s]


Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2758it [05:11,  8.84it/s]


In [11]:
def create_dataloader(dataset, batch_size):
    def pack_fn(batch):
        texts = [x[0] for x in batch]
        audio_features = [x[1] for x in batch]
        labels = torch.tensor([x[2] for x in batch])

        # tokenize text
        # tokenizer_output = tokenizer(texts, return_tensors='pt', padding=True, truncation=False).to(device)
        # embed text
        # embedder_output = embedder(**tokenizer_output, output_hidden_states=True)

        # text_features = embedder_output['hidden_states'][0]
        
        # pad audio features
        audio_features = pad_sequence(audio_features, batch_first=True, padding_value=float('-inf'))

        audio_features_attention_mask = audio_features[:, :, 0] != float('-inf')
        
        audio_features[(audio_features == float('-inf'))] = 0

        #return text_features, tokenizer_output.attention_mask, audio_features, audio_features_attention_mask, labels
        return texts, audio_features, audio_features_attention_mask, labels

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=pack_fn)
    return dataloader

In [12]:
train_dataloader = create_dataloader(train_dataset, 8)
val_dataloader = create_dataloader(val_dataset, 8)
test_dataloader = create_dataloader(test_dataset, 8)

In [13]:
#del early_fusion
gc.collect()

48

In [15]:
class EarlyFusion(nn.Module):
    def __init__(self, tokenizer, embedder, transformer, head):
        super().__init__()
        self.tokenizer = tokenizer
        self.embedder = embedder
        self.transformer = transformer
        self.head = head

    #def forward(self, text_features, text_attentions, audio_features, audio_attentions):
    def forward(self, texts, audio_features, audio_attentions):
        tokenizer_output = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=False).to(device)
        embedder_output = self.embedder(**tokenizer_output, output_hidden_states=True)
        text_features = embedder_output['hidden_states'][0]
        text_attentions = tokenizer_output.attention_mask
        
        concatenated_features = torch.cat((text_features, audio_features), dim=1)
        concatenated_attentions = torch.cat((text_attentions, audio_attentions.float()), dim=1)
        
        # padding mask is 1 where there is padding (i.e. where attention is 0) and 0 otherwise
        concatenated_padding_mask = ~concatenated_attentions.to(torch.bool)
        
        # compute a full attention mask of size [seq_len, seq_len]
        full_attention_mask = torch.zeros((concatenated_features.shape[1], concatenated_features.shape[1]), dtype=torch.bool).to(device)
                
        transformer_output = self.transformer(src=concatenated_features,  mask=full_attention_mask, src_key_padding_mask=concatenated_padding_mask)
        transformer_output_sum = (transformer_output * concatenated_attentions.unsqueeze(-1)).sum(axis=1)
        transformer_output_pooled = transformer_output_sum / concatenated_attentions.sum(axis=1).unsqueeze(-1)
        return self.head(transformer_output_pooled)

transformer_layer = nn.TransformerEncoderLayer(d_model=768, nhead=4, dim_feedforward=512, batch_first=True).to(device)
transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=4).to(device)

head = nn.Sequential(
    nn.Linear(768, 256),
    nn.ReLU(),
    nn.Linear(256, 3)
).to(device)

early_fusion = EarlyFusion(tokenizer, embedder, transformer_encoder, head).to(device)

optimizer = torch.optim.Adam(early_fusion.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

train(early_fusion, optimizer, criterion, train_dataloader, val_dataloader, epochs=10, device=device)

 10%|█         | 1/10 [01:22<12:26, 82.94s/it]

Epoch: 0, Training Loss: 0.12, Validation Loss: 0.12, accuracy = 0.53


 20%|██        | 2/10 [02:45<11:02, 82.82s/it]

Epoch: 1, Training Loss: 0.11, Validation Loss: 0.12, accuracy = 0.54


 30%|███       | 3/10 [04:08<09:39, 82.73s/it]

Epoch: 2, Training Loss: 0.10, Validation Loss: 0.13, accuracy = 0.54


 40%|████      | 4/10 [05:30<08:16, 82.71s/it]

Epoch: 3, Training Loss: 0.09, Validation Loss: 0.14, accuracy = 0.53


 50%|█████     | 5/10 [06:53<06:53, 82.70s/it]

Epoch: 4, Training Loss: 0.09, Validation Loss: 0.16, accuracy = 0.54


 60%|██████    | 6/10 [08:16<05:30, 82.70s/it]

Epoch: 5, Training Loss: 0.08, Validation Loss: 0.16, accuracy = 0.54


 70%|███████   | 7/10 [09:39<04:08, 82.71s/it]

Epoch: 6, Training Loss: 0.07, Validation Loss: 0.17, accuracy = 0.54


 80%|████████  | 8/10 [11:01<02:45, 82.68s/it]

Epoch: 7, Training Loss: 0.07, Validation Loss: 0.17, accuracy = 0.55


 90%|█████████ | 9/10 [12:24<01:22, 82.68s/it]

Epoch: 8, Training Loss: 0.06, Validation Loss: 0.19, accuracy = 0.54


100%|██████████| 10/10 [13:47<00:00, 82.72s/it]

Epoch: 9, Training Loss: 0.06, Validation Loss: 0.18, accuracy = 0.54





In [44]:
class TextModel(nn.Module):
    def __init__(self, tokenizer, embedder, head):
        super().__init__()
        self.tokenizer = tokenizer
        self.embedder = embedder
        self.head = head
    def forward(self, texts):
        tokenizer_output = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=False).to(device)
        embedder_output = self.embedder(**tokenizer_output, output_hidden_states=True)
        text_features = embedder_output['last_hidden_state']
        
        text_features_sum = (text_features * tokenizer_output.attention_mask.unsqueeze(-1)).sum(axis=1)
        text_features_pooled = text_features_sum / tokenizer_output.attention_mask.sum(axis=1).unsqueeze(-1)
        return self.head(text_features_pooled)
    
class AudioModel(nn.Module):        
    def __init__(self, transformer, head):
        super().__init__()
        self.transformer = transformer
        self.head = head
        
    def forward(self, audio_features, audio_attention):
        padding_mask = ~audio_attention.to(torch.bool)
        full_attention_mask = torch.zeros((audio_features.shape[1],audio_features.shape[1]), dtype=torch.bool).to(device)
        transformer_output = self.transformer(src=audio_features, mask=full_attention_mask, src_key_padding_mask=padding_mask)
        
        # pooling transformer output
        transformer_output_sum = (transformer_output * audio_attention.unsqueeze(-1)).sum(axis=1)
        transformer_output_pooled = transformer_output_sum / audio_attention.sum(axis=1).unsqueeze(-1)
        return self.head(transformer_output_pooled)
    
class EnsemblingFusion(nn.Module):
    def __init__(self, text_model, audio_model):
        super().__init__()
        self.text_model = text_model
        self.audio_model = audio_model
        self.weight = torch.nn.Parameter(torch.tensor(0.0))
        
    def forward(self, texts, audio_features, audio_attentions):
        text_logits = self.text_model(texts)
        audio_logits = self.audio_model(audio_features, audio_attentions)
        
        text_probabilities = torch.nn.functional.softmax(text_logits)
        audio_probabilities = torch.nn.functional.softmax(audio_logits)
        
        coefficient = (torch.tanh(self.weight) + 1) / 2
        
        return coefficient*text_probabilities + (1-coefficient)*audio_probabilities
    
text_head = nn.Sequential(
    nn.Linear(768, 256),
    nn.ReLU(),
    nn.Linear(256, 3)
).to(device)

audio_head = nn.Sequential(
    nn.Linear(768, 256),
    nn.ReLU(),
    nn.Linear(256, 3)
).to(device)

transformer_layer = nn.TransformerEncoderLayer(d_model=768, nhead=4, dim_feedforward=512, batch_first=True).to(device)
transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=4).to(device)

text_model = TextModel(tokenizer, embedder, text_head)
audio_model = AudioModel(transformer_encoder, audio_head)

ensembling_fusion = EnsemblingFusion(text_model, audio_model)

optimizer = torch.optim.Adam(early_fusion.parameters(), lr=1e-4)

def custom_loss(outputs, targets):
    return torch.nn.functional.nll_loss(torch.log(outputs), targets, reduction='mean')

train(ensembling_fusion, optimizer, custom_loss, train_dataloader, val_dataloader, epochs=10, device=device)

 10%|█         | 1/10 [00:23<03:35, 23.99s/it]

Epoch: 0, Training Loss: 0.14, Validation Loss: 0.14, accuracy = 0.35


 20%|██        | 2/10 [00:47<03:11, 23.99s/it]

Epoch: 1, Training Loss: 0.14, Validation Loss: 0.14, accuracy = 0.35


 30%|███       | 3/10 [01:11<02:47, 23.96s/it]

Epoch: 2, Training Loss: 0.14, Validation Loss: 0.14, accuracy = 0.35


 30%|███       | 3/10 [01:24<03:17, 28.23s/it]


KeyboardInterrupt: 