In [1]:
%rm -rf /dev/shm/whisper/data

In [2]:
import os
import random
import time
from load_dataset import load_data, load_dataset
from transformers import WhisperTokenizer


class Args:
    def __init__(self):
        self.max_seq_length = 448
        self.context_size = 5


args = Args()

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small")

datasets_to_test = ["Ego4D", "Youtube"]
modes_to_test = ["train", "val", "test"]
strategies = ["Identity Declaration", "Accusation", "Interrogation", "Call for Action", "Defense", "Evidence"]

print("\n## Testing All Dataset and Mode Combinations")

positive_percentage = {}
for dataset_name in datasets_to_test:
    args.dataset = dataset_name
    args.data_dir = f"/dev/shm/whisper/data/{args.dataset}"
    os.makedirs(args.data_dir, exist_ok=True)

    start_time = time.time()
    load_data(args, "train")
    print(f"Time to load train data: {time.time() - start_time:.2f} seconds")
    
    start_time = time.time()
    load_data(args, "val")
    print(f"Time to load val data: {time.time() - start_time:.2f} seconds")
    
    start_time = time.time()
    load_data(args, "test")
    print(f"Time to load test data: {time.time() - start_time:.2f} seconds")

    positive_percentage[dataset_name] = {}
    for mode_name in modes_to_test:
        positive_percentage[dataset_name][mode_name] = {}
        for strategy in strategies:
            print(f"\nTesting dataset: {dataset_name}, mode: {mode_name}, strategy: {strategy}")

            args.dataset = dataset_name

            start_time = time.time()
            test_dataset = load_dataset(args, strategy, tokenizer, mode_name)
            load_time = time.time() - start_time
            
            zeros = sum(1 for sample in test_dataset if sample['labels'] == 0)
            ones = sum(1 for sample in test_dataset if sample['labels'] == 1)
            positive_percentage[dataset_name][mode_name][strategy] = ones / len(test_dataset)

            print(f"Dataset size: {len(test_dataset)} samples")
            print(f"  Number of 0 labels: {zeros}")
            print(f"  Number of 1 labels: {ones}")
            print(f"  Time taken to load dataset: {load_time:.2f} seconds")
            if len(test_dataset) > 0:
                sample = test_dataset[random.randint(0, len(test_dataset) - 1)]
                print("Sample:")
                print(f"  Label: {sample['labels']}")
                print(f"  Decoder input ids shape: {len(sample['decoder_input_ids'])}")
                print(f"  Decoder attention mask shape: {len(sample['decoder_attention_mask'])}")
                print(f"  Audio path: {sample['audio_path']}")
                print(f"  Start sample: {sample['start_sample']}")
                print(f"  End sample: {sample['end_sample']}")
                print(f"  Decoded text: {tokenizer.decode(sample['decoder_input_ids'])}")


print("\n## Positive Percentage Summary (% of samples with label 1)")
print(f"{'Dataset':<10} {'Mode':<6} " + " ".join([f"{strategy[:10]:<12}" for strategy in strategies]))

for dataset_name in datasets_to_test:
    for mode_name in modes_to_test:
        row = f"{dataset_name:<10} {mode_name:<6} "
        for strategy in strategies:
            if strategy in positive_percentage[dataset_name][mode_name]:
                percentage = positive_percentage[dataset_name][mode_name][strategy] * 100
                row += f"{percentage:>10.2f}%  "
            else:
                row += f"{'N/A':>10}  "
        print(row)

  from .autonotebook import tqdm as notebook_tqdm



## Testing All Dataset and Mode Combinations
Time to load train data: 76.62 seconds
Time to load val data: 13.61 seconds
Time to load test data: 19.05 seconds

Testing dataset: Ego4D, mode: train, strategy: Identity Declaration
Dataset size: 3746 samples
  Number of 0 labels: 3533
  Number of 1 labels: 213
  Time taken to load dataset: 2.26 seconds
Sample:
  Label: 0
  Decoder input ids shape: 448
  Decoder attention mask shape: 448
  Audio path: /dev/shm/whisper/data/Ego4D/audios/train/0c2659db-7bd4-4b37-9b08-4e247befe382_Game7.npy
  Start sample: 4032000
  End sample: 4384000
  Decoded text: <|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoft