In [None]:
import os
import glob
from io import BytesIO
import numpy as np
import wandb
from typing import List, Tuple, Union, Optional, Literal, Dict
import time
import jiwer
from fire import Fire
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.utils import clip_grad_norm_
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import GradScaler, autocast

import whisper
from whisper import audio, DecodingOptions
from whisper.normalizers import EnglishTextNormalizer
from whisper.tokenizer import get_tokenizer
import whisper.tokenizer
from olmoasr.config.model_dims import VARIANT_TO_DIMS, ModelDimensions
import olmoasr as oa

import webdataset as wds
import tempfile

In [None]:
def decode_audio_bytes(audio_bytes: bytes) -> np.ndarray:
    bytes_io = BytesIO(audio_bytes)
    audio_arr = np.load(bytes_io)

    return audio_arr

def decode_text_bytes(text_bytes: bytes) -> str:
    transcript_str = text_bytes.decode("utf-8")

    return transcript_str

def decode_sample(sample: Dict[str, bytes]) -> Tuple[np.ndarray, str]:
    file_path = os.path.join(sample["__url__"], sample["__key__"])
    audio_path = file_path + ".m4a"
    text_path = file_path + ".srt"
    audio_bytes = sample["npy"]
    text_bytes = sample["srt"]
    audio_arr = decode_audio_bytes(audio_bytes)
    transcript_str = decode_text_bytes(text_bytes)

    return audio_path, audio_arr, text_path, transcript_str

def preprocess_audio(audio_arr: np.ndarray) -> torch.Tensor:
    audio_arr = audio_arr.astype(np.float32) / 32768.0
    audio_arr = audio.pad_or_trim(audio_arr)
    mel_spec = audio.log_mel_spectrogram(audio_arr)

    return mel_spec, audio_arr

def preprocess_text(transcript_string: str, tokenizer: whisper.tokenizer.Tokenizer, n_text_ctx: int) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    reader = oa.utils.TranscriptReader(transcript_string=transcript_string, ext="srt")
    transcript, *_ = reader.read()
    
    if not transcript:
        text_tokens = [tokenizer.no_speech]
    else:
        transcript_text = reader.extract_text(transcript=transcript)

        text_tokens = tokenizer.encode(transcript_text)

    text_tokens = list(tokenizer.sot_sequence_including_notimestamps) + text_tokens

    text_tokens.append(tokenizer.eot)

    # offset
    text_input = text_tokens[:-1]
    text_y = text_tokens[1:]

    padding_mask = torch.zeros((n_text_ctx, n_text_ctx))
    padding_mask[:, len(text_input) :] = -float("inf")

    text_input = np.pad(
        text_input,
        pad_width=(0, n_text_ctx - len(text_input)),
        mode="constant",
        constant_values=51864,
    )
    text_y = np.pad(
        text_y,
        pad_width=(0, n_text_ctx - len(text_y)),
        mode="constant",
        constant_values=51864,
    )

    text_input = torch.tensor(text_input, dtype=torch.long)
    text_y = torch.tensor(text_y, dtype=torch.long)

    return text_input, text_y, padding_mask
    
def preprocess(sample, n_text_ctx: int):
    tokenizer = get_tokenizer(multilingual=False)
    audio_path, audio_arr, text_path, transcript_str = decode_sample(sample)
    audio_input, padded_audio_arr = preprocess_audio(audio_arr)
    text_input, text_y, padding_mask = preprocess_text(transcript_str, tokenizer, n_text_ctx)

    return audio_path, text_path, padded_audio_arr, audio_input, text_input, text_y, padding_mask

def shuffle_shards(shards: str) -> List[str]:
    start_train_shard, end_train_shard = [int(shard_idx) for shard_idx in shards.split("{")[-1].split("}")[0].split("..")]
    rng = np.random.default_rng(42)
    shards_list = np.array(range(start_train_shard, end_train_shard + 1))
    rng.shuffle(shards_list)
    shuffled_shards_list = [f"data/tars/{shard_idx:08d}.tar" for shard_idx in shards_list]
    
    return shuffled_shards_list

In [None]:
dataset = wds.WebDataset("data/tars/{000000..000019}.tar").map(lambda sample: preprocess(sample, 448))

In [None]:
dataloader = DataLoader(dataset, batch_size=1, drop_last=False)
for batch in dataloader:
    audio_path, text_path, padded_audio_arr, audio_input, text_input, text_y, padding_mask = batch
    print(audio_input.shape, text_input.shape, text_y.shape, padding_mask.shape)
    break

In [None]:
for sample in dataset:
    print(sample)
    break

