Connected to open_whisper (Python 3.10.14)

In [2]:
import os
import glob
from io import BytesIO
import numpy as np
import wandb
from typing import List, Tuple, Union, Optional, Literal, Dict
import time
import jiwer
from fire import Fire
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.utils import clip_grad_norm_
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import GradScaler, autocast

import whisper
from whisper import audio, DecodingOptions
from whisper.normalizers import EnglishTextNormalizer
from whisper.tokenizer import get_tokenizer
import whisper.tokenizer
from open_whisper.config.model_dims import VARIANT_TO_DIMS, ModelDimensions
import open_whisper as ow

import webdataset as wds
import tempfile

In [None]:
def decode_audio_bytes(audio_bytes: bytes) -> np.ndarray:
    bytes_io = BytesIO(audio_bytes)
    audio_arr = np.load(bytes_io)

    return audio_arr

def decode_text_bytes(text_bytes: bytes) -> str:
    transcript_str = text_bytes.decode("utf-8")

    return transcript_str

def decode_sample(sample: Dict[str, bytes]) -> Tuple[np.ndarray, str]:
    file_path = os.path.join(sample["__url__"], sample["__key__"])
    audio_path = file_path + ".m4a"
    text_path = file_path + ".srt"
    audio_bytes = sample["npy"]
    text_bytes = sample["srt"]
    audio_arr = decode_audio_bytes(audio_bytes)
    transcript_str = decode_text_bytes(text_bytes)

    return audio_path, audio_arr, text_path, transcript_str

def preprocess_audio(audio_arr: np.ndarray) -> torch.Tensor:
    audio_arr = audio_arr.astype(np.float32) / 32768.0
    audio_arr = audio.pad_or_trim(audio_arr)
    mel_spec = audio.log_mel_spectrogram(audio_arr)

    return mel_spec, audio_arr

def preprocess_text(transcript_string: str, tokenizer: whisper.tokenizer.Tokenizer, n_text_ctx: int) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    reader = ow.utils.TranscriptReader(transcript_string=transcript_string, ext="srt")
    transcript, *_ = reader.read()
    
    if not transcript:
        text_tokens = [tokenizer.no_speech]
    else:
        transcript_text = reader.extract_text(transcript=transcript)

        text_tokens = tokenizer.encode(transcript_text)

    text_tokens = list(tokenizer.sot_sequence_including_notimestamps) + text_tokens

    text_tokens.append(tokenizer.eot)

    # offset
    text_input = text_tokens[:-1]
    text_y = text_tokens[1:]

    padding_mask = torch.zeros((n_text_ctx, n_text_ctx))
    padding_mask[:, len(text_input) :] = -float("inf")

    text_input = np.pad(
        text_input,
        pad_width=(0, n_text_ctx - len(text_input)),
        mode="constant",
        constant_values=51864,
    )
    text_y = np.pad(
        text_y,
        pad_width=(0, n_text_ctx - len(text_y)),
        mode="constant",
        constant_values=51864,
    )

    text_input = torch.tensor(text_input, dtype=torch.long)
    text_y = torch.tensor(text_y, dtype=torch.long)

    return text_input, text_y, padding_mask
    
def preprocess(sample, n_text_ctx: int):
    tokenizer = get_tokenizer(multilingual=False)
    audio_path, audio_arr, text_path, transcript_str = decode_sample(sample)
    audio_input, padded_audio_arr = preprocess_audio(audio_arr)
    text_input, text_y, padding_mask = preprocess_text(transcript_str, tokenizer, n_text_ctx)

    return audio_path, text_path, padded_audio_arr, audio_input, text_input, text_y, padding_mask

def shuffle_shards(shards: str) -> List[str]:
    start_train_shard, end_train_shard = [int(shard_idx) for shard_idx in shards.split("{")[-1].split("}")[0].split("..")]
    rng = np.random.default_rng(42)
    shards_list = np.array(range(start_train_shard, end_train_shard + 1))
    rng.shuffle(shards_list)
    shuffled_shards_list = [f"data/tars/{shard_idx:08d}.tar" for shard_idx in shards_list]
    
    return shuffled_shards_list

