In [1]:
!apt update && apt install ffmpeg

The operation couldn’t be completed. Unable to locate a Java Runtime that supports apt.
Please visit http://www.java.com for information on installing Java.



In [5]:
from contextlib import nullcontext
from model import Whisper
from utils import TokensPerSecondTimer
from whisper.tokenizer import get_tokenizer
from torch.utils.data import DataLoader
from dataset import AudioDatasetFake, SpectrogramDataset, Batch, LexDataset
import torch
import time

tokenizer = get_tokenizer(True, language='en', task='transcribe')

torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

num_to_generate = 64

if torch.cuda.is_available():
    print('Using GPU - loading medium model')
    model = Whisper.load_from_pretrained("medium")
    model.to('cuda')
    print('Compiling model')
    unoptimized_model = model
    model = torch.compile(unoptimized_model) # pytorch 2.0
    print('Loading dataset')
    dataset = SpectrogramDataset(LexDataset())
    batch_size = 16
    num_workers = 4
    dtype = torch.float16
    ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)
    device = 'cuda'
else:
    model = Whisper.load_from_pretrained("tiny")
    dataset = SpectrogramDataset(AudioDatasetFake())
    batch_size = 2
    num_workers = 0
    dtype = torch.float32
    ctx = nullcontext()
    device = 'cpu'

model.eval()

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=Batch.collate_fn)
dataloader_iter = iter(dataloader)
def get_batch():
    try:
        return (
            next(dataloader_iter).to(device),
            torch.tensor([50258, 50259, 50359, 50363], dtype=torch.int64).repeat(batch_size, 1).to(device),
            #tokenizer.encode('<|startoftranscript|><|en|><|transcribe|><|notimestamps|>', allowed_special={'<|startoftranscript|>', '<|en|>', '<|transcribe|>', '<|notimestamps|>'})
        )
    except StopIteration:
        return None, None

In [6]:
timer = TokensPerSecondTimer(tokens_per_call=batch_size * num_to_generate)
hours_per_batch = batch_size * 30 / 3600

batch, base_tokens = get_batch()
i = 0
while batch is not None:
    print(f'Batch {i}')
    start_time = time.time()

    # GPU Work    
    encoder_logits = model.encoder(batch.inputs)
    output = model.generate(base_tokens, encoder_logits, max_new_tokens=num_to_generate, use_cache=False)

    for x in range(output.shape[0]):
        print(tokenizer.decode(output[x].cpu()))
        
    time_elapsed = time.time() - start_time 
    tokens_per_second = timer()
    print(f'Running: {tokens_per_second} tokens/s, {time_elapsed} s/batch, {(hours_per_batch / time_elapsed) * 60* 60}hours transcribed/h')

    batch, base_tokens = get_batch()
    i += 1


# Incorrect
# -5.773551940917969
# -5.742635250091553
# -4.675758361816406
# -3.0309195518493652
# -3.0745980739593506


# Correct
# -5.773551940917969
# 7.718759059906006
# 25.459793090820312
# 31.683210372924805
# 37.67772674560547
# 14.447501182556152


Batch 0
-5.773551940917969
7.718759059906006
25.459793090820312
31.683210372924805
37.67772674560547
14.447501182556152
7.189600944519043
38.71146774291992
35.253395080566406
19.81686782836914
6.911682605743408
8.963393211364746
-2.6529881954193115
1.9277257919311523
10.174867630004883
-2.4632441997528076
29.83599281311035
12.924864768981934
29.00969123840332
25.301349639892578
33.810150146484375
29.957555770874023
17.558969497680664
2.7650229930877686
23.852235794067383
2.300333261489868
10.68060302734375
41.86252212524414
26.616365432739258
-1.5971943140029907
21.53993797302246
12.066537857055664
-0.34007346630096436
-3.0347888469696045
-1.5243186950683594
0.7232425212860107
3.2335762977600098
6.375285625457764
8.93490982055664
13.142239570617676
14.69994831085205
16.05259895324707
17.54960823059082
20.314781188964844
13.974905014038086
14.587306022644043
13.32923698425293
13.073711395263672
12.79568099975586
12.455617904663086
12.086496353149414
12.39461898803711
12.210887908935547
