In [1]:
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def read_stored_data(stored_folder):
    audio_masks, text_masks, visual_masks, preds, targets = [], [], [], [], []
    for i in range(1, 23):
        visual_masks.append(torch.load(f"{stored_folder}/masks/mask_visual_batch_{i}.pt").cpu().numpy())
        audio_masks.append(torch.load(f"{stored_folder}/masks/mask_audio_batch_{i}.pt").cpu().numpy())
        text_masks.append(torch.load(f"{stored_folder}/masks/mask_text_batch_{i}.pt").cpu().numpy())
        preds.append(torch.load(f"{stored_folder}/preds/batch_{i}.pt").cpu().numpy())
        targets.append(torch.load(f"{stored_folder}/labels/batch_{i}.pt").cpu().numpy())

    return audio_masks, text_masks, visual_masks, preds, targets

In [3]:
def divide_samples(mask_list):
    all_elements = []
    for batch in mask_list:
        unique_elem_list = [batch[i] for i in range(batch.shape[0])]
        all_elements += unique_elem_list
    return unique_elem_list

In [13]:
audio_masks, text_masks, visual_masks, preds, targets = read_stored_data(stored_folder="../ablation_study/base")
print(len(audio_masks), len(visual_masks), len(text_masks), len(preds), len(targets))
for i in range(len(audio_masks)):
    print(audio_masks[i].shape[0] == text_masks[i].shape[0] == visual_masks[i].shape[0])

22 22 22 22 22
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


In [15]:
for masks in audio_masks:
    print(masks.shape)

(32, 42, 74)
(32, 68, 74)
(32, 64, 74)
(32, 41, 74)
(32, 48, 74)
(32, 122, 74)
(32, 56, 74)
(32, 40, 74)
(32, 26, 74)
(32, 43, 74)
(32, 37, 74)
(32, 48, 74)
(32, 87, 74)
(32, 61, 74)
(32, 69, 74)
(32, 95, 74)
(32, 42, 74)
(32, 33, 74)
(32, 92, 74)
(32, 51, 74)
(32, 38, 74)
(14, 38, 74)


In [6]:
def get_unique_ordered_samples(batches):
    unique_samples = []
    for batch in batches:
        for sample in batch:
            unique_samples.append(sample)

    return unique_samples

In [7]:
audio_masks = get_unique_ordered_samples(audio_masks)
text_masks = get_unique_ordered_samples(text_masks)
visual_masks = get_unique_ordered_samples(visual_masks)
preds = [item for arr in preds for item in arr]
targets = [item for arr in targets for item in arr]
print(len(audio_masks), len(visual_masks), len(text_masks), len(preds), len(targets))

686 686 686 686 686


In [12]:
audio_masks[0].shape, audio_masks[31].shape

((42, 74), (42, 74))

In [23]:
def read_stored_data(stored_folder):
    audio_masks, text_masks, visual_masks, preds, targets = [], [], [], [], []
    for i in range(1, 23):
        visual_masks.append(torch.load(f"{stored_folder}/masks/mask_visual_batch_{i}.pt"))
        audio_masks.append(torch.load(f"{stored_folder}/masks/mask_audio_batch_{i}.pt"))
        text_masks.append(torch.load(f"{stored_folder}/masks/mask_text_batch_{i}.pt"))
        preds.append(torch.load(f"{stored_folder}/preds/batch_{i}.pt"))
        targets.append(torch.load(f"{stored_folder}/labels/batch_{i}.pt"))

    return audio_masks, text_masks, visual_masks, preds, targets

In [31]:
audio_masks, text_masks, visual_masks, preds, targets = read_stored_data(stored_folder="../ablation_study/base")
print(len(audio_masks), len(visual_masks), len(text_masks), len(preds), len(targets))
"""for i in range(len(audio_masks)):
    print(audio_masks[i].shape[0] == text_masks[i].shape[0] == visual_masks[i].shape[0])"""

22 22 22 22 22


'for i in range(len(audio_masks)):\n    print(audio_masks[i].shape[0] == text_masks[i].shape[0] == visual_masks[i].shape[0])'

In [32]:
# `masks` is a list of (batch_size, time, features) tensors, one per batch
all_audio_masks, all_visual_masks, all_text_masks = [], [], []
all_preds = []
all_labels = []

for batch_audio_mask, batch_text_mask, batch_vis_mask, preds, labels in zip(audio_masks, text_masks, visual_masks, preds, targets):
    # Collapse over time with mean (shape: batch_size × feature_dim)
    audio_time_reduced = batch_audio_mask.mean(dim=1)
    text_time_reduced = batch_text_mask.mean(dim=1)
    visual_time_reduced = batch_vis_mask.mean(dim=1)
    
    all_audio_masks.append(audio_time_reduced.cpu().numpy())
    all_text_masks.append(text_time_reduced.cpu().numpy())
    all_visual_masks.append(visual_time_reduced.cpu().numpy())
    all_preds.append(preds.cpu().numpy())
    all_labels.append(labels.cpu().numpy())

# Concatenate across all batches
all_audio_masks = np.concatenate(all_audio_masks, axis=0)   # (N, 74)
all_text_masks = np.concatenate(all_text_masks, axis=0)   # (N, 74)
all_visual_masks = np.concatenate(all_visual_masks, axis=0)   # (N, 74)
all_preds = np.array([item for arr in all_preds for item in arr])
all_labels = np.array([item for arr in all_labels for item in arr])
#all_preds = np.concatenate(all_preds, axis=0)   # (N,)
#all_labels = np.concatenate(all_labels, axis=0) # (N,)

all_audio_masks.shape, all_text_masks.shape, all_visual_masks.shape, all_preds.shape, all_labels.shape

((686, 74), (686, 300), (686, 35), (686,), (686,))