# Dataloading benchmark

In [6]:
!mkdir -p /tmp/zverev/vggsound 2>/dev/null && \
squashfuse /storage/local/vggsound.squashfs /tmp/zverev/vggsound 
! METADATA_DIR=/tmp/zverev && \
MOUNTPOINT=$METADATA_DIR/vggsound && \
tr_data=/storage/slurm/zverev/datasets/cav-mae/vggsound/metadata/vgg_train_cleaned.json && \
te_data=/storage/slurm/zverev/datasets/cav-mae/vggsound/metadata/vgg_test_cleaned.json && \
w_data=/storage/slurm/zverev/datasets/cav-mae/vggsound/metadata/vgg_train_cleaned_weight.csv && \
cp $tr_data $METADATA_DIR/ && \
cp $te_data $METADATA_DIR/ && \
cp $w_data $METADATA_DIR/ && \
sed -i 's|/storage/slurm/zverev/datasets/cav-mae|'$MOUNTPOINT'|g' $METADATA_DIR/vgg_train_cleaned.json && \
sed -i 's|/storage/slurm/zverev/datasets/cav-mae|'$MOUNTPOINT'|g' $METADATA_DIR/vgg_test_cleaned.json

In [5]:
!fusermount -u /tmp/zverev/vggsound

In [2]:
import torch
import src.dataloader_ft as dataloader
import src.utils as utils
import os

# Fill args from run_videoonly.slurm
args = type('Args', (), {})()
args.data_train = '/tmp/zverev/vgg_train_cleaned.json'
args.label_csv = '/storage/slurm/zverev/datasets/cav-mae/vggsound/metadata/class_labels_indices_vgg.csv'
args.batch_size = 32
args.num_workers = 1
args.sql_path = '/home/wiss/zverev/AVSiam/artefacts/sql'
args.video_path_prefix = '/tmp/zverev/vggsound/video'
# Fill remaining args from run_videoonly.slurm
args.target_length = 1024
args.freqm = 48
args.timem = 192
args.mixup = 0.5
args.dataset = 'vggsound'
args.dataset_mean = -5.081
args.dataset_std = 4.4849
args.noise = True
args.label_smooth = 0.1
im_res = 224  # Standard ViT resolution

args.world_size = 1
args.local_rank = 0
args.dist_url = 'env://'

# Audio config defined from args
audio_conf = {'num_mel_bins': 128, 'target_length': args.target_length, 'freqm': args.freqm, 'timem': args.timem, 'mixup': args.mixup,
              'dataset': args.dataset, 'mode':'train', 'mean':args.dataset_mean, 'std':args.dataset_std,
              'noise':args.noise, 'label_smooth': args.label_smooth, 'im_res': im_res}

os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['LOCAL_RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

if not torch.distributed.is_initialized():
    utils.init_distributed_mode(args)

train_sampler = torch.utils.data.distributed.DistributedSampler(dataloader.AudiosetDataset(args.data_train, label_csv=args.label_csv, audio_conf=audio_conf, sql_path=args.sql_path, video_path_prefix=args.video_path_prefix),shuffle=True)
train_dataset = dataloader.AudiosetDataset(args.data_train, label_csv=args.label_csv, audio_conf=audio_conf, video_path_prefix=args.video_path_prefix)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size, 
    sampler=train_sampler, 
    shuffle=False, 
    num_workers=args.num_workers, 
    pin_memory=True, 
    drop_last=True,
    persistent_workers=True,
)

| distributed init (rank 0): env://, gpu 0
Using Label Smoothing: 0.1
now using following mask: 48 freq, 192 time
now using mix-up with rate 0.500000
now process vggsound
use dataset mean -5.081 and std 4.485 to normalize the input.
now use noise augmentation
number of classes is 309
now in train mode.
now use frame -1 from total 10 frames
now using 224 * 224 image input
Dataset has 183727 samples
Using Label Smoothing: 0.1
now using following mask: 48 freq, 192 time
now using mix-up with rate 0.500000
now process vggsound
use dataset mean -5.081 and std 4.485 to normalize the input.
now use noise augmentation
number of classes is 309
now in train mode.
now use frame -1 from total 10 frames
now using 224 * 224 image input
Dataset has 183727 samples


In [3]:
# !pip install line_profiler
%load_ext line_profiler

In [7]:
from tqdm.auto import trange
import random
import numpy as np
import torchvision
import torch
import torchaudio
import time

def get_datum(self, index):
    query = f"SELECT * FROM annos WHERE id = {index};"
    res = self.cur.execute(query)
    datum = res.fetchone()

    return datum

