In [1]:
import torch
import torch.nn.functional as F
from torch import nn

import whisper
from whisper.audio import (
    log_mel_spectrogram,
    pad_or_trim,
    load_audio,
)

import jiwer
from tqdm import tqdm
from main import *

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = whisper.load_model("base.en")

In [3]:
# collect trainable params
params = []
names = []

for name, param in model.named_parameters():
    param.requires_grad = False

for nm, m in model.named_modules():
    # print(str(nm).split('.'))
    trainable = ['weight', 'bias']
    # train_LN
    if isinstance(m, nn.LayerNorm):
        for np, p in m.named_parameters():
            if np in trainable:  
                p.requires_grad = True
                params.append(p)
                names.append(f"{nm}.{np}")
    # train_feature
    if len(str(nm).split('.')) > 1:
        if str(nm).split('.')[0] == 'encoder' and (str(nm).split('.')[1] == 'conv1' or str(nm).split('.')[1] == 'conv2'):
            for np, p in m.named_parameters():
                p.requires_grad = True
                params.append(p)
                names.append(f"{nm}.{np}")
print(names)

['encoder.conv1.weight', 'encoder.conv1.bias', 'encoder.conv2.weight', 'encoder.conv2.bias', 'encoder.blocks.0.attn_ln.weight', 'encoder.blocks.0.attn_ln.bias', 'encoder.blocks.0.mlp_ln.weight', 'encoder.blocks.0.mlp_ln.bias', 'encoder.blocks.1.attn_ln.weight', 'encoder.blocks.1.attn_ln.bias', 'encoder.blocks.1.mlp_ln.weight', 'encoder.blocks.1.mlp_ln.bias', 'encoder.blocks.2.attn_ln.weight', 'encoder.blocks.2.attn_ln.bias', 'encoder.blocks.2.mlp_ln.weight', 'encoder.blocks.2.mlp_ln.bias', 'encoder.blocks.3.attn_ln.weight', 'encoder.blocks.3.attn_ln.bias', 'encoder.blocks.3.mlp_ln.weight', 'encoder.blocks.3.mlp_ln.bias', 'encoder.blocks.4.attn_ln.weight', 'encoder.blocks.4.attn_ln.bias', 'encoder.blocks.4.mlp_ln.weight', 'encoder.blocks.4.mlp_ln.bias', 'encoder.blocks.5.attn_ln.weight', 'encoder.blocks.5.attn_ln.bias', 'encoder.blocks.5.mlp_ln.weight', 'encoder.blocks.5.mlp_ln.bias', 'encoder.ln_post.weight', 'encoder.ln_post.bias', 'decoder.blocks.0.attn_ln.weight', 'decoder.blocks.0.

In [23]:
# check trainable parameter
for name, param in model.named_parameters():
    print("name: ", name)
    print("requires_grad: ", param.requires_grad)

name:  encoder.conv1.weight
requires_grad:  True
name:  encoder.conv1.bias
requires_grad:  True
name:  encoder.conv2.weight
requires_grad:  True
name:  encoder.conv2.bias
requires_grad:  True
name:  encoder.blocks.0.attn.query.weight
requires_grad:  False
name:  encoder.blocks.0.attn.query.bias
requires_grad:  False
name:  encoder.blocks.0.attn.key.weight
requires_grad:  False
name:  encoder.blocks.0.attn.value.weight
requires_grad:  False
name:  encoder.blocks.0.attn.value.bias
requires_grad:  False
name:  encoder.blocks.0.attn.out.weight
requires_grad:  False
name:  encoder.blocks.0.attn.out.bias
requires_grad:  False
name:  encoder.blocks.0.attn_ln.weight
requires_grad:  True
name:  encoder.blocks.0.attn_ln.bias
requires_grad:  True
name:  encoder.blocks.0.mlp.0.weight
requires_grad:  False
name:  encoder.blocks.0.mlp.0.bias
requires_grad:  False
name:  encoder.blocks.0.mlp.2.weight
requires_grad:  False
name:  encoder.blocks.0.mlp.2.bias
requires_grad:  False
name:  encoder.blocks.

In [4]:
# load audio
model = model.to(DEVICE)
options = whisper.DecodingOptions(language="en", without_timestamps=True)
audio = load_audio(file='./p232_001.wav')
audio = pad_or_trim(audio)
mel = log_mel_spectrogram(audio)
mel = mel.unsqueeze(-1)
mel = mel.permute(2,0,1)

## Before TTA

In [5]:
# forward
mel = mel.to(DEVICE)
outputs = model.decode(mel, options)
outputs

([DecodingResult(audio_features=tensor([[-1.6826,  0.1215, -0.4236,  ...,  0.4717, -0.7290,  0.2764],
          [-1.3369, -0.6982,  0.3635,  ...,  0.4441, -0.4841,  0.5293],
          [-1.2100, -1.0420,  0.8184,  ...,  0.0606,  0.1081, -0.0176],
          ...,
          [ 0.1824, -0.4514, -0.4377,  ...,  1.0195, -0.5532,  0.7129],
          [ 0.6265, -0.3445,  0.0033,  ...,  1.0947, -0.7920,  0.2429],
          [ 0.8047, -0.7417,  0.4585,  ...,  0.5771, -0.7646, -0.1542]],
         device='cuda:0', dtype=torch.float16), language='en', language_probs=None, tokens=[4222, 869, 45856, 13], text='Please call Stella.', avg_logprob=-0.1863834500312805, no_speech_prob=0.04131714627146721, temperature=0.0, compression_ratio=0.7037037037037037)],
 [tensor([[ 5.1758e+00, -1.0000e+20, -1.0000e+20,  ...,  2.6641e+00,
            1.3848e+00,  8.3838e-01]], device='cuda:0'),
  tensor([[ 7.6055e+00, -1.0000e+20, -1.0000e+20,  ...,  3.8184e+00,
            3.9199e+00,  2.9531e+00]], device='cuda:0'),
 

## Calculate loss and adapt

In [7]:
steps = 10
optimizer, scheduler = setup_optimizer(params, 'AdamW', lr=0.1, scheduler=None)
for i in range(steps):
    outputs = model.decode(mel, options)
    result_tensor = torch.stack(outputs[1], dim=0)
    result_tensor=result_tensor.permute(1,0,2) # torch.Size([1, 5, 51864])

    e_loss = softmax_entropy(result_tensor).mean(0).mean()
    c_loss = mcc_loss(result_tensor, reweight=False)
    loss = 0
    loss += e_loss
    loss.requires_grad = True
    loss.backward()
    optimizer.step()
    if scheduler is not None: 
        scheduler.step()
    model.zero_grad()

    
    with torch.no_grad():
        outputs = model.decode(mel, options)
        print(outputs)



[INFO]    optimizer: <class 'torch.optim.adamw.AdamW'>
[INFO]    scheduler: None
([DecodingResult(audio_features=tensor([[-1.6826,  0.1215, -0.4236,  ...,  0.4717, -0.7290,  0.2764],
        [-1.3369, -0.6982,  0.3635,  ...,  0.4441, -0.4841,  0.5293],
        [-1.2100, -1.0420,  0.8184,  ...,  0.0606,  0.1081, -0.0176],
        ...,
        [ 0.1824, -0.4514, -0.4377,  ...,  1.0195, -0.5532,  0.7129],
        [ 0.6265, -0.3445,  0.0033,  ...,  1.0947, -0.7920,  0.2429],
        [ 0.8047, -0.7417,  0.4585,  ...,  0.5771, -0.7646, -0.1542]],
       device='cuda:0', dtype=torch.float16), language='en', language_probs=None, tokens=[4222, 869, 45856, 13], text='Please call Stella.', avg_logprob=-0.1863834500312805, no_speech_prob=0.04131714627146721, temperature=0.0, compression_ratio=0.7037037037037037)], [tensor([[ 5.1758e+00, -1.0000e+20, -1.0000e+20,  ...,  2.6641e+00,
          1.3848e+00,  8.3838e-01]], device='cuda:0'), tensor([[ 7.6055e+00, -1.0000e+20, -1.0000e+20,  ...,  3.8184e+