In [None]:
dataset = wds.WebDataset("data/tars/{000000..000019}.tar").map(lambda sample: preprocess(sample, 448))

In [None]:
dataloader = DataLoader(dataset, batch_size=1, drop_last=False)
for batch in dataloader:
    audio_path, text_path, padded_audio_arr, audio_input, text_input, text_y, padding_mask = batch
    print(audio_input.shape, text_input.shape, text_y.shape, padding_mask.shape)
    break

torch.Size([1, 80, 3000]) torch.Size([1, 448]) torch.Size([1, 448]) torch.Size([1, 448, 448])


In [None]:
for sample in dataset:
    print(sample)
    break

('data/tars/000000.tar/HNNQSe6SJ-I/00:00:27,593_00:00:56,920.m4a', 'data/tars/000000.tar/HNNQSe6SJ-I/00:00:27,593_00:00:56,920.srt', array([ 0.0000000e+00, -6.1035156e-05, -9.1552734e-05, ...,
        0.0000000e+00,  0.0000000e+00,  0.0000000e+00], dtype=float32), tensor([[-0.5765, -0.1714, -0.3327,  ..., -0.5765, -0.5765, -0.5765],
        [-0.5765, -0.0693, -0.0991,  ..., -0.5765, -0.5765, -0.5765],
        [-0.2038, -0.0826,  0.1299,  ..., -0.5765, -0.5765, -0.5765],
        ...,
        [-0.5765, -0.5765, -0.4130,  ..., -0.5765, -0.5765, -0.5765],
        [-0.5765, -0.5765, -0.3556,  ..., -0.5765, -0.5765, -0.5765],
        [-0.5765, -0.5765, -0.4582,  ..., -0.5765, -0.5765, -0.5765]]), tensor([50257, 50362,    40,   892,   262,  7505,   318,   588,   198,  1169,
         3367,   287,   262,  1621,    11,   612,   318,   645,  8278,  1231,
          294, 19942,    13,  4525,   287,  1204,   612,   338,  1464,   588,
          198, 36747, 34120,   326,   345,   423,   284,  1986,   

In [None]:
dataset = wds.WebDataset("data/tars/{000000..000019}.tar")

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample["__key__"])
    print(sample["__url__"])
    if count == 10:
        break

HNNQSe6SJ-I/00:00:27,593_00:00:56,920
data/tars/000000.tar
HNNQSe6SJ-I/00:00:56,920_00:01:25,589
data/tars/000000.tar
HNNQSe6SJ-I/00:01:25,589_00:01:52,850
data/tars/000000.tar
HNNQSe6SJ-I/00:01:52,850_00:02:10,809
data/tars/000000.tar
M0IdAz6O8FA/00:00:00,000_00:00:25,500
data/tars/000000.tar
M0IdAz6O8FA/00:00:25,500_00:00:50,040
data/tars/000000.tar
M0IdAz6O8FA/00:00:50,040_00:01:17,280
data/tars/000000.tar
M0IdAz6O8FA/00:01:17,280_00:01:46,500
data/tars/000000.tar
M0IdAz6O8FA/00:01:46,500_00:02:13,680
data/tars/000000.tar
M0IdAz6O8FA/00:02:13,680_00:02:39,000
data/tars/000000.tar


In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    # at this point we have an iterator over all the shards
    wds.detshuffle(bufsize=1000, initial=100, seed=42),

    # add wds.split_by_node here if you are using multiple nodes
    wds.split_by_worker)

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)