def get_item(self, index):
    start_time = time.time()

    # SQL query timing
    sql_start = time.time()
    query = f"SELECT * FROM annos WHERE id = {index};"
    res = self.cur.execute(query)
    datum = res.fetchone()
    sql_time = time.time() - sql_start

    # Path construction timing
    path_start = time.time()
    if self.dataset == 'vggsound':
        video_path = os.path.join(self.video_path_prefix, datum[1] + '.mp4')
        mix_sample_idx = random.randint(0, self.num_samples-1)
        query = f"SELECT * FROM annos WHERE id = {mix_sample_idx};"
        res = self.cur.execute(query)
        mix_datum = res.fetchone()
        video_path_mix = os.path.join(self.video_path_prefix, mix_datum[1] + '.mp4')

    elif self.dataset == 'audioset_20k':
        if self.audio_conf['mode'] == 'train':
            video_path = '/mnt/opr/yblin/audioset_sun/train_balanced/'+datum[1]+'.mp4'
            mix_sample_idx = random.randint(0, self.num_samples-1)
            query = f"SELECT * FROM annos WHERE id = {mix_sample_idx};"
            res = self.cur.execute(query)
            mix_datum = res.fetchone()
            video_path_mix = os.path.join(self.video_path_prefix, mix_datum[1] + '.mp4')
        else:
            video_path = os.path.join(self.video_path_prefix, datum[1] + '.mp4')
            mix_sample_idx = random.randint(0, self.num_samples-1)
            query = f"SELECT * FROM annos WHERE id = {mix_sample_idx};"
            res = self.cur.execute(query)
            mix_datum = res.fetchone()
            video_path_mix = os.path.join(self.video_path_prefix, mix_datum[1] + '.mp4')

    elif self.dataset == 'audioset_2m':
        if self.audio_conf['mode'] == 'train':
            video_path = os.path.join(self.video_path_prefix, datum[1] + '.mp4')
            mix_sample_idx = random.randint(0, self.num_samples-1)
            query = f"SELECT * FROM annos WHERE id = {mix_sample_idx};"
            res = self.cur.execute(query)
            mix_datum = res.fetchone()
            video_path_mix = os.path.join(self.video_path_prefix, mix_datum[1] + '.mp4')
        else:
            video_path = os.path.join(self.video_path_prefix, datum[1] + '.mp4')
            mix_sample_idx = random.randint(0, self.num_samples-1)
            query = f"SELECT * FROM annos WHERE id = {mix_sample_idx};"
            res = self.cur.execute(query)
            mix_datum = res.fetchone()
            video_path_mix = os.path.join(self.video_path_prefix, mix_datum[1] + '.mp4')
    path_time = time.time() - path_start

    # Audio/video processing timing
    if random.random() < self.mixup:
        # get the mixed fbank
        mix_lambda = np.random.beta(10, 10)

        try:
            fbank = self._wav2fbank(video_path, video_path_mix, mix_lambda)
        except Exception as e:
            print(e)
            fbank = torch.zeros([self.target_length, 128]) + 0.01
            print('there is an error in loading audio 1', datum[1],mix_datum[1])

        try:
            reader = torchvision.io.VideoReader(video_path, "video")
            frames = []
            for frame in reader:
                frames.append(frame['data'].unsqueeze(0))

            gg = torch.vstack(frames)
            image = gg[np.linspace(random.randint(0,5), len(frames)-1, num=self.num_frame, dtype=int)]
            image = image/255
            image = self.my_normalize(image)

            #### mixing ###
            reader = torchvision.io.VideoReader(video_path_mix, "video")
            frames = []
            read_start = time.time()
            for frame in reader:
                frames.append(frame['data'].unsqueeze(0))

            gg = torch.vstack(frames)
            image2 = gg[np.linspace(random.randint(0,5), len(frames)-1, num=self.num_frame, dtype=int)]
            image2 = image2/255
            image2 = self.my_normalize(image2)

            weight = random.random()
            image = weight * image + (1-weight)*image2
            image = image[random.randint(0,9)].unsqueeze(0)

        except Exception as e:
            print(e)
            image = torch.zeros([1, 3, self.im_res, self.im_res]) + 0.01
            print('there is an error in loading image 1', video_path, video_path_mix)

        label_indices = np.zeros(self.label_num) + (self.label_smooth / self.label_num)
        for label_str in datum[-1].split(','):
            label_indices[int(self.index_dict[label_str])] += mix_lambda * (1.0 - self.label_smooth)
        for label_str in mix_datum[-1].split(','):
            label_indices[int(self.index_dict[label_str])] += (1.0 - mix_lambda) * (1.0 - self.label_smooth)
        label_indices = torch.FloatTensor(label_indices)

    else:
        label_indices = np.zeros(self.label_num) + (self.label_smooth / self.label_num)
        try:
            torchaudio.load(video_path)
            fbank = self._wav2fbank(video_path, None, 0)
        except:
            fbank = torch.zeros([self.target_length, 128]) + 0.01
            print('there is an error in loading audio 2', datum)
        try:
            reader = torchvision.io.VideoReader(video_path, "video")
            frames = []
            for frame in reader:
                frames.append(frame['data'].unsqueeze(0))

            gg = torch.vstack(frames)
            image = gg[np.linspace(random.randint(0,30), len(frames)-1, num=self.num_frame, dtype=int)]
            image = image/255
            image = self.my_normalize(image)

            if self.mode =='eval':
                pass
            else:
                image = image[random.randint(0,9)].unsqueeze(0)
        except Exception as e:
            print(e)
            if self.mode =='eval':
                image = torch.zeros([10, 3, self.im_res, self.im_res]) + 0.01
            else:
                image = torch.zeros([1, 3, self.im_res, self.im_res]) + 0.01
            print('there is an error in loading image 2', video_path)

        for label_str in datum[-1].split(','):
            label_indices[int(self.index_dict[label_str])] = 1.0 - self.label_smooth
        label_indices = torch.FloatTensor(label_indices)

    # SpecAug timing
    freqm = torchaudio.transforms.FrequencyMasking(self.freqm)
    timem = torchaudio.transforms.TimeMasking(self.timem)
    fbank = torch.transpose(fbank, 0, 1)
    fbank = fbank.unsqueeze(0)
    if self.freqm != 0:
        fbank = freqm(fbank)
    if self.timem != 0:
        fbank = timem(fbank)
    fbank = fbank.squeeze(0)
    fbank = torch.transpose(fbank, 0, 1)

    if self.skip_norm == False:
        fbank = (fbank - self.norm_mean) / (self.norm_std)
    else:
        pass

    if self.noise == True:
        fbank = fbank + torch.rand(fbank.shape[0], fbank.shape[1]) * np.random.rand() / 10
        fbank = torch.roll(fbank, np.random.randint(-self.target_length, self.target_length), 0)
    print(image.shape)
    return fbank, image, label_indices

