# Fine-Tune Whisper For Werewolf

## Prepare Environment

Map:   0%|          | 0/12504 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1052 > 1024). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/3542 [00:00<?, ? examples/s]

Map:   0%|          | 0/2433 [00:00<?, ? examples/s]

In [41]:

from src.audio.data import load_werewolf_data, filter_data, create_prepare_decoder_input_ids_and_labels_fn

import jax
from tqdm.auto import tqdm
import numpy as np
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

BATCH_SIZE = 50
BOS_LEN = 2
EOS_LEN = 1
MAX_DURATION = 30
MASK_ID = -100
MAX_LENGTH = 448
SAMPLING_RATE = 16000
WORD_ERROR_PENALTY = 100
from datasets import Audio
def create_collate_fn(processor, feature_extractor):
    a_type = Audio()
    def preproc_batch(batch):
        audio_arrays = [a_type.decode_example(x)["array"] for x in batch["audio"]]
        input_features= feature_extractor(audio_arrays, sampling_rate=SAMPLING_RATE).input_features
        input_features = processor.feature_extractor.pad([{"input_features":x} for x in list(input_features)], return_tensors="np").input_features
        # decoder_input_ids = processor.tokenizer.pad([{"input_ids": feature} for feature in list(batch["decoder_input_ids"])], return_tensors="np").input_ids
        # labels_enc = processor.tokenizer.pad([{"input_ids": feature} for feature in list(batch["labels"])], return_tensors="np")
        # labels = np.where(labels_enc.attention_mask>0, labels_enc.input_ids, MASK_ID)
        input_tokens = np.array(batch["input_tokens"], dtype=np.int32)
        target_tokens = np.array(batch["target_tokens"], dtype=np.int32)
        loss_masks = np.array(batch["loss_masks"], dtype=np.float32)
        attention_mask = np.array(batch["attention_mask"], dtype=np.int32)
        return {"input_features":input_features,
                "decoder_input_ids":input_tokens,
                "target_tokens":target_tokens,
                "loss_masks":loss_masks,
                "attention_mask":attention_mask,
                }
    return preproc_batch

def create_process_sample_fn(tokenizer, seq_length):
    def process_sample(sample):
            tokens = tokenizer.encode(sample['prompt'] + sample['completion'], add_special_tokens=False)
            truncated = False
            if len(tokens) > seq_length:
                tokens = tokens[:seq_length]
                truncated = True
            tokens = [tokenizer.bos_token_id] + tokens + [tokenizer.eos_token_id]
            prompt_len = len(tokenizer.encode(sample['prompt'], add_special_tokens=False)) + 1  # add bos token
            loss_masks = ([0.0] * prompt_len) + ([1.0] * (len(tokens) - prompt_len))
            # trunacte and pad everything out
            if len(tokens) > seq_length:
                tokens = tokens[:seq_length]
                loss_masks = loss_masks[:seq_length]
            # before padding, account for shifting
            input_tokens = tokens[:-1]
            loss_masks = loss_masks[1:]
            target_tokens = tokens[1:]
            attention_mask = [1] * len(input_tokens) + [0] * (seq_length - len(input_tokens))
            input_tokens = input_tokens + [tokenizer.pad_token_id] * (seq_length - len(input_tokens))
            target_tokens = target_tokens + [tokenizer.pad_token_id] * (seq_length - len(target_tokens))
            loss_masks = loss_masks + [0.0] * (seq_length - len(loss_masks))
            result ={
                "input_tokens": np.array(input_tokens, dtype=np.int32),
                "target_tokens": np.array(target_tokens, dtype=np.int32),
                "loss_masks": np.array(loss_masks, dtype=np.float32),
                "attention_mask": np.array(attention_mask, dtype=np.int32),
                "truncated": truncated,
            }
            return result
            # sample.update(result)
            # return sample
    return process_sample



model_name = "openai/whisper-small"
tokenizer = WhisperTokenizer.from_pretrained(model_name, bos_token="<|startoftranscript|>")
werewolf_data = load_dataset("iohadrubin/werewolf_dialogue_data_10sec_v2",streaming=True)
process_sample = create_process_sample_fn(tokenizer, 448)
itrable_data = werewolf_data["train"].map(process_sample)

