Skip to content

Commit

Permalink
add fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Apr 27, 2024
1 parent 148db77 commit 69dca63
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
3 changes: 3 additions & 0 deletions amt/inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def setup_cache(
batch_size,
max_seq_len=4096,
max_audio_len=1500,
dtype=torch.bfloat16,
):
self.causal_mask = torch.tril(
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
Expand All @@ -397,12 +398,14 @@ def setup_cache(
max_seq_length=max_seq_len,
n_heads=8,
head_dim=64,
dtype=dtype,
).cuda()
b.cross_attn.kv_cache = KVCache(
max_batch_size=batch_size,
max_seq_length=max_audio_len,
n_heads=8,
head_dim=64,
dtype=dtype,
).cuda()


Expand Down
8 changes: 6 additions & 2 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def wrapper(*args, **kwargs):
with torch.autocast("cuda", dtype=torch.bfloat16):
return func(*args, **kwargs)
else:
with torch.autocast("cuda", dtype=torch.float32):
with torch.autocast("cuda", dtype=torch.float16):
return func(*args, **kwargs)

return wrapper
Expand Down Expand Up @@ -265,7 +265,11 @@ def gpu_manager(
if gpu_id is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

model.decoder.setup_cache(batch_size=batch_size, max_seq_len=MAX_BLOCK_LEN)
model.decoder.setup_cache(
batch_size=batch_size,
max_seq_len=MAX_BLOCK_LEN,
dtype=torch.bfloat16 if is_bf16_supported() else torch.float16,
)
model.cuda()
model.eval()
if compile is True:
Expand Down

0 comments on commit 69dca63

Please sign in to comment.