In [None]:
dataset = wds.WebDataset("data/tars/{000000..000019}.tar")

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample["__key__"])
    print(sample["__url__"])
    if count == 10:
        break

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    # at this point we have an iterator over all the shards
    wds.detshuffle(bufsize=1000, initial=100, seed=42),

    # add wds.split_by_node here if you are using multiple nodes
    wds.split_by_worker)

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),
    # add wds.split_by_node here if you are using multiple nodes
    wds.split_by_worker,
    wds.detshuffle(bufsize=1000, initial=100, seed=42))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),
    # add wds.split_by_node here if you are using multiple nodes
    wds.split_by_worker)

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.decode(decode_sample))

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.decode(wds.handle_extension(".npy", decode_audio_bytes)))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.decode(wds.handle_extension(".npy", decode_audio_bytes)),
    wds.decode(wds.handle_extension(".srt", decode_text_bytes)))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)
    break

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.map(decode_sample))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)
    break

In [None]:
count = 0
for sample in dataset:
    count += 1
    temp_sample = sample
    break

In [None]:
print(temp_sample)

In [None]:
for element in temp_sample:
    print(element)

In [None]:
def preprocess(sample, n_text_ctx):
    tokenizer = get_tokenizer(multilingual=False)
    audio_path = sample[0]
    audio_arr = sample[1]
    text_path = sample[2]
    transcript_str = sample[3]
    audio_input, padded_audio_arr = preprocess_audio(audio_arr)
    text_input, text_y, padding_mask = preprocess_text(transcript_str, tokenizer, n_text_ctx)

    return audio_path, text_path, padded_audio_arr, audio_input, text_input, text_y, padding_mask

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.map(decode_sample),
    wds.map(lambda sample: preprocess(sample, 448)))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)
    break

In [None]:
tokenizer = get_tokenizer(multilingual=False)

In [None]:
def preprocess(sample, tokenizer, n_text_ctx):
    audio_path = sample[0]
    audio_arr = sample[1]
    text_path = sample[2]
    transcript_str = sample[3]
    audio_input, padded_audio_arr = preprocess_audio(audio_arr)
    text_input, text_y, padding_mask = preprocess_text(transcript_str, tokenizer, n_text_ctx)

    return audio_path, text_path, padded_audio_arr, audio_input, text_input, text_y, padding_mask

In [None]:
n_text_ctx = 448

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.map(decode_sample),
    wds.map(lambda sample: preprocess(sample, tokenizer, n_text_ctx)))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)
    break

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.map(decode_sample),
    wds.map(lambda sample: preprocess(sample, tokenizer, n_text_ctx)),
    wds.batched(8))

In [None]:
count = 0
for sample in dataset:
    count += 1
    print(sample)
    break

In [None]:
sample

In [None]:
len(sample)

In [None]:
wds.WebLoader(dataset, batch_size=None, shuffle=False, pin_memory=True, num_workers=4, drop_last=False, persistent_workers=True)

In [None]:
dataloader = wds.WebLoader(dataset, batch_size=None, shuffle=False, pin_memory=True, num_workers=4, drop_last=False, persistent_workers=True)

In [None]:
for batch_idx, batch in enumerate(dataloader):
    print(batch)
    break

In [None]:
for batch_idx, batch in enumerate(dataloader):
    print(batch["audio_files"])
    print(batch["transcript_files"])
    break

In [None]:
for batch_idx, batch in enumerate(dataloader):
    print(batch[0])
    print(batch[1])
    break

In [None]:
dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100),

    # at this point, we have an iterator over the shards assigned to each worker
    wds.tarfile_to_samples(),

    # this shuffles the samples in memory
    wds.shuffle(bufsize=1000, initial=100),
    wds.map(decode_sample),
    wds.map(lambda sample: preprocess(sample, tokenizer, n_text_ctx)),
    wds.shuffle(bufsize=1000, initial=100),
    wds.batched(8))

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import webdataset as wds
from torch.utils.data import DataLoader

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

In [None]:
def train(rank, world_size):
    setup(rank, world_size)

    dataset = wds.DataPipeline(
    wds.SimpleShardList("data/tars/{000000..000019}.tar"),

    wds.split_by_worker,

    wds.shuffle(bufsize=1000, initial=100))

    for sample in dataset:
        print(rank, world_size)
        print(sample)

In [None]:
world_size = torch.cuda.device_count()
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

In [None]:
for batch_idx, batch in enumerate(dataloader):
    print(batch[0])
    print(batch[1])
    break

In [None]:
dataset.with_epoch(30764 // 64)

In [None]:
for batch_idx, batch in enumerate(dataloader):
    print(batch)
    break

In [None]:
for epoch in range(0, 3):
    for batch_idx, batch in enumerate(dataloader):
        if batch_idx == 0:
            print(batch[0][0])

In [None]:
for epoch in range(0, 3):
    for batch_idx, batch in enumerate(dataloader):
        if batch_idx == 0:
            print(batch[0][0])