In [35]:
batch = next(itr)

In [38]:
batch["audio"][0].keys()

dict_keys(['bytes', 'path'])

In [44]:


processor = WhisperProcessor.from_pretrained(model_name)
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
collate_fn = create_collate_fn(processor, feature_extractor)


itr = itrable_data.iter(16)
gen_itr = (collate_fn(x) for x in itr)


In [45]:

for i,batch in tqdm(enumerate(gen_itr)):
    if i%100 == 0:
        print(jax.tree.map(np.shape, batch))
        break

0it [00:00, ?it/s]

{'attention_mask': (16, 448), 'decoder_input_ids': (16, 448), 'input_features': (16, 80, 3000), 'loss_masks': (16, 448), 'target_tokens': (16, 448)}


In [43]:
from src.audio.models.whisper import FlaxWhisperForConditionalGeneration
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-small", from_pt=True)

  return self.fget.__get__(instance, owner)()


In [51]:
import jax.numpy as jnp
def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
    if valid is None:
        valid = jnp.ones(tokens.shape[:2])
    valid = valid.astype(jnp.float32)
    valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
    logits = logits.astype(jnp.float32)  # for numerical stability
    logp = jax.nn.log_softmax(logits, axis=-1)
    
    token_log_prob = jnp.squeeze(
        jnp.take_along_axis(
            logp,
            jnp.expand_dims(tokens, -1),
            axis=-1,
        ),
        -1,
    )
    token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
    loss = -(jnp.sum(token_log_prob) / jnp.sum(valid))
    # old: loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
    # changed to match hf implementation
    correct = jnp.where(
        valid > 0.0,
        jnp.argmax(logits, axis=-1) == tokens,
        jnp.array(False)
    )
    accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
    metrics = {
        'accuracy': accuracy,
        'token_logprob_sum': jnp.sum(token_log_prob),
        'valid_sum': jnp.sum(valid),
    }
    return loss, metrics

In [47]:
batch.keys()

dict_keys(['input_features', 'decoder_input_ids', 'target_tokens', 'loss_masks', 'attention_mask'])

In [48]:
model_output = model(decoder_input_ids=batch["decoder_input_ids"], input_features=batch["input_features"], decoder_attention_mask=batch["attention_mask"])

In [52]:
cross_entropy_loss_and_accuracy(model_output.logits, tokens=batch["target_tokens"], valid=batch["loss_masks"])

(Array(7.3337097, dtype=float32),
 {'accuracy': Array(0.02083333, dtype=float32),
  'token_logprob_sum': Array(-352.01807, dtype=float32),
  'valid_sum': Array(48., dtype=float32)})

In [None]:
batch["target_tokens"], batch["loss_masks"]

In [8]:
from src.vision.optim import create_learning_rate_schedule
from src.vision.rolling_avg import RollingAverage
from ml_collections import config_dict
import yaml
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.training import train_state
import optax
import wandb
from typing import Dict, Any
import jax.numpy as jnp

from flax import jax_utils
# name: "timm/imagenet-1k-wds"
CONFIG = """


figure_size:
  width: 10
  height: 5
metrics:
  rolling_average_window: 20
training:
  total_steps: 100000
  warmup_steps: 10000
  lr: 5e-5
  wd: 0.01
  b2: 0.95
  batch_size: 64
"""

def get_config():
    """
    Load config from the above YAML string into a ConfigDict.
    """
    config_dict_raw = yaml.safe_load(CONFIG)
    return config_dict.ConfigDict(config_dict_raw)



class TrainStateWithMetrics(train_state.TrainState):
    """
    Extends the basic Flax TrainState with rolling metrics for loss & accuracy.
    """
    loss_metric: RollingAverage
    acc_metric: RollingAverage
    dropout_rng: jax.random.PRNGKey
    
    def replicate(self):
        return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))