{'url': 'data/tars/000003.tar'}
{'url': 'data/tars/000000.tar'}
{'url': 'data/tars/000008.tar'}
{'url': 'data/tars/000007.tar'}
{'url': 'data/tars/000016.tar'}
{'url': 'data/tars/000002.tar'}
{'url': 'data/tars/000011.tar'}
{'url': 'data/tars/000001.tar'}
{'url': 'data/tars/000010.tar'}
{'url': 'data/tars/000017.tar'}
{'url': 'data/tars/000012.tar'}
{'url': 'data/tars/000006.tar'}
{'url': 'data/tars/000018.tar'}
{'url': 'data/tars/000015.tar'}
{'url': 'data/tars/000013.tar'}
{'url': 'data/tars/000009.tar'}
{'url': 'data/tars/000004.tar'}
{'url': 'data/tars/000014.tar'}
{'url': 'data/tars/000005.tar'}
{'url': 'data/tars/000019.tar'}


In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),
    # add wds.split_by_node here if you are using multiple nodes
    wds.split_by_worker,
    wds.detshuffle(bufsize=1000, initial=100, seed=42))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)

{'url': 'data/tars/000003.tar'}
{'url': 'data/tars/000000.tar'}
{'url': 'data/tars/000008.tar'}
{'url': 'data/tars/000007.tar'}
{'url': 'data/tars/000016.tar'}
{'url': 'data/tars/000002.tar'}
{'url': 'data/tars/000011.tar'}
{'url': 'data/tars/000001.tar'}
{'url': 'data/tars/000010.tar'}
{'url': 'data/tars/000017.tar'}
{'url': 'data/tars/000012.tar'}
{'url': 'data/tars/000006.tar'}
{'url': 'data/tars/000018.tar'}
{'url': 'data/tars/000015.tar'}
{'url': 'data/tars/000013.tar'}
{'url': 'data/tars/000009.tar'}
{'url': 'data/tars/000004.tar'}
{'url': 'data/tars/000014.tar'}
{'url': 'data/tars/000005.tar'}
{'url': 'data/tars/000019.tar'}


In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),
    # add wds.split_by_node here if you are using multiple nodes
    wds.split_by_worker)

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)

{'url': 'data/tars/000000.tar'}
{'url': 'data/tars/000001.tar'}
{'url': 'data/tars/000002.tar'}
{'url': 'data/tars/000003.tar'}
{'url': 'data/tars/000004.tar'}
{'url': 'data/tars/000005.tar'}
{'url': 'data/tars/000006.tar'}
{'url': 'data/tars/000007.tar'}
{'url': 'data/tars/000008.tar'}
{'url': 'data/tars/000009.tar'}
{'url': 'data/tars/000010.tar'}
{'url': 'data/tars/000011.tar'}
{'url': 'data/tars/000012.tar'}
{'url': 'data/tars/000013.tar'}
{'url': 'data/tars/000014.tar'}
{'url': 'data/tars/000015.tar'}
{'url': 'data/tars/000016.tar'}
{'url': 'data/tars/000017.tar'}
{'url': 'data/tars/000018.tar'}
{'url': 'data/tars/000019.tar'}


In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)

{'url': 'data/tars/000012.tar'}
{'url': 'data/tars/000015.tar'}
{'url': 'data/tars/000016.tar'}
{'url': 'data/tars/000000.tar'}
{'url': 'data/tars/000011.tar'}
{'url': 'data/tars/000005.tar'}
{'url': 'data/tars/000014.tar'}
{'url': 'data/tars/000001.tar'}
{'url': 'data/tars/000010.tar'}
{'url': 'data/tars/000008.tar'}
{'url': 'data/tars/000007.tar'}
{'url': 'data/tars/000003.tar'}
{'url': 'data/tars/000006.tar'}
{'url': 'data/tars/000009.tar'}
{'url': 'data/tars/000004.tar'}
{'url': 'data/tars/000002.tar'}
{'url': 'data/tars/000013.tar'}
{'url': 'data/tars/000018.tar'}
{'url': 'data/tars/000019.tar'}
{'url': 'data/tars/000017.tar'}


In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.decode(decode_sample))

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.decode(wds.handle_extension(".npy", decode_audio_bytes)))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)