def data_loading_benchmark():
    avg_time_spent = 0
    count = 0
    
    for index in trange(1):
        get_item(train_dataset, index)

%lprun -f data_loading_benchmark data_loading_benchmark()

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  1.27it/s]

torch.Size([1, 3, 224, 224])





Timer unit: 1e-09 s

Total time: 0.788445 s
File: /tmp/ipykernel_424865/3565145512.py
Function: data_loading_benchmark at line 178

Line #      Hits         Time  Per Hit   % Time  Line Contents
   178                                           def data_loading_benchmark():
   179         1        475.0    475.0      0.0      avg_time_spent = 0
   180         1        116.0    116.0      0.0      count = 0
   181                                               
   182         2    1862946.0 931473.0      0.2      for index in trange(1):
   183         1  786581567.0    8e+08     99.8          get_item(train_dataset, index)

In [None]:
torch.Size([250, 3, 720, 1280])

# Combine results

In [3]:
!mkdir -p /tmp/zverev/vggsound 2>/dev/null && \
squashfuse /storage/slurm/zverev/datasets/vggsound.squashfs /tmp/zverev/vggsound 
! METADATA_DIR=/tmp/zverev && \
MOUNTPOINT=$METADATA_DIR/vggsound && \
tr_data=/storage/slurm/zverev/datasets/cav-mae/vggsound/metadata/vgg_train_cleaned.json && \
te_data=/storage/slurm/zverev/datasets/cav-mae/vggsound/metadata/vgg_test_cleaned.json && \
w_data=/storage/slurm/zverev/datasets/cav-mae/vggsound/metadata/vgg_train_cleaned_weight.csv && \
cp $tr_data $METADATA_DIR/ && \
cp $te_data $METADATA_DIR/ && \
cp $w_data $METADATA_DIR/ && \
sed -i 's|/storage/slurm/zverev/datasets/cav-mae|'$MOUNTPOINT'|g' $METADATA_DIR/vgg_train_cleaned.json && \
sed -i 's|/storage/slurm/zverev/datasets/cav-mae|'$MOUNTPOINT'|g' $METADATA_DIR/vgg_test_cleaned.json