def create_train_state(config, model, input_shape):

    
    rng = jax.random.PRNGKey(0)
    rng, dropout_rng = jax.random.split(rng)
    

    params = model.init_weights(rng,input_shape=input_shape)

    # Create learning rate schedule and optimizer
    lr_schedule = create_learning_rate_schedule(config)
    tx = optax.adamw(lr_schedule, weight_decay=config.training.wd, b2=config.training.b2)

    return TrainStateWithMetrics.create(
        apply_fn=model.__call__,
        params=params,
        tx=tx,
        loss_metric=RollingAverage.create(size=config.metrics.rolling_average_window),
        acc_metric=RollingAverage.create(size=config.metrics.rolling_average_window),
        dropout_rng=dropout_rng,
    )



@jax.jit
def train_step(state: TrainStateWithMetrics, batch: Dict[str, jnp.ndarray]):
    def loss_fn(params):
        outputs = state.apply_fn(
            **{"params": params},
            pixel_values=batch["pixel_values"],
            train=True,  # ensure model is in train mode
        )
        logits = outputs.logits  # [batch, num_labels]
        one_hot = jax.nn.one_hot(batch["labels"], num_classes=logits.shape[-1])
        unnorm_loss =  optax.softmax_cross_entropy(logits, one_hot).sum()
        
        return unnorm_loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (unnorm_loss, logits), grads = grad_fn(state.params)
    grads = jax.lax.psum(grads, "batch")
    new_state = state.apply_gradients(grads=grads)
    
    
    predictions = jnp.argmax(logits, axis=-1) == batch["labels"]
    is_correct = jnp.sum(predictions)
    
    total_n_examples = jax.lax.psum(logits.shape[0], "batch")
    total_is_correct = jax.lax.psum(is_correct, "batch")
    total_loss = jax.lax.psum(unnorm_loss, "batch")
    acc = total_is_correct / total_n_examples
    loss = total_loss / total_n_examples

    # Update rolling average metrics
    curr_loss, new_loss_metric = new_state.loss_metric.update(loss)
    curr_acc, new_acc_metric = new_state.acc_metric.update(acc)

    # Replace the old metrics with updated ones
    new_state = new_state.replace(loss_metric=new_loss_metric, acc_metric=new_acc_metric)

    return new_state, curr_loss, curr_acc, total_n_examples

@jax.jit
def eval_step(state: TrainStateWithMetrics,
              batch: Dict[str, jnp.ndarray]):
    outputs = state.apply_fn(
        **{"params": state.params},
        pixel_values=batch["pixel_values"],
        train=True,  # ensure model is in train mode
    )
    logits = outputs.logits  # [batch, num_labels]
    one_hot = jax.nn.one_hot(batch["labels"], num_classes=logits.shape[-1])
    unnorm_loss =  optax.softmax_cross_entropy(logits, one_hot).sum()
    predictions = jnp.argmax(logits, axis=-1) == batch["labels"]
    is_correct = jnp.sum(predictions)
    
    total_n_examples = jax.lax.psum(logits.shape[0], "batch")
    total_is_correct = jax.lax.psum(is_correct, "batch")
    total_loss = jax.lax.psum(unnorm_loss, "batch")
    acc = total_is_correct / total_n_examples
    loss = total_loss / total_n_examples


    return loss, acc



def create_model(config, stream):
    model_config = ViTConfig(
        num_labels=stream.num_labels,
        label2id=stream.label2id,
        id2label=stream.id2label,
        # ignoring mismatched sizes for demonstration, as in the PyTorch code
        ignore_mismatched_sizes=True,
        **dict(config.model),
    )
    model = FlaxViTForImageClassification(model_config)
    input_shape = (1, model_config.image_size, model_config.image_size, model_config.num_channels)
    return model, model_config, input_shape