{'__key__': 'HU9i7y9gDiQ/00:04:12,201_00:04:37,023', '__url__': 'data/tars/000004.tar', 'npy': array([0.        , 0.        , 0.        , ..., 0.00061035, 0.00045776,
       0.00042725], dtype=float32), 'srt': b'1\n00:00:00,000 --> 00:00:06,219\nYou feel like you could just drift off to sleep in this warm, peaceful spot,\n\n2\n00:00:07,078 --> 00:00:10,258\nlistening to the waves breaking\n\n3\n00:00:10,258 --> 00:00:12,429\nand washing away,\n\n4\n00:00:13,174 --> 00:00:17,409\nbreaking and washing away.\n\n5\n00:00:18,060 --> 00:00:24,822\nLet yourself be still in this relaxed place for a few moments.\n\n'}
{'__key__': 'HpRlyCCgR5c/00:22:07,893_00:22:37,122', '__url__': 'data/tars/000004.tar', 'npy': array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
       -3.0517578e-05, -3.0517578e-05,  1.2207031e-04], dtype=float32), 'srt': b'1\n00:00:00,000 --> 00:00:03,603\nsee if you have not experienced\na time of refreshing. I want to\n\n2\n00:00:03,603 --> 00:00:03,670\na time of r

KeyboardInterrupt: 

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.decode(wds.handle_extension(".npy", decode_audio_bytes)),
    wds.decode(wds.handle_extension(".srt", decode_text_bytes)))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)
    break

AssertionError: k,v = npy, [0.         0.         0.         ... 0.02160645 0.02377319 0.02639771]

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.map(decode_sample))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)
    break

('data/tars/000000.tar/W68f2znp17k/00:01:27,830_00:01:55,210.m4a', array([ 0.0000000e+00, -3.0517578e-05,  0.0000000e+00, ...,
        7.1258545e-02,  6.8267822e-02,  6.3049316e-02], dtype=float32), 'data/tars/000000.tar/W68f2znp17k/00:01:27,830_00:01:55,210.srt', "1\n00:00:00,000 --> 00:00:01,660\nDo you remember what those criteria were?\n\n2\n00:00:01,660 --> 00:00:04,110\nWhat makes you a good\ncandidate for this surgery?\n\n3\n00:00:05,370 --> 00:00:09,230\n- My growth plates down\nhere are still at a zero,\n\n4\n00:00:09,230 --> 00:00:14,230\nso that means that I'm\nstill growing quite a lot.\n\n5\n00:00:14,390 --> 00:00:17,680\n- So, and then we need a\ncurve that is big enough\n\n6\n00:00:17,680 --> 00:00:19,740\nto where we would reasonably assume,\n\n7\n00:00:19,740 --> 00:00:20,980\nif we leave this alone,\n\n8\n00:00:20,980 --> 00:00:23,380\nshe's gonna have to have\nsurgery one day anyway.\n\n9\n00:00:23,380 --> 00:00:25,380\nSophia became the first patient\nin Colorado to

In [None]:
count = 0
for sample in dataset:
    count += 1
    temp_sample = sample
    break

In [None]:
print(temp_sample)

('data/tars/000004.tar/9BV8X76iLF8/00:01:25,930_00:01:55,180.m4a', array([ 0.        ,  0.        ,  0.        , ..., -0.14889526,
       -0.10601807, -0.05722046], dtype=float32), 'data/tars/000004.tar/9BV8X76iLF8/00:01:25,930_00:01:55,180.srt', '1\n00:00:00,000 --> 00:00:02,000\nlike a neon lamp.\n\n2\n00:00:02,000 --> 00:00:08,650\nOn the other hand, there are also Dark Nebula,\nor Molecular Clouds where there are no nearby\n\n3\n00:00:08,650 --> 00:00:12,540\nstars, so ionisation is not taking place.\n\n4\n00:00:12,540 --> 00:00:17,180\nThey often consist of dust and molecular hydrogen.\n\n5\n00:00:17,180 --> 00:00:23,080\nSometimes you can have a mix of the two nebulae,\na Dark Nebula found in the middle of a HII\n\n6\n00:00:23,080 --> 00:00:24,080\nregion.\n\n7\n00:00:24,080 --> 00:00:29,250\nIn such a situation, you can often see the\nouter particles becoming ionised, with the\n\n')


In [None]:
for element in temp_sample:
    print(element)

data/tars/000004.tar/9BV8X76iLF8/00:01:25,930_00:01:55,180.m4a
[ 0.          0.          0.         ... -0.14889526 -0.10601807
 -0.05722046]
data/tars/000004.tar/9BV8X76iLF8/00:01:25,930_00:01:55,180.srt
1
00:00:00,000 --> 00:00:02,000
like a neon lamp.

2
00:00:02,000 --> 00:00:08,650
On the other hand, there are also Dark Nebula,
or Molecular Clouds where there are no nearby

3
00:00:08,650 --> 00:00:12,540
stars, so ionisation is not taking place.

4
00:00:12,540 --> 00:00:17,180
They often consist of dust and molecular hydrogen.

5
00:00:17,180 --> 00:00:23,080
Sometimes you can have a mix of the two nebulae,
a Dark Nebula found in the middle of a HII

6
00:00:23,080 --> 00:00:24,080
region.

7
00:00:24,080 --> 00:00:29,250
In such a situation, you can often see the
outer particles becoming ionised, with the




In [None]:
def preprocess(sample, n_text_ctx):
    tokenizer = get_tokenizer(multilingual=False)
    audio_path = sample[0]
    audio_arr = sample[1]
    text_path = sample[2]
    transcript_str = sample[3]
    audio_input, padded_audio_arr = preprocess_audio(audio_arr)
    text_input, text_y, padding_mask = preprocess_text(transcript_str, tokenizer, n_text_ctx)

    return audio_path, text_path, padded_audio_arr, audio_input, text_input, text_y, padding_mask

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.map(decode_sample),
    wds.map(lambda sample: preprocess(sample, 448)))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)
    break