fuse: mountpoint is not empty
fuse: if you are sure this is safe, use the 'nonempty' mount option


In [4]:
!ls -l /tmp/zverev

total 45812
drwxr-xr-x 5 zverev tumuser        0 Oct 24 16:11 vggsound
-rw-r--r-- 1 zverev tumuser  3284523 Jan 26 13:42 vgg_test_cleaned.json
-rw-r--r-- 1 zverev tumuser 39028521 Jan 26 13:42 vgg_train_cleaned.json
-rw-r--r-- 1 zverev tumuser  4593175 Jan 26 13:42 vgg_train_cleaned_weight.csv


In [1]:
import torch
import src.dataloader_ft as dataloader
import src.utils as utils
import os

# Fill args from run_videoonly.slurm
args = type('Args', (), {})()
args.data_train = '/tmp/zverev/vgg_train_cleaned.json'
args.label_csv = '/storage/slurm/zverev/datasets/cav-mae/vggsound/metadata/class_labels_indices_vgg.csv'
args.batch_size = 32
args.num_workers = 6
args.sql_path = '/home/wiss/zverev/AVSiam/artefacts/sql'
args.video_path_prefix = '/tmp/zverev/vggsound/video'
# Fill remaining args from run_videoonly.slurm
args.target_length = 1024
args.freqm = 48
args.timem = 192
args.mixup = 0.5
args.dataset = 'vggsound'
args.dataset_mean = -5.081
args.dataset_std = 4.4849
args.noise = True
args.label_smooth = 0.1
im_res = 224  # Standard ViT resolution

args.world_size = 1
args.local_rank = 0
args.dist_url = 'env://'

# Audio config defined from args
audio_conf = {'num_mel_bins': 128, 'target_length': args.target_length, 'freqm': args.freqm, 'timem': args.timem, 'mixup': args.mixup,
              'dataset': args.dataset, 'mode':'train', 'mean':args.dataset_mean, 'std':args.dataset_std,
              'noise':args.noise, 'label_smooth': args.label_smooth, 'im_res': im_res}