def main():
    config = get_config()
    worker_id = jax.process_index()
    if worker_id==0:
        wandb.init(project=f"whisper_jax", config=config.to_dict())

    
    stream = DataStream(config)

 
    lr_schedule = create_learning_rate_schedule(config)
    
    
    model, _, input_shape = create_model(config, stream)
    state = create_train_state(config, model, input_shape)
    state = state.replicate()
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=tuple())
    

    total_steps = config.training.total_steps
    pbar = tqdm(range(total_steps), desc="Training")
    eval_freq = 2000
    eval_steps = 5
    eval_counter = eval_freq
    seen_examples = 0
    for step, batch in zip(pbar, stream.train_iter()):
        # Single train step
        epoch = batch.pop("epoch", 0)
        
        state, curr_loss, curr_acc, total_n_examples = p_train_step(state, batch)
        total_n_examples = int(total_n_examples[0])
        seen_examples += total_n_examples
        curr_loss = curr_loss.mean().item()
        curr_acc = curr_acc.mean().item()
        

        pbar.set_description(f"Loss: {curr_loss:.4f}, Acc: {curr_acc:.4f}")
        metrics = {
                "step": step,
                "loss": float(curr_loss),
                "accuracy": float(curr_acc),
                "lr": float(lr_schedule(step)),
                "epoch": epoch,
                "seen_examples": seen_examples,
            }
        
        eval_counter -= 1
        if eval_counter==0:
            eval_counter = eval_freq
            for i, dev_batch in enumerate(stream.validation_iter()):
                if i>=eval_steps:
                    break
                dev_batch.pop("epoch", 0)
                curr_loss, curr_acc = p_eval_step(state, dev_batch)
                curr_loss = curr_loss.mean().item()
                curr_acc = curr_acc.mean().item()
                
            if worker_id==0:
                wandb.log({"eval_loss": curr_loss, "eval_accuracy": curr_acc, "epoch": epoch, "seen_examples": seen_examples,
                           "step": step})

        # Log to wandb
        if worker_id==0:
            wandb.log(metrics)
    if worker_id==0:
        wandb.finish()


if __name__ == "__main__":
    main()


In [5]:
12504/16

781.5

In [3]:
werewolf_data["train"]

Dataset({
    features: ['audio', 'dialogue', 'start', 'end', 'idx', 'Game_ID', 'file_name', 'video_name', 'startRoles', 'startTime', 'endRoles', 'playerNames', 'decoder_input_ids', 'labels', 'target'],
    num_rows: 12504
})

In [6]:
x

