# Fine-Tune Whisper For Werewolf

## Prepare Environment

In [17]:
import jax
import jax.numpy as jnp
import numpy as np
import optax
# import wandb
import yaml

from audio.data import create_stream
from audio.models.whisper import FlaxWhisperForConditionalGeneration
from audio.optim import create_learning_rate_schedule
from audio.rolling_avg import RollingAverage
from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard, shard_prng_key
from ml_collections import config_dict
from tqdm.auto import tqdm
from typing import Dict
from utils import cross_entropy_loss_and_accuracy

CONFIG = """
figure_size:
  width: 10
  height: 5
metrics:
  rolling_average_window: 20
training:
  total_steps: 100000
  warmup_steps: 10000
  lr: 5e-5
  wd: 0.01
  b2: 0.95
  batch_size: 64
"""


def get_config():
    """
    Load config from the above YAML string into a ConfigDict.
    """
    config_dict_raw = yaml.safe_load(CONFIG)
    return config_dict.ConfigDict(config_dict_raw)


class TrainStateWithMetrics(train_state.TrainState):
    """
    Extends the basic Flax TrainState with rolling metrics for loss & accuracy.
    """

    loss_metric: RollingAverage
    acc_metric: RollingAverage
    dropout_rng: jax.random.PRNGKey

    def replicate(self):
        return jax_utils.replicate(self).replace(
            dropout_rng=shard_prng_key(self.dropout_rng)
        )


def create_train_state(config, model, params):

    rng = jax.random.PRNGKey(0)
    rng, dropout_rng = jax.random.split(rng)

    # Create learning rate schedule and optimizer
    lr_schedule = create_learning_rate_schedule(config)
    tx = optax.adamw(
        lr_schedule, weight_decay=config.training.wd, b2=config.training.b2
    )

    return TrainStateWithMetrics.create(
        apply_fn=model.__call__,
        params=params,
        tx=tx,
        loss_metric=RollingAverage.create(size=config.metrics.rolling_average_window),
        acc_metric=RollingAverage.create(size=config.metrics.rolling_average_window),
        dropout_rng=dropout_rng,
    )


@jax.jit
def train_step(state: TrainStateWithMetrics, batch: Dict[str, jnp.ndarray]):
    def loss_fn(params):
        outputs = state.apply_fn(
            **{"params": params},
            input_features=batch["input_features"],
            decoder_input_ids=batch["decoder_input_ids"],
            decoder_attention_mask=batch["attention_mask"],
            # pixel_values=batch["pixel_values"],
            train=True,  # ensure model is in train mode
        )
        logits = outputs.logits  # [batch, num_labels]
        # one_hot = jax.nn.one_hot(batch["labels"], num_classes=logits.shape[-1])
        # unnorm_loss =  optax.softmax_cross_entropy(logits, one_hot).sum()
        unnorm_loss, metrics = cross_entropy_loss_and_accuracy(
            logits, tokens=batch["target_tokens"], valid=batch["loss_masks"]
        )

        return unnorm_loss, (logits, metrics)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (unnorm_loss, (logits, metrics)), grads = grad_fn(state.params)
    grads = jax.lax.psum(grads, "batch")
    new_state = state.apply_gradients(grads=grads)

    # predictions = jnp.argmax(logits, axis=-1) == batch["labels"]
    # is_correct = jnp.sum(predictions)

    total_n_examples = jax.lax.psum(logits.shape[0], "batch")
    total_is_correct = jax.lax.psum(metrics["is_correct"], "batch")
    total_loss = jax.lax.psum(unnorm_loss, "batch")
    acc = total_is_correct / total_n_examples
    loss = total_loss / total_n_examples

    # Update rolling average metrics
    curr_loss, new_loss_metric = new_state.loss_metric.update(loss)
    curr_acc, new_acc_metric = new_state.acc_metric.update(acc)

    # Replace the old metrics with updated ones
    new_state = new_state.replace(
        loss_metric=new_loss_metric, acc_metric=new_acc_metric
    )

    return new_state, curr_loss, curr_acc, total_n_examples


# def main():
config = get_config()
# worker_id = jax.process_index()
# if worker_id==0:
#     wandb.init(project=f"whisper_jax", config=config.to_dict())


# stream = DataStream(config)


lr_schedule = create_learning_rate_schedule(config)


model = FlaxWhisperForConditionalGeneration.from_pretrained(
    "openai/whisper-small", from_pt=True
)
params = model.params
state = create_train_state(config, model, params)
state = state.replicate()
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
# p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=tuple())


total_steps = config.training.total_steps
pbar = tqdm(range(total_steps), desc="Training")
eval_freq = 2000
eval_steps = 5
eval_counter = eval_freq
seen_examples = 0
stream = create_stream()
for step, batch in zip(pbar, stream):
    # Single train step
    print(jax.tree.map(np.shape, batch))
    batch = shard(batch)
    epoch = batch.pop("epoch", 0)

    state, curr_loss, curr_acc, total_n_examples = p_train_step(state, batch)
    total_n_examples = int(total_n_examples[0])
    seen_examples += total_n_examples
    curr_loss = curr_loss.mean().item()
    curr_acc = curr_acc.mean().item()

    pbar.set_description(f"Loss: {curr_loss:.4f}, Acc: {curr_acc:.4f}")
    metrics = {
        "step": step,
        "loss": float(curr_loss),
        "accuracy": float(curr_acc),
        "lr": float(lr_schedule(step)),
        "epoch": epoch,
        "seen_examples": seen_examples,
    }

    # eval_counter -= 1
    # if eval_counter==0:
    #     eval_counter = eval_freq
    #     for i, dev_batch in enumerate(stream.validation_iter()):
    #         if i>=eval_steps:
    #             break
    #         dev_batch.pop("epoch", 0)
    #         curr_loss, curr_acc = p_eval_step(state, dev_batch)
    #         curr_loss = curr_loss.mean().item()
    #         curr_acc = curr_acc.mean().item()

    #     if worker_id==0:
    #         wandb.log({"eval_loss": curr_loss, "eval_accuracy": curr_acc, "epoch": epoch, "seen_examples": seen_examples,
    #                     "step": step})

    # Log to wandb
    print(metrics)
    # if worker_id==0:
    # wandb.log(metrics)
# if worker_id==0:
# wandb.finish()


# if __name__ == "__main__":
#     main()


RuntimeError: Unable to initialize backend 'tpu': ABORTED: The TPU is already in use by process with pid 4065990. Not attempting to load libtpu.so in this process. (set JAX_PLATFORMS='' to automatically choose an available backend)