('data/tars/000000.tar/SUUJBBcpEpI/00:03:40,700_00:03:40,960.m4a', 'data/tars/000000.tar/SUUJBBcpEpI/00:03:40,700_00:03:40,960.srt', array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), tensor([[-1.0228, -0.8683, -0.7618,  ..., -1.5000, -1.5000, -1.5000],
        [-0.8460, -0.7151, -0.7851,  ..., -1.5000, -1.5000, -1.5000],
        [-0.9912, -0.6684, -0.6794,  ..., -1.5000, -1.5000, -1.5000],
        ...,
        [-0.5172, -0.3456, -0.2355,  ..., -1.5000, -1.5000, -1.5000],
        [-0.6125, -0.3977, -0.3193,  ..., -1.5000, -1.5000, -1.5000],
        [-0.5905, -0.3795, -0.3550,  ..., -1.5000, -1.5000, -1.5000]]), tensor([50257, 50362, 50361, 51864, 51864, 51864, 51864, 51864, 51864, 51864,
        51864, 51864, 51864, 51864, 51864, 51864, 51864, 51864, 51864, 51864,
        51864, 51864, 51864, 51864, 51864, 51864, 51864, 51864, 51864, 51864,
        51864, 51864, 51864, 51864, 51864, 51864, 51864, 51864, 51864, 51864,
        51864, 51864, 51864, 51864, 51864, 51864, 51864, 51864, 518

In [None]:
tokenizer = get_tokenizer(multilingual=False)

In [None]:
def preprocess(sample, tokenizer, n_text_ctx):
    audio_path = sample[0]
    audio_arr = sample[1]
    text_path = sample[2]
    transcript_str = sample[3]
    audio_input, padded_audio_arr = preprocess_audio(audio_arr)
    text_input, text_y, padding_mask = preprocess_text(transcript_str, tokenizer, n_text_ctx)

    return audio_path, text_path, padded_audio_arr, audio_input, text_input, text_y, padding_mask

In [None]:
n_text_ctx = 448

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.map(decode_sample),
    wds.map(lambda sample: preprocess(sample, tokenizer, n_text_ctx)))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)
    break