{'input_features': array([[[ 0.82272065,  0.14677548,  0.05719948, ..., -0.53930366,
          -0.53930366, -0.53930366],
         [ 0.8889208 ,  0.38297367,  0.3025496 , ..., -0.53930366,
          -0.53930366, -0.53930366],
         [ 1.0776412 ,  0.85012746,  0.68373585, ..., -0.53930366,
          -0.53930366, -0.53930366],
         ...,
         [ 0.22516626,  0.18507051,  0.26033157, ..., -0.53930366,
          -0.53930366, -0.53930366],
         [ 0.18408287,  0.24704015,  0.18730009, ..., -0.53930366,
          -0.53930366, -0.53930366],
         [ 0.25767684,  0.23682338,  0.09144157, ..., -0.53930366,
          -0.53930366, -0.53930366]],
 
        [[ 0.82272065,  0.14677548,  0.05719948, ..., -0.53930366,
          -0.53930366, -0.53930366],
         [ 0.8889208 ,  0.38297367,  0.3025496 , ..., -0.53930366,
          -0.53930366, -0.53930366],
         [ 1.0776412 ,  0.85012746,  0.68373585, ..., -0.53930366,
          -0.53930366, -0.53930366],
         ...,
         [ 0.22

In [None]:
itr.

In [81]:
from torch.utils.data import DataLoader

In [74]:
batch = next(itr)

In [None]:

# audio_arrays = [a_feature.decode_example(value=x)["array"] for x in batch["audio"]]


In [None]:
itr = tqdm(itr, total=len(itr))

In [None]:
dloader = DataLoader(dataset, batch_size=per_proc_batch_size, collate_fn=collate_fn)

In [None]:

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [57]:
itrable_data

IterableDataset({
    features: ['audio', 'dialogue', 'start', 'end', 'idx', 'Game_ID', 'file_name', 'video_name', 'startRoles', 'startTime', 'endRoles', 'playerNames', 'decoder_input_ids', 'labels', 'target'],
    num_shards: 1
})

In [56]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)

In [54]:
batch = next(itr)

In [55]:
batch

{'audio': [{'path': '459079951205a9baccf61758ef03f7c623c5f4dcba760df6afede955f6ea4e4a.wav',
   'array': array([-0.01821899, -0.01461792, -0.01190186, ...,  0.02914429,
           0.03326416,  0.0355835 ]),
   'sampling_rate': 16000},
  {'path': '65b017f792b6e615810e8cd946f1d0bab8d69357eb6ba82bfd321bd8094260e5.wav',
   'array': array([-0.0149231 , -0.01565552, -0.01480103, ..., -0.20046997,
          -0.27645874, -0.36056519]),
   'sampling_rate': 16000},
  {'path': 'eb5d0e83dfc61fdb6cbafdf9e6c6e94e496d5398c347e40fa6344c77de18bb25.wav',
   'array': array([ 0.10049438,  0.09466553,  0.08831787, ..., -0.20046997,
          -0.27645874, -0.36056519]),
   'sampling_rate': 16000},
  {'path': '97344fe47ff1ce9654b4af785b0bfdb765f1f6f9992319840a1432886c8eeecb.wav',
   'array': array([ 0.10049438,  0.09466553,  0.08831787, ..., -0.20046997,
          -0.27645874, -0.36056519]),
   'sampling_rate': 16000},
  {'path': '4dcfea3b01359b572c35eda73b5fd475d78c9f645510d05cb8c85f9660f04a74.wav',
   'arra

In [48]:
?Audio.decode_example

[0;31mSignature:[0m
[0mAudio[0m[0;34m.[0m[0mdecode_example[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mself[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mvalue[0m[0;34m:[0m [0mdict[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtoken_per_repo_id[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mDict[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mbool[0m[0;34m,[0m [0mNoneType[0m[0;34m][0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0mdict[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Decode example audio file into audio data.

Args:
    value (`dict`):
        A dictionary with keys:

        - `path`: String with relative audio file path.
        - `bytes`: Bytes of the audio file.
    token_per_repo_id (`dict`, *optional*):
        To access and decode
        audio files from private repositories on the Hub, you can pass
        a dictionary repo

In [37]:
from datasets import Audio

In [27]:
# batch

In [39]:

# batch["audio"]

<class 'dict'>
dict_keys(['bytes', 'path'])


NameError: name 'jax' is not defined

In [21]:

batch["input_features"].shape
batch["decoder_input_ids"].shape
batch["labels"].shape

<class 'dict'>
dict_keys(['bytes', 'path'])


TypeError: string indices must be integers

In [13]:
itr = itr

AttributeError: 'generator' object has no attribute 'map'

In [12]:
prepare_audio(batch)

{'audio': [{'path': '459079951205a9baccf61758ef03f7c623c5f4dcba760df6afede955f6ea4e4a.wav',
   'array': array([-0.01821899, -0.01461792, -0.01190186, ...,  0.02914429,
           0.03326416,  0.0355835 ]),
   'sampling_rate': 16000},
  {'path': '65b017f792b6e615810e8cd946f1d0bab8d69357eb6ba82bfd321bd8094260e5.wav',
   'array': array([-0.0149231 , -0.01565552, -0.01480103, ..., -0.20046997,
          -0.27645874, -0.36056519]),
   'sampling_rate': 16000},
  {'path': 'eb5d0e83dfc61fdb6cbafdf9e6c6e94e496d5398c347e40fa6344c77de18bb25.wav',
   'array': array([ 0.10049438,  0.09466553,  0.08831787, ..., -0.20046997,
          -0.27645874, -0.36056519]),
   'sampling_rate': 16000},
  {'path': '97344fe47ff1ce9654b4af785b0bfdb765f1f6f9992319840a1432886c8eeecb.wav',
   'array': array([ 0.10049438,  0.09466553,  0.08831787, ..., -0.20046997,
          -0.27645874, -0.36056519]),
   'sampling_rate': 16000},
  {'path': '4dcfea3b01359b572c35eda73b5fd475d78c9f645510d05cb8c85f9660f04a74.wav',
   'arra

In [None]:

werewolf_data = werewolf_data.map(prepare_audio, batched=True, batch_size=BATCH_SIZE,num_proc=16)


In [None]:

werewolf_data = werewolf_data.map(prepare_decoder_input_ids_and_labels)
werewolf_data = werewolf_data.remove_columns(["start", "end", "idx", "Game_ID", "file_name", "video_name", "startRoles", "startTime", "endRoles", "playerNames"])
    # return werewolf_data, tokenizer, feature_extractor

In [3]:
werewolf_data, tokenizer, feature_extractor = load_and_prepare_werewolf_data()

Map:   0%|          | 0/12504 [00:00<?, ? examples/s]

2025-02-25 14:56:58.191675: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-25 14:56:58.196304: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-25 14:56:58.209014: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1740495418.230889 3972896 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740495418.237570 3972896 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-25 14:56:58.264115: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

KeyboardInterrupt: 

In [21]:
werewolf_data.set_format(type="numpy", columns=["input_features", "decoder_input_ids", "labels"])

In [39]:
itr = iter(werewolf_data["train"].iter(16))

In [45]:
batch = next(itr)
jax.tree.map(np.shape, batch)

{'decoder_input_ids': (16,), 'input_features': (16, 80, 3000), 'labels': (16,)}

In [26]:
import jax.numpy as jnp
import jax

In [49]:

# batch


from dataclasses import dataclass
from transformers import WhisperProcessor


data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)


In [50]:
data_collator(batch)

{'input_features': tensor([[[ 1.4271e-01,  3.1942e-01,  3.2700e-01,  ..., -6.8846e-01,
          -6.8846e-01, -6.8846e-01],
         [ 1.8945e-01,  2.9090e-01,  3.5406e-01,  ..., -6.8846e-01,
          -6.8846e-01, -6.8846e-01],
         [ 2.3707e-01,  1.2807e-01,  2.5274e-01,  ..., -6.8846e-01,
          -6.8846e-01, -6.8846e-01],
         ...,
         [-6.8846e-01, -6.7930e-01, -6.8846e-01,  ..., -6.8846e-01,
          -6.8846e-01, -6.8846e-01],
         [-6.8846e-01, -6.7836e-01, -6.8846e-01,  ..., -6.8846e-01,
          -6.8846e-01, -6.8846e-01],
         [-6.8143e-01, -6.3424e-01, -6.7985e-01,  ..., -6.8846e-01,
          -6.8846e-01, -6.8846e-01]],

        [[ 2.4604e-01,  3.1551e-01,  3.2865e-01,  ..., -6.8846e-01,
          -6.8846e-01, -6.8846e-01],
         [ 9.7445e-02,  3.2855e-01,  3.2837e-01,  ..., -6.8846e-01,
          -6.8846e-01, -6.8846e-01],
         [-4.9492e-02,  2.2324e-01,  1.8133e-01,  ..., -6.8846e-01,
          -6.8846e-01, -6.8846e-01],
         ...,
      

## Compute Metrics

In [53]:
!python3.10 -m pip install --upgrade evaluate jiwer

Defaulting to user installation because normal site-packages is not writeable
Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Downloading jiwer-3.1.0-py3-none-any.whl (22 kB)
Downloading rapidfuzz-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m87.9 MB/s[0m eta [36m0:00:00[0m
[0mInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-3.1.0 rapidfuzz-3.12.1


In [55]:
!python3.10 -m  pip install  evaluate==0.4.3

Defaulting to user installation because normal site-packages is not writeable
[0m

In [54]:
import evaluate

metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    label_ids[label_ids == MASK_ID] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = WORD_ERROR_PENALTY * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

AttributeError: 'DownloadConfig' object has no attribute 'use_auth_token'

## Training Arguments

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-werewolf",
    eval_strategy="steps",
    eval_steps=1000,
    max_steps=4000,
    warmup_steps=500,
    logging_steps=25,
    save_steps=1000,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    report_to=["tensorboard"],
    push_to_hub=True,
    gradient_checkpointing=True,
    predict_with_generate=True
)

## Trainer

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=werewolf_data["train"],
    eval_dataset=werewolf_data["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

### Training

In [None]:
trainer.train()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss


Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


ValueError: one or more references are empty strings

## Publish

In [None]:
kwargs = {
    "dataset_tags": "iohadrubin/werewolf_dialogue_data_10sec",
    "dataset_args": "split: test",
    "model_name": "Whisper Small Werewolf",
    "finetuned_from": "openai/whisper-small",
    "tasks": "classification",
}

In [None]:
trainer.push_to_hub(**kwargs)