os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['LOCAL_RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

if not torch.distributed.is_initialized():
    utils.init_distributed_mode(args)

train_sampler = torch.utils.data.distributed.DistributedSampler(dataloader.AudiosetDataset(args.data_train, label_csv=args.label_csv, audio_conf=audio_conf, sql_path=args.sql_path, video_path_prefix=args.video_path_prefix),shuffle=True)
train_dataset = dataloader.AudiosetDataset(args.data_train, label_csv=args.label_csv, audio_conf=audio_conf, video_path_prefix=args.video_path_prefix),

  from .autonotebook import tqdm as notebook_tqdm


| distributed init (rank 0): env://, gpu 0
Using Label Smoothing: 0.1
now using following mask: 48 freq, 192 time
now using mix-up with rate 0.500000
now process vggsound
use dataset mean -5.081 and std 4.485 to normalize the input.
now use noise augmentation
number of classes is 309
now in train mode.
now use frame -1 from total 10 frames
now using 224 * 224 image input
Dataset has 183727 samples
Using Label Smoothing: 0.1
now using following mask: 48 freq, 192 time
now using mix-up with rate 0.500000
now process vggsound
use dataset mean -5.081 and std 4.485 to normalize the input.
now use noise augmentation
number of classes is 309
now in train mode.
now use frame -1 from total 10 frames
now using 224 * 224 image input
Dataset has 183727 samples


In [2]:
import sqlite3
sql_path =  'artefacts/sql'
con = sqlite3.connect("file:" + f'{sql_path}/val_vgg.sqlite.db' + "?mode=ro", uri=True)
# con = sqlite3.connect("file:" + '/home/wiss/zverev/AVSiam/artefacts/sql/val_vgg_retrieval.sqlite.db' + "?mode=ro", uri=True) ### <----- YB: for retrieval
cur = con.cursor()
num_samples = cur.execute("SELECT COUNT(*) FROM annos").fetchone()[0]

In [14]:
video_ids = [ row[0] for row in cur.execute("SELECT path FROM annos").fetchall() ]

In [27]:

import torch 
import numpy as np
import src.dataloader as dataloader
import pandas as pd
import json

total_frames = 10 # change if your total frame is different

label_csv = pd.read_csv('/storage/slurm/zverev/datasets/cav-mae/vggsound/metadata/class_labels_indices_vgg.csv')
test_json = json.load(open('/storage/slurm/zverev/datasets/cav-mae/vggsound/metadata/vgg_test_cleaned.json', 'r'))["data"]

label_to_display_name = {
    int(data['labels'].split('_')[1]): label_csv.loc[int(data['labels'].split('_')[1]), "display_name"]
    for data in test_json
}


targets = np.load(f'egs/vggsound/exp/CL-test2-cav-mae-ft-5e-05-2-0.75-1-bs32-ldaFalse-mm_grad-fzFalse-h10-a5/mm_grad_target_0.npy')
targets = targets.argmax(axis=1)
targets = [label_to_display_name[target] for target in targets]


modality_predictions = {}
for modality in ['mm_grad', 'audioonly', 'videoonly']:
    if modality == "audioonly":
        target = np.load(f'egs/vggsound/exp/CL-test2-cav-mae-ft-5e-05-2-0.75-1-bs32-ldaFalse-{modality}-fzFalse-h10-a5/{modality}_target.npy')
        target = target.argmax(axis=1)

        prediction = np.load(f'egs/vggsound/exp/CL-test2-cav-mae-ft-5e-05-2-0.75-1-bs32-ldaFalse-{modality}-fzFalse-h10-a5/{modality}_output.npy')
        multiframe_pred = prediction.mean(axis=1)
    else:
        target = np.load(f'egs/vggsound/exp/CL-test2-cav-mae-ft-5e-05-2-0.75-1-bs32-ldaFalse-{modality}-fzFalse-h10-a5/{modality}_target_0.npy')
        target = target.argmax(axis=1)

        multiframe_pred = []
        for frame in range(total_frames):    
            prediction = np.load(f'egs/vggsound/exp/CL-test2-cav-mae-ft-5e-05-2-0.75-1-bs32-ldaFalse-{modality}-fzFalse-h10-a5/{modality}_output_{frame}.npy')
            prediction = prediction.mean(axis=1)
            multiframe_pred.append(prediction)
            
        multiframe_pred = np.mean(multiframe_pred, axis=0)

    multiframe_top_10 = multiframe_pred.argsort(axis=1)[:, -10:]
    multiframe_top_1 = multiframe_top_10[:, -1]

    # prepare data fo csv format
    correctly_predicted = (multiframe_top_1 == target)
    multiframe_top_10  = [
        [
            label_to_display_name[pred] if multiframe_pred[i,pred] > 1 / 309  else "" 
            for pred in row[::-1]
        ] 
        for (i, row) in enumerate(multiframe_top_10)
    ]
    
    #transpose multiframe_top_10
    multiframe_top_10 = list(map(list, zip(*multiframe_top_10)))


    # log data for the future
    modality_predictions[modality] = {
        'logits': multiframe_pred,
        'correctly_predicted': correctly_predicted,
        'multiframe_top_10': multiframe_top_10,
    }

eval_csv_data = list(zip(
    video_ids, targets, 
    modality_predictions['audioonly']['correctly_predicted'],
    modality_predictions['audioonly']['multiframe_top_10'][0],   
    modality_predictions['videoonly']['correctly_predicted'],
    modality_predictions['videoonly']['multiframe_top_10'][0],
    modality_predictions['mm_grad']['correctly_predicted'],
    modality_predictions['mm_grad']['multiframe_top_10'][0],
    # top labels
    *modality_predictions['audioonly']['multiframe_top_10'],
    *modality_predictions['videoonly']['multiframe_top_10'],
    *modality_predictions['mm_grad']['multiframe_top_10'],
))

columns = ['video_id', 'label', 'a', 'a_label', 'v', 'v_label', 'av', 'av_label']
columns.extend([f'{modality}_top_{i}' for modality in ['a', 'v', 'av'] for i in range(1, 11)])
# columns.extend([f'{modality}_top_{i}' for modality in ['av'] for i in range(1, 11)])

eval_csv_pd = pd.DataFrame(eval_csv_data, columns=columns)
eval_csv_pd.to_csv(
    './test_predictions.csv',
    columns=columns,
    index=False,
)

In [30]:
eval_csv_pd["av"].mean(), eval_csv_pd["a"].mean(), eval_csv_pd["v"].mean()

(0.552505503042859, 0.5688851482584488, 0.472679010747119)

In [31]:
import pickle

modality_name_map  = {
    'audioonly' : 'a',
    'videoonly' : 'v',
    'mm_grad' : 'av'
}

logits_results = {
    video_id : {
        modality_name_map[modality] : modality_predictions['logits'][i]
        for modality, modality_predictions in modality_predictions.items()
    }
    for i, video_id in enumerate(video_ids)
}

pickle.dump(logits_results, open("logits.pkl", "wb"))