('data/tars/000018.tar/G0AZsQ_UCbg/00:13:49,451_00:14:17,925.m4a', 'data/tars/000018.tar/G0AZsQ_UCbg/00:13:49,451_00:14:17,925.srt', array([3.0517578e-05, 1.8310547e-04, 2.1362305e-04, ..., 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00], dtype=float32), tensor([[-0.6308, -0.5292, -0.4150,  ..., -0.6308, -0.6308, -0.6308],
        [-0.6308, -0.3649, -0.2214,  ..., -0.6308, -0.6308, -0.6308],
        [-0.4598, -0.2148, -0.0827,  ..., -0.6308, -0.6308, -0.6308],
        ...,
        [-0.6308, -0.3910, -0.2776,  ..., -0.6308, -0.6308, -0.6308],
        [-0.6308, -0.5203, -0.3812,  ..., -0.6308, -0.6308, -0.6308],
        [-0.6308, -0.5980, -0.4603,  ..., -0.6308, -0.6308, -0.6308]]), tensor([50257, 50362,   259,  2800,  2106,   286,   262,  2663,    11,   340,
          338,  6493,   198, 10734,    12,  1462,    12, 10734, 11478,   326,
        15397,   550,  4750,   326,   460,  3051,    13,   843,    11,   475,
          356,   836,   470,  1107,   198, 16275,  1865,   611,  1997,  

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.map(decode_sample),
    wds.map(lambda sample: preprocess(sample, tokenizer, n_text_ctx)),
    wds.batched(8))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)
    break

[['data/tars/000009.tar/3DhDbXHOCeY/00:00:56,050_00:01:24,880.m4a', 'data/tars/000009.tar/q28u_xRNN-Q/00:00:10,570_00:00:40,280.m4a', 'data/tars/000009.tar/AlYqaAi1y1o/00:00:38,254_00:01:06,793.m4a', 'data/tars/000009.tar/13x3OTgC1z0/00:03:22,302_00:03:46,693.m4a', 'data/tars/000009.tar/8L83HC4bqHE/00:05:24,639_00:05:49,879.m4a', 'data/tars/000009.tar/eRhEU9Yn-F4/00:11:05,798_00:11:33,993.m4a', 'data/tars/000009.tar/eRhEU9Yn-F4/00:04:18,223_00:04:47,854.m4a', 'data/tars/000009.tar/RRU5Cfpz2n0/00:00:31,380_00:01:01,140.m4a'], ['data/tars/000009.tar/3DhDbXHOCeY/00:00:56,050_00:01:24,880.srt', 'data/tars/000009.tar/q28u_xRNN-Q/00:00:10,570_00:00:40,280.srt', 'data/tars/000009.tar/AlYqaAi1y1o/00:00:38,254_00:01:06,793.srt', 'data/tars/000009.tar/13x3OTgC1z0/00:03:22,302_00:03:46,693.srt', 'data/tars/000009.tar/8L83HC4bqHE/00:05:24,639_00:05:49,879.srt', 'data/tars/000009.tar/eRhEU9Yn-F4/00:11:05,798_00:11:33,993.srt', 'data/tars/000009.tar/eRhEU9Yn-F4/00:04:18,223_00:04:47,854.srt', 'data/

In [None]:
sample

[['data/tars/000009.tar/3DhDbXHOCeY/00:00:56,050_00:01:24,880.m4a',
  'data/tars/000009.tar/q28u_xRNN-Q/00:00:10,570_00:00:40,280.m4a',
  'data/tars/000009.tar/AlYqaAi1y1o/00:00:38,254_00:01:06,793.m4a',
  'data/tars/000009.tar/13x3OTgC1z0/00:03:22,302_00:03:46,693.m4a',
  'data/tars/000009.tar/8L83HC4bqHE/00:05:24,639_00:05:49,879.m4a',
  'data/tars/000009.tar/eRhEU9Yn-F4/00:11:05,798_00:11:33,993.m4a',
  'data/tars/000009.tar/eRhEU9Yn-F4/00:04:18,223_00:04:47,854.m4a',
  'data/tars/000009.tar/RRU5Cfpz2n0/00:00:31,380_00:01:01,140.m4a'],
 ['data/tars/000009.tar/3DhDbXHOCeY/00:00:56,050_00:01:24,880.srt',
  'data/tars/000009.tar/q28u_xRNN-Q/00:00:10,570_00:00:40,280.srt',
  'data/tars/000009.tar/AlYqaAi1y1o/00:00:38,254_00:01:06,793.srt',
  'data/tars/000009.tar/13x3OTgC1z0/00:03:22,302_00:03:46,693.srt',
  'data/tars/000009.tar/8L83HC4bqHE/00:05:24,639_00:05:49,879.srt',
  'data/tars/000009.tar/eRhEU9Yn-F4/00:11:05,798_00:11:33,993.srt',
  'data/tars/000009.tar/eRhEU9Yn-F4/00:04:18,22

In [None]:
len(sample)

7

In [None]:
wds.WebLoader(dataset, batch_size=None, shuffle=False, pin_memory=True, num_workers=4, drop_last=False, persistent_workers=True)

<webdataset.compat.WebLoader at 0x7f500fe35e40>

In [None]:
dataloader = wds.WebLoader(dataset, batch_size=None, shuffle=False, pin_memory=True, num_workers=4, drop_last=False, persistent_workers=True)

In [None]:
for batch_idx, batch in enumerate(dataloader):
    print(batch)
    break

[['data/tars/000000.tar/Uafa6ROC1gU/00:00:26,484_00:00:56,314.m4a', 'data/tars/000000.tar/ZoaqfS-bTFI/00:00:29,840_00:00:30,960.m4a', 'data/tars/000000.tar/T5I-yEkajRc/00:00:58,170_00:01:26,630.m4a', 'data/tars/000000.tar/M0IdAz6O8FA/00:02:13,680_00:02:39,000.m4a', 'data/tars/000000.tar/HNNQSe6SJ-I/00:00:56,920_00:01:25,589.m4a', 'data/tars/000000.tar/q2yt8GrnMqU/00:00:03,350_00:00:31,910.m4a', 'data/tars/000000.tar/W68f2znp17k/00:00:58,670_00:01:27,830.m4a', 'data/tars/000000.tar/SUUJBBcpEpI/00:07:28,800_00:07:55,240.m4a'], ['data/tars/000000.tar/Uafa6ROC1gU/00:00:26,484_00:00:56,314.srt', 'data/tars/000000.tar/ZoaqfS-bTFI/00:00:29,840_00:00:30,960.srt', 'data/tars/000000.tar/T5I-yEkajRc/00:00:58,170_00:01:26,630.srt', 'data/tars/000000.tar/M0IdAz6O8FA/00:02:13,680_00:02:39,000.srt', 'data/tars/000000.tar/HNNQSe6SJ-I/00:00:56,920_00:01:25,589.srt', 'data/tars/000000.tar/q2yt8GrnMqU/00:00:03,350_00:00:31,910.srt', 'data/tars/000000.tar/W68f2znp17k/00:00:58,670_00:01:27,830.srt', 'data/

In [None]:
for batch_idx, batch in enumerate(dataloader):
    print(batch["audio_files"])
    print(batch["transcript_files"])
    break

TypeError: list indices must be integers or slices, not str

In [None]:
for batch_idx, batch in enumerate(dataloader):
    print(batch[0])
    print(batch[1])
    break

['data/tars/000000.tar/M0IdAz6O8FA/00:00:00,000_00:00:25,500.m4a', 'data/tars/000000.tar/T5I-yEkajRc/00:00:00,120_00:00:28,463.m4a', 'data/tars/000000.tar/SUUJBBcpEpI/00:07:00,320_00:07:27,800.m4a', 'data/tars/000000.tar/Uafa6ROC1gU/00:00:56,314_00:01:24,342.m4a', 'data/tars/000000.tar/SUUJBBcpEpI/00:06:03,580_00:06:30,300.m4a', 'data/tars/000000.tar/T5I-yEkajRc/00:00:29,460_00:00:58,170.m4a', 'data/tars/000000.tar/ZoaqfS-bTFI/00:01:36,080_00:01:36,640.m4a', 'data/tars/000000.tar/q2yt8GrnMqU/00:00:31,910_00:00:49,610.m4a']
['data/tars/000000.tar/M0IdAz6O8FA/00:00:00,000_00:00:25,500.srt', 'data/tars/000000.tar/T5I-yEkajRc/00:00:00,120_00:00:28,463.srt', 'data/tars/000000.tar/SUUJBBcpEpI/00:07:00,320_00:07:27,800.srt', 'data/tars/000000.tar/Uafa6ROC1gU/00:00:56,314_00:01:24,342.srt', 'data/tars/000000.tar/SUUJBBcpEpI/00:06:03,580_00:06:30,300.srt', 'data/tars/000000.tar/T5I-yEkajRc/00:00:29,460_00:00:58,170.srt', 'data/tars/000000.tar/ZoaqfS-bTFI/00:01:36,080_00:01:36,640.srt', 'data/ta

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.map(decode_sample),
    wds.map(lambda sample: preprocess(sample, tokenizer, n_text_ctx)),
    wds.shuffle(bufsize=1000, initial=100),
    wds.batched(8))

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import webdataset as wds
from torch.utils.data import DataLoader

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

In [None]:
def train(rank, world_size):
    setup(rank, world_size)

    dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100))

    for sample in dataset:
        print(rank, world_size)
        print(sample)

In [None]:
world_size = torch.cuda.device_count()
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

In [None]:
for batch_idx, batch in enumerate(dataloader):
    print(batch[0])
    print(batch[1])
    break

['data/tars/000016.tar/3mnY_m_XiZE/00:22:40,057_00:23:04,215.m4a', 'data/tars/000016.tar/3mnY_m_XiZE/00:08:02,848_00:08:26,672.m4a', 'data/tars/000016.tar/9DLs1Jfx0lo/00:01:57,684_00:01:58,118.m4a', 'data/tars/000016.tar/3mnY_m_XiZE/00:13:29,307_00:13:57,602.m4a', 'data/tars/000016.tar/3mnY_m_XiZE/00:03:47,059_00:04:16,555.m4a', 'data/tars/000016.tar/JJDl-C2sPgc/00:00:17,801_00:00:46,536.m4a', 'data/tars/000016.tar/t9xKzn66quk/00:07:44,130_00:08:12,992.m4a', 'data/tars/000016.tar/3mnY_m_XiZE/00:06:36,528_00:07:05,991.m4a']
['data/tars/000016.tar/3mnY_m_XiZE/00:22:40,057_00:23:04,215.srt', 'data/tars/000016.tar/3mnY_m_XiZE/00:08:02,848_00:08:26,672.srt', 'data/tars/000016.tar/9DLs1Jfx0lo/00:01:57,684_00:01:58,118.srt', 'data/tars/000016.tar/3mnY_m_XiZE/00:13:29,307_00:13:57,602.srt', 'data/tars/000016.tar/3mnY_m_XiZE/00:03:47,059_00:04:16,555.srt', 'data/tars/000016.tar/JJDl-C2sPgc/00:00:17,801_00:00:46,536.srt', 'data/tars/000016.tar/t9xKzn66quk/00:07:44,130_00:08:12,992.srt', 'data/ta

In [None]:
dataset.with_epoch(30764 // 64)

In [None]:
for batch_idx, batch in enumerate(dataloader):
    print(batch)
    break

In [None]:
for epoch in range(0, 3):
    for batch_idx, batch in enumerate(dataloader):
        if batch_idx == 0:
            print(batch[0][0])

In [None]:
for epoch in range(0, 3):
    for batch_idx, batch in enumerate(dataloader):
        if batch_idx == 0:
            print(batch[0][0])