In [15]:
import sys
sys.path.append('../fairseq')
sys.path.append('../fairseq/examples')
from fairseq.tasks.audio_pretraining import AudioMaskingConfig, AudioPretrainingConfig, AudioPretrainingTask
from fairseq.tasks.masked_lm import MaskedLMConfig, MaskedLMTask
from data2vec.tasks.multimodal import MultimodalPretrainingConfig, MultimodalPretrainingTask
from fairseq.data.round_robin_zip_datasets import RoundRobinZipDatasets

In [2]:
am_cfg =  AudioMaskingConfig(
    feature_encoder_spec='[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]',
    mask_prob=0.5,
    mask_prob_adjust=0.05,
    mask_length=5,
    inverse_mask=False,
    mask_dropout=0,
    clone_batch=2,
)

In [3]:
ap_cfg = AudioPretrainingConfig(
    data='../../data/LibriSpeech',
    max_sample_size=320000,
    min_sample_size=32000,
    normalize=True,
    tpu=False
)

In [4]:
a_task = AudioPretrainingTask.setup_task(ap_cfg)

In [5]:
audio_ratio=1 
image_ratio=1
text_ratio=1

In [6]:
lm_cfg = MaskedLMConfig(
    data='../../data/language/enwik9',
    sample_break_mode='none',
    tokens_per_sample=512,
    random_token_prob=0,
    leave_unmasked_prob=0,
    include_index=True,
    skip_masking=True,
    d2v2_multi=True,
    seed=42
)

In [7]:
lm_task = MaskedLMTask.setup_task(lm_cfg)

INFO:fairseq.tasks.masked_lm:dictionary: 50264 types


In [8]:
mm_config = MultimodalPretrainingConfig(
    audio=ap_cfg,
    text=lm_cfg,
    max_tokens=1000000,
    batch_size=32,
    update_freq=[1]
)

In [9]:
mm_task = MultimodalPretrainingTask.setup_task(mm_config)

INFO:fairseq.tasks.masked_lm:dictionary: 50264 types


In [17]:
mm_task.load_dataset('train')

INFO:fairseq.data.audio.raw_audio_dataset:loaded 28515, skipped 24 samples
INFO:fairseq.data.data_utils:loaded 1 examples from: ../../data/language/enwik9/train
INFO:fairseq.tasks.masked_lm:loaded 283883 blocks from: ../../data/language/enwik9/train


In [12]:
mm_task.dataset('train').ordered_indices()

[array([ 9701,  1255, 21380, ...,  5143,  3710, 19218]),
 array([283882, 130469, 110054, ...,  47552, 177043,  24172])]

In [14]:
mm_task.dataset('train')[0]

(0,
 {'id': 0,
  'source': tensor([-0.0632, -0.1329, -0.0518,  ..., -0.0202, -0.0220, -0.0229])})

In [262]:
iter = mm_task.get_batch_iterator(dataset=mm_task.dataset('train'))

INFO:fairseq.data.audio.multi_modality_dataset: raw_sub_batch_samplers exists. No action is taken
INFO:fairseq.data.audio.multi_modality_dataset:dataset Modality.AUDIO batch number is 6508 
INFO:fairseq.data.audio.multi_modality_dataset:dataset Modality.TEXT batch number is 8872 


In [263]:
ds = iter._get_iterator_for_epoch(0, False)

In [19]:
a_task.load_dataset('train')

INFO:fairseq.data.audio.raw_audio_dataset:loaded 28515, skipped 24 samples


In [20]:
lm_task.load_dataset('train')

INFO:fairseq.data.data_utils:loaded 1 examples from: ../../data/language/enwik9/train
INFO:fairseq.tasks.masked_lm:loaded 283883 blocks from: ../../data/language/enwik9/train


In [21]:
rr_ds = RoundRobinZipDatasets(datasets={'audio': a_task.dataset('train'), 'text': lm_task.dataset('train')})

In [23]:
rr_ds.ordered_indices()

array([     0,      1,      2, ..., 283880, 283881, 283882])

In [24]:
rr_ds[0]

OrderedDict([('audio',
              {'id': 9701,
               'source': tensor([ 0.0086,  0.0029, -0.0063,  ...,  0.0010,  0.0010,  0.0010])}),
             ('text',
              OrderedDict([('id', 283882),
                           ('net_input.source',
                            tensor([    0,  2226,     8,   341,    13, 22917, 25231,    61,    16,    10,
                                     8218,  2502,     5,  1795,     9,  1368, 12749,     7,    10,   346,
                                        9,    97,   160,  6929,  1274,   215,    25, 37653,  1794,  2199,
                                     1668,  8579, 30817, 20633,   219,  1533, 32204, 17301,  1274,  2900,
                                     1356, 15080,     8,    97,  9161, 12876,    33,   648,     7,    28,
                                    30804, 15253,    30,   557,  4566,  3104,     7, 12909,  3225,  1368,
                                    12749,    21,  2226,    25,    10,  1416,    13, 12909, 12876,  3329