In [None]:


import torch
import torch.nn as nn
from datasets import load_dataset, Audio
from transformers import EncodecModel, AutoProcessor, TrainingArguments, Trainer
import numpy as np

MODEL_ID = "facebook/encodec_24khz"
DATASET_ID = "hf-internal-testing/librispeech_asr_dummy" #example
SPLIT = "validation"
MAX_DURATION_SAMPLES = 123840
OUTPUT_DIR = "./results_encodec_fixed_2"
LEARNING_RATE = 2e-5
BATCH_SIZE = 4
NUM_EPOCHS = 3

processor = AutoProcessor.from_pretrained(MODEL_ID)
target_sampling_rate = processor.sampling_rate

librispeech_dummy = load_dataset(DATASET_ID, split=SPLIT)
librispeech_dummy = librispeech_dummy.cast_column(
    "audio", Audio(sampling_rate=target_sampling_rate)
)

class CustomEncodecModelForReconstruction(EncodecModel):
    def __init__(self, config):
        super().__init__(config)
        self.loss_fn = nn.MSELoss()

    def forward(
        self,
        input_values,
        attention_mask=None,
        labels=None,
        return_dict=None,
        **kwargs,
    ):
        """
        Forward pass for reconstruction. Accepts 'attention_mask' from Trainer.
        Does not pass output_attentions/hidden_states to base model.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = super().forward(
            input_values=input_values,
            padding_mask=attention_mask,
            return_dict=True, # Request dict output from base model
            # Removed output_attentions and output_hidden_states from the call
        )

        reconstructed_audio = outputs.audio_values

        loss = None
        if labels is not None:
            output_length = reconstructed_audio.shape[-1]
            label_length = labels.shape[-1]

            if output_length > label_length:
                padding_size = output_length - label_length
                padded_labels = nn.functional.pad(labels, (0, padding_size))
            elif label_length > output_length:
                 padded_labels = labels[..., :output_length]
            else:
                 padded_labels = labels

            if padded_labels.dim() == 2:
                padded_labels = padded_labels.unsqueeze(1)

            mask = attention_mask
            if mask is not None:
                mask_length = mask.shape[-1]
                if output_length > mask_length:
                    mask_padding_size = output_length - mask_length
                    mask = nn.functional.pad(mask, (0, mask_padding_size), value=0)
                elif mask_length > output_length:
                    mask = mask[..., :output_length]

                if mask.dim() == 2:
                     mask = mask.unsqueeze(1)

                mask = mask.bool()


                loss = self.loss_fn(
                    reconstructed_audio.where(mask, torch.tensor(0.0, device=reconstructed_audio.device)),
                    padded_labels.where(mask, torch.tensor(0.0, device=padded_labels.device))
                )
            else:
                 loss = self.loss_fn(reconstructed_audio, padded_labels)

        final_outputs = {}
        if loss is not None:
            final_outputs["loss"] = loss

        for key, value in outputs.items():
            final_outputs[key] = value

        if not return_dict:
             output_tuple_values = [v for k, v in final_outputs.items() if k != 'loss' and v is not None]
             return (loss,) + tuple(output_tuple_values) if loss is not None else tuple(output_tuple_values)
        else:
            return final_outputs


model = CustomEncodecModelForReconstruction.from_pretrained(MODEL_ID)

def preprocess_function(examples, max_length=MAX_DURATION_SAMPLES):
    try:
        audio_data = examples["audio"]["array"]
        current_sampling_rate = examples["audio"]["sampling_rate"]
    except Exception as e:
        print(f"Error accessing audio data for example: {examples}. Error: {e}")
        return {"input_values": [], "attention_mask": [], "labels": []}

    if current_sampling_rate != target_sampling_rate:
        raise ValueError(f"Incorrect sampling rate: {current_sampling_rate} vs {target_sampling_rate}")

    if len(audio_data) > max_length:
        processed_audio = audio_data[:max_length]
    else:
        padding = np.zeros(max_length - len(audio_data))
        processed_audio = np.concatenate([audio_data, padding])

    processed_audio = processed_audio.astype(np.float32)

    processed = processor(
        raw_audio=processed_audio,
        sampling_rate=target_sampling_rate,
        return_tensors="pt"
    )

    input_values = torch.squeeze(processed.input_values, 0)

    padding_mask = processed.get("padding_mask")
    if padding_mask is not None:
         mask = torch.squeeze(padding_mask, 0)
    else:
         mask = torch.ones_like(input_values[0], dtype=torch.long) # Assuming input_values is [1, T] after squeeze

    labels = torch.tensor(processed_audio, dtype=torch.float32)

    output = {"input_values": input_values, "attention_mask": mask, "labels": labels}
    return output

tokenized_datasets = librispeech_dummy.map(
    preprocess_function,
    remove_columns=librispeech_dummy.column_names,
    num_proc=1
)
tokenized_datasets = tokenized_datasets.filter(lambda example: "input_values" in example and len(example['input_values']) > 0)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    eval_strategy="epoch",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_EPOCHS,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="none",
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    eval_dataset=tokenized_datasets,
)

print("Starting training...")
trainer.train()
print("Training finished.")

trainer.save_model(f"{OUTPUT_DIR}/final_model")
processor.save_pretrained(f"{OUTPUT_DIR}/final_model")
print(f"Final model saved to {OUTPUT_DIR}/final_model")