In [1]:
from contextlib import nullcontext
from model import Whisper
from utils import TokensPerSecondTimer
from whisper.tokenizer import get_tokenizer
from torch.utils.data import DataLoader
from dataset import AudioDatasetFake, SpectrogramDataset, Batch, LexDataset
import torch
import time

tokenizer = get_tokenizer(True, language='en', task='transcribe')

torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

num_to_generate = 64

if torch.cuda.is_available():
    print('Using GPU - loading medium model')
    model = Whisper.load_from_pretrained("medium")
    model.to('cuda')
    print('Compiling model')
    unoptimized_model = model
    model = torch.compile(unoptimized_model) # pytorch 2.0
    print('Loading dataset')
    dataset = SpectrogramDataset(LexDataset())
    batch_size = 16
    num_workers = 4
    dtype = torch.float16
    ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)
    device = 'cuda'
else:
    model = Whisper.load_from_pretrained("tiny")
    dataset = SpectrogramDataset(AudioDatasetFake())
    batch_size = 2
    num_workers = 0
    dtype = torch.float32
    ctx = nullcontext()
    device = 'cpu'

model.eval()

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=Batch.collate_fn)
dataloader_iter = iter(dataloader)
def get_batch():
    try:
        return (
            next(dataloader_iter).to(device),
            torch.tensor([tokenizer.sot_sequence], dtype=torch.int64).repeat(batch_size, 1).to(device),
        )
    except StopIteration:
        return None, None

  def backtrace(trace: np.ndarray):


In [2]:
timer = TokensPerSecondTimer(tokens_per_call=batch_size * num_to_generate)
hours_per_batch = batch_size * 30 / 3600

batch, base_tokens = get_batch()
i = 0
while batch is not None:
    print(f'Batch {i}')
    start_time = time.time()

    # GPU Work    
    encoder_logits = model.encoder(batch.inputs)
    output = model.generate(base_tokens, encoder_logits, max_new_tokens=num_to_generate)

    for i in range(output.shape[0]):
        print(tokenizer.decode(output[i].cpu()))
        
    time_elapsed = time.time() - start_time 
    tokens_per_second = timer()
    print(f'Running: {tokens_per_second} tokens/s, {time_elapsed} s/batch, {(hours_per_batch / time_elapsed) * 60}hours transcribed/m')

    batch, base_tokens = get_batch()
    i += 1


Batch 0
<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the exhibition.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>3333333333333333333333
<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the exhibition.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>3333333333333333333333
Running: -1 tokens/s, 2.1980910301208496 s/batch, 0.4549402123464471hours transcribed/m
Batch 2
<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts a