In [1]:
import torch
import torch.nn.functional as F
from torch import nn

import whisper
from whisper.audio import (
    log_mel_spectrogram,
    pad_or_trim,
    load_audio,
)

import jiwer
from tqdm import tqdm
from main import *

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model("base.en")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# collect trainable params
params = []
names = []

for name, param in model.named_parameters():
    param.requires_grad = False

for nm, m in model.named_modules():
    # print(str(nm).split('.'))
    trainable = ['weight', 'bias']
    # train_LN
    if isinstance(m, nn.LayerNorm) and str(nm).split('.')[0] == 'encoder':
        for np, p in m.named_parameters():
            if np in trainable:  
                p.requires_grad = True
                params.append(p)
                names.append(f"{nm}.{np}")
    # train_feature
    if len(str(nm).split('.')) > 1:
        if str(nm).split('.')[0] == 'encoder' and (str(nm).split('.')[1] == 'conv1' or str(nm).split('.')[1] == 'conv2'):
            for np, p in m.named_parameters():
                p.requires_grad = True
                params.append(p)
                names.append(f"{nm}.{np}")
print(names)
# check trainable parameter
# for name, param in model.named_parameters():
#     print("name: ", name)
#     print("requires_grad: ", param.requires_grad)

['encoder.conv1.weight', 'encoder.conv1.bias', 'encoder.conv2.weight', 'encoder.conv2.bias', 'encoder.blocks.0.attn_ln.weight', 'encoder.blocks.0.attn_ln.bias', 'encoder.blocks.0.mlp_ln.weight', 'encoder.blocks.0.mlp_ln.bias', 'encoder.blocks.1.attn_ln.weight', 'encoder.blocks.1.attn_ln.bias', 'encoder.blocks.1.mlp_ln.weight', 'encoder.blocks.1.mlp_ln.bias', 'encoder.blocks.2.attn_ln.weight', 'encoder.blocks.2.attn_ln.bias', 'encoder.blocks.2.mlp_ln.weight', 'encoder.blocks.2.mlp_ln.bias', 'encoder.blocks.3.attn_ln.weight', 'encoder.blocks.3.attn_ln.bias', 'encoder.blocks.3.mlp_ln.weight', 'encoder.blocks.3.mlp_ln.bias', 'encoder.blocks.4.attn_ln.weight', 'encoder.blocks.4.attn_ln.bias', 'encoder.blocks.4.mlp_ln.weight', 'encoder.blocks.4.mlp_ln.bias', 'encoder.blocks.5.attn_ln.weight', 'encoder.blocks.5.attn_ln.bias', 'encoder.blocks.5.mlp_ln.weight', 'encoder.blocks.5.mlp_ln.bias', 'encoder.ln_post.weight', 'encoder.ln_post.bias']


In [3]:
# load audio
model = model.to(DEVICE)
options = whisper.DecodingOptions(language="en", without_timestamps=True)
audio = load_audio(file='./p232_022.wav')
audio = pad_or_trim(audio)
mel = log_mel_spectrogram(audio)
mel = mel.unsqueeze(-1)
mel = mel.permute(2,0,1)

## Before TTA

In [4]:
# forward
mel = mel.to(DEVICE)
outputs = model.decode(mel, options)
outputs

([DecodingResult(audio_features=tensor([[-6.8994e-01,  5.6152e-01, -9.4238e-01,  ...,  2.4438e-01,
           -4.6631e-01,  2.3331e-02],
          [-5.1758e-01,  4.1162e-01, -6.4731e-05,  ...,  9.3115e-01,
           -8.7305e-01,  1.8359e-01],
          [-9.4092e-01, -2.0190e-01,  3.5303e-01,  ...,  5.2930e-01,
           -1.8066e-01, -3.0908e-01],
          ...,
          [ 3.3105e-01, -6.3818e-01, -9.0723e-01,  ...,  1.0029e+00,
           -2.2168e-01,  8.9893e-01],
          [ 7.2314e-01, -3.7769e-01, -4.2725e-01,  ...,  1.1074e+00,
           -3.8501e-01,  5.4004e-01],
          [ 7.6709e-01, -7.2461e-01,  3.1763e-01,  ...,  6.4258e-01,
           -4.7339e-01,  2.2205e-01]], device='cuda:0', dtype=torch.float16,
         grad_fn=<UnbindBackward0>), language='en', language_probs=None, tokens=[383, 4036, 4165, 27223, 6515, 318, 531, 284, 307, 262, 1245, 286, 2208, 320, 9150, 286, 257, 1271, 286, 37469, 13], text='The actual primary rainbow observed is said to be the effect of superim

## Calculate loss and adapt

In [5]:
optimizer, scheduler = setup_optimizer(params, 'AdamW', lr=3e-4, scheduler=None)
outputs = model.decode(mel, options)
result_tensor = torch.stack(outputs[1], dim=0)
result_tensor=result_tensor.permute(1,0,2) # torch.Size([1, 5, 51864])

[INFO]    optimizer: <class 'torch.optim.adamw.AdamW'>
[INFO]    scheduler: None


In [6]:
e_loss = softmax_entropy(result_tensor).mean(0).mean()
e_loss


tensor(0.2499, device='cuda:0')

In [7]:
c_loss = mcc_loss(result_tensor, reweight=False)
c_loss

OutOfMemoryError: CUDA out of memory. Tried to allocate 10.02 GiB. GPU 0 has a total capacty of 11.72 GiB of which 7.28 GiB is free. Process 1935689 has 1.83 GiB memory in use. Including non-PyTorch memory, this process has 2.14 GiB memory in use. Of the allocated memory 1.80 GiB is allocated by PyTorch, and 145.23 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [7]:
loss = 0
loss += e_loss
print(loss)

loss.backward()
optimizer.step()
if scheduler is not None: 
    scheduler.step()
model.zero_grad()


# with torch.no_grad():
#     outputs = model.decode(mel, options)
#     print(outputs)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [12]:
with torch.no_grad():
    outputs = model.decode(mel, options)
    print(outputs[0][0].text)

The actual primary rainbow observed is said to be the effect of superimposition of a number of bows.


In [None]:
# for np, p in model.encoder.conv1.named_parameters():
#     if np in trainable:
#         print(p.grad)