# Training

In [1]:
import sys
sys.path.append("/work/b0990106x/trl/vc")
import importlib
import vc
importlib.reload(vc)
import torch
from vc.trainer_encodec_vc_inference import get_ar_prediction_v3, pack_inputs_v2
from types import SimpleNamespace
from transformers import BartForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
from NISQA.nisqa.NISQA_model import nisqaModel
from datasets import load_from_disk, Dataset
from trl import DPOTrainer, DPOConfig, AutoModelForSeq2SeqLMWithValueHead, create_reference_model
from vc.encodec_model.nar_bart_model import NARBartForConditionalGeneration
from datetime import datetime
import os
import numpy as np
from dpo_eval import get_reward, eval_dpo_mos
import json
from tqdm import tqdm
import time
from typing import List, Tuple



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def generate_output(
        ar_model, 
        nar_model, 
        ar_tokenizer, 
        nar_tokenizer, 
        src_encodec: list, 
        instruction: list, 
        args_predict: SimpleNamespace, 
        episode_counter: int = 0, 
        base_path: str = "/work/b0990106x/trl", 
        temperature: float = 1.0
) -> tuple[float, str]:
    '''
    Generates output from AR model, synthesize the audio, and evaluate the audio using NISQA.

    Args:
        ar_model(BartForConditionalGeneration): AR model
        nar_model(NarbartForConditionalGeneration): NAR model
        ar_tokenizer(AutoTokenizer): AR tokenizer
        nar_tokenizer(AutoTokenizer): NAR tokenizer
        src_encodec(list): A list of inputs, where each input is a list of layers, and each layer is a list of v_token integers.
        instruction(list): A list of string of instructions.
        args_predict(SimpleNamespace): A SimpleNamespace object containing the arguments for the NISQA prediction.
        episode_counter(int): A counter that determine the name of the output audio.
        base_path(str): The path to the base directory.
        temperature(float): The temperature for the AR model.

    Returns:
        tuple:
            reward(float): The reward of the audio.
            tokenized_decode_ar(str): The tokenized output of the AR model - first layer.
    '''
    # Generate predictions using the AR model
    _, decode_ar, output_path_ckpt = get_ar_prediction_v3(
        args_predict, ar_model, nar_model, ar_tokenizer, nar_tokenizer, src_encodec, instruction, episode_counter, temperature=temperature
    )

    # Flatten the decoded AR output tensor and convert it to a list
    list_decode_ar = decode_ar.flatten().tolist()   

    # Evaluate the audio to get the reward
    reward = get_reward(output_path_ckpt, base_path)
    
    # Filter the decoded AR output to remove special tokens
    filtered_decode_ar_list = list_decode_ar[2:-1]

    # Convert the filtered token IDs back to tokens and then to a string
    decode_ar_tokens = ar_tokenizer.convert_ids_to_tokens(filtered_decode_ar_list)
    tokenized_decode_ar = ar_tokenizer.convert_tokens_to_string(decode_ar_tokens)

    return reward, tokenized_decode_ar

def extract_data_from_json(file_path: str) -> Tuple[List[list], List[str], List[list]]:
    """
    Loads data from a JSON file and extracts 'src_encodec', 'instruction', and 'tgt_encodec'.

    Args:
        file_path (str): The path to the JSON file.

    Returns:
        tuple:
            all_src_encodec (List[list]): A list containing the 'src_encodec' data from each item in the JSON file.
            all_instruction (List[str]): A list containing the 'instruction' data from each item in the JSON file.
            all_tgt_encodec (List[list]): A list containing the 'tgt_encodec' data from each item in the JSON file.
    """
    with open(file_path, 'r') as f:
        data = json.load(f)

    all_src_encodec = [item["src_encodec"] for item in data]
    all_instruction = [item["instruction"] for item in data]
    all_tgt_encodec = [item["tgt_encodec"] for item in data]

    return all_src_encodec, all_instruction, all_tgt_encodec

def train_model(
        model,
        model_ref,
        ar_tokenizer,
        train_dataset: Dataset,
        val_dataset: Dataset,
        model_output_dir: str,
        beta: float,
        resume_from_checkpoint: bool,
        model_checkpoint: str,
        learning_rate: float = 5e-05,
        num_train_epochs: int = 100,
        max_length: int = 1024*9,
        max_prompt_length: int = 1024*9,
        max_target_length: int = 1024*9,
        per_device_train_batch_size: int = 1,
        gradient_accumulation_steps: int = 1,
        seed: int = 42
) -> None:
    '''
    Train the DPO model and save the model.

    Args:
        model(AutoModelForSeq2SeqLMWithValueHead): The DPO model.
        model_ref(AutoModelForCausalLM): The reference model.
        ar_tokenizer(AutoTokenizer): The tokenizer.
        train_dataset(Dataset): The training dataset.
        val_dataset(Dataset): The validation dataset.
        model_output_dir(str): The output directory for the model.
        beta(float): The beta value.
        resume_from_checkpoint(bool): Whether to resume from a checkpoint.
        model_checkpoint(str): The path to the model

    Returns:
        None
    '''

    training_args = DPOConfig(
        beta = beta,
        output_dir = model_output_dir,
        generate_during_eval = True,
        resume_from_checkpoint = model_checkpoint if resume_from_checkpoint else None,
        seed = seed,
        per_device_train_batch_size = per_device_train_batch_size,
        num_train_epochs = num_train_epochs,
        gradient_accumulation_steps = gradient_accumulation_steps,
        learning_rate = learning_rate,
        max_length = max_length,
        max_prompt_length = max_prompt_length,
        max_target_length = max_target_length
    )
    
    trainer = DPOTrainer(
        model=model,
        ref_model=model_ref,
        args=training_args,
        tokenizer=ar_tokenizer,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
    )
    # Train the model
    trainer.train()

    # Save the model
    trainer.save_model(f"{model_output_dir}/dpo_model")
    model.config.to_json_file(f"{model_output_dir}/dpo_model/config.json")
    ar_tokenizer.save_pretrained(f"{model_output_dir}/dpo_model")

In [3]:
def process_data(sample_size: int, 
                 ar_model, 
                 nar_model, 
                 ar_tokenizer, 
                 nar_tokenizer, 
                 all_src_encodec: List[list], 
                 all_instruction: List[str],
                 args_predict: SimpleNamespace, 
                 base_path: str = "/work/b0990106x/trl", 
                 temperature: float = 1.0
):
    """
    Process data to generate outputs, calculate rewards, and organize chosen and rejected data.
    """
    if sample_size < 2:
        raise ValueError("Parameter 'sample_size' must be greater than 1.")

    chosen, rejected, prompts, chosen_rewards, rejected_rewards, average_rewards = [], [], [], [], [], []

    for i in tqdm(range(len(all_src_encodec)), desc="Processing Data"):
        rewards, tokenized_outputs = [], []

        for j in tqdm(range(sample_size), desc="Processing Samples"):
            size_of_packed_input = (
                    len(all_src_encodec[i][0]) +
                    len(ar_tokenizer(all_instruction[i])["input_ids"][1:-1]) +
                    3
            )
            if 4 < size_of_packed_input <= 1024:
                reward, tokenized_decode_ar = generate_output(
                    ar_model=ar_model, 
                    nar_model=nar_model, 
                    ar_tokenizer=ar_tokenizer, 
                    nar_tokenizer=nar_tokenizer,
                    src_encodec = all_src_encodec[i],
                    instruction=all_instruction[i], 
                    args_predict=args_predict,
                    episode_counter=f"data_{i}_episode_{j}",
                    base_path=base_path, 
                    temperature=temperature
                )
                rewards.append(reward)
                tokenized_outputs.append(tokenized_decode_ar)

        valid_rewards = [r for r in rewards if r is not None]
        valid_outputs = [tokenized_outputs[j] for j in range(len(rewards)) if rewards[j] is not None]

        if len(valid_rewards) >= 2:
            max_reward_index = np.argmax(valid_rewards)
            min_reward_index = np.argmin(valid_rewards)
            average_reward = np.mean(valid_rewards)
            chosen_output = valid_outputs[max_reward_index]
            rejected_output = valid_outputs[min_reward_index]

            obs_input = pack_inputs_v2(ar_tokenizer, all_src_encodec[i], all_instruction[i])
            tokenize_input = ar_tokenizer.convert_ids_to_tokens(obs_input)
            tokenize_input_str = ar_tokenizer.convert_tokens_to_string(tokenize_input)
            prompts.append(tokenize_input_str)

            chosen.append(chosen_output)
            chosen_rewards.append(valid_rewards[max_reward_index])
            rejected.append(rejected_output)
            rejected_rewards.append(valid_rewards[min_reward_index])
            average_rewards.append(average_reward)
        else:
            print(f"Not enough valid rewards for data index {i}")

    # print(f"{len(chosen)} datas are processed.")
    return chosen, rejected, prompts, chosen_rewards, rejected_rewards, average_rewards


def generate_data(ar_model, 
                  ar_tokenizer, 
                  nar_model, 
                  nar_tokenizer, 
                  selected_src_encodec, 
                  selected_instruction,
                  args_predict: SimpleNamespace, 
                  sample_size: int, 
                  iteration: int, 
                  agent_output_dir: str, 
                  base_path: str = "/work/b0990106x/trl", 
                  temperature: float = 1.0
):
    """
    Generates data for the dataset and saves info to a JSON file.
    """
    chosen, rejected, prompts, chosen_rewards, rejected_rewards, average_rewards = process_data(
        sample_size=sample_size,
        ar_model=ar_model,
        nar_model=nar_model,
        ar_tokenizer=ar_tokenizer,
        nar_tokenizer=nar_tokenizer,
        all_src_encodec=selected_src_encodec,
        all_instruction=selected_instruction,
        args_predict=args_predict,
        base_path=base_path,
        temperature=temperature
    )

    data = {
        "prompt": prompts,
        "chosen": chosen,
        "rejected": rejected,
        "chosen_rewards": chosen_rewards,
        "rejected_rewards": rejected_rewards,
        "average_rewards": average_rewards
    }

    if len(selected_src_encodec) == 1:
        data = {
            "prompt": prompts + prompts,
            "chosen": chosen + chosen,
            "rejected": rejected + rejected,
            "chosen_rewards": chosen_rewards + chosen_rewards,
            "rejected_rewards": rejected_rewards + rejected_rewards,
            "average_rewards": average_rewards + average_rewards
        }

    with open(f"{agent_output_dir}/data_iter_{iteration}.json", "w") as outfile:
        json.dump(data, outfile, indent=4)

    data_for_dataset = {key: data[key] for key in ["prompt", "chosen", "rejected"]}
    return data_for_dataset, chosen_rewards, rejected_rewards

def train_iteration(model_checkpoint, 
                    iteration, 
                    data_size, 
                    sample_size, 
                    ar_checkpoint, 
                    nar_checkpoint, 
                    all_src_encodec, 
                    all_instruction, 
                    args_predict, 
                    agent_output_dir,
                    model_output_dir_base, 
                    beta = 0.1, 
                    temperature = 1.0,
                    base_path="/work/b0990106x/trl",
                    resume_from_checkpoint = False
):
    """
    Executes one training iteration: generates data, trains the model, and saves the output.
    """
    # print(f"Iteration {iteration}")

    ar_model = BartForConditionalGeneration.from_pretrained(model_checkpoint)
    ar_tokenizer = AutoTokenizer.from_pretrained(ar_checkpoint)
    ar_tokenizer.pad_token = ar_tokenizer.eos_token
    nar_model = NARBartForConditionalGeneration.from_pretrained(nar_checkpoint)
    nar_tokenizer = AutoTokenizer.from_pretrained(nar_checkpoint)

    selected_src_encodec = all_src_encodec[:data_size]
    selected_instruction = all_instruction[:data_size]

    data_for_dataset, chosen_rewards, rejected_rewards = generate_data(ar_model=ar_model,
                                                                        ar_tokenizer=ar_tokenizer,
                                                                        nar_model=nar_model,
                                                                        nar_tokenizer=nar_tokenizer,
                                                                        selected_src_encodec=selected_src_encodec,
                                                                        selected_instruction=selected_instruction,
                                                                        args_predict=args_predict,
                                                                        sample_size=sample_size,
                                                                        iteration=iteration,
                                                                        agent_output_dir=agent_output_dir,
                                                                        base_path=base_path,
                                                                        temperature=temperature)

    dataset = Dataset.from_dict(data_for_dataset)
    dataset_dict = dataset.train_test_split(test_size=0.1)
    train_dataset = dataset_dict["train"]
    val_dataset = dataset_dict["test"]

    model_output_dir = f"{model_output_dir_base}/iter_{iteration}"
    os.makedirs(model_output_dir, exist_ok=True)

    model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_checkpoint, return_dict=True)
    model_ref = create_reference_model(model)

    train_model(model=model,
                model_ref=model_ref,
                ar_tokenizer=ar_tokenizer,
                train_dataset=train_dataset,
                val_dataset=val_dataset,
                model_output_dir=model_output_dir,
                beta=beta,
                resume_from_checkpoint=resume_from_checkpoint,
                model_checkpoint=model_checkpoint)

    return f"{model_output_dir}/dpo_model", chosen_rewards, rejected_rewards

In [4]:
# Load all data
all_src_encodec, all_instruction, all_tgt_encodec = extract_data_from_json('dpo_data/src_encodec.json')
print(len(all_src_encodec))
print(len(all_instruction))
print(len(all_tgt_encodec))

# Define paths and device
base_path = "/work/b0990106x/trl"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define timestamp
now = datetime.now()
ts = now.strftime("%m%d-%H%M")
print("timestamp:", ts)

# Define paths
model_output_dir = f"{base_path}/model_output/{ts}" # Location where the model are saved
agent_input_dir = f"{base_path}/data-encodec" # Location of our original data(input) is stored
agent_output_dir = f"{base_path}/output/{ts}" # Path of saving the generated audio for reward model to evaluate

if not os.path.exists(model_output_dir):
    os.makedirs(model_output_dir)

if not os.path.exists(agent_output_dir):
    os.makedirs(agent_output_dir)

# Define arguments 
args_predict = SimpleNamespace(output_path=f"{base_path}/output/{ts}/example.wav", seed=0, device=device)
ar_checkpoint = "lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans"
nar_checkpoint = "lca0503/speech-chatgpt-base-nar-v2-epoch4-wotrans"

# Define training parameters
model_checkpoint = ar_checkpoint # set the initial model checkpoint
initial_data_size =  1 # Training: data size for the first iteration
data_size_per_iteration = 1 # Training: each iteration will train how many data
total_data_size = 1 # Training: total data size that we want to train
sample_size = 10 # Prepare Dataset: generate how many outputs to select max and min for chosen and rejected
beta = 0.1 # Training: beta value for DPO
num_iterations = 20  # Training: train how many iterations
eval_selected_indices = [0] # Evaluation: evaluate which data from extract_data_from_json('dpo_data/src_encodec.json')
eval_data_len = 1 # Evaluation: evaluate how many data

# num_iterations = (total_data_size - initial_data_size) // data_size_per_iteration + 1 # Training: train how many iterations
# eval_selected_indices = random.sample(range(len(all_src_encodec)), eval_data_len) # Evaluation: select 10 data for evaluation
print(f"length of all_src_encodec: {len(all_src_encodec)}") # ~ 9000 data
print(f"length of all_instruction: {len(all_instruction)}") # ~ 9000 data


9254
9254
9254
timestamp: 0725-1827
length of all_src_encodec: 9254
length of all_instruction: 9254


In [5]:
print(f"num_iterations: {num_iterations}")
print(f"data_size_per_iteration: {data_size_per_iteration}")
print(f"sample_size: {sample_size}")
print(f"beta: {beta}")
print(f"ar_checkpoint: {ar_checkpoint}")
print(f"nar_checkpoint: {nar_checkpoint}")
print(f"args_predict: {args_predict}")
print(f"model_output_dir: {model_output_dir}")
print(f"agent_output_dir: {agent_output_dir}")
print(f"base_path: {base_path}")
print(f"device: {device}")
print(f"eval_data_len: {eval_data_len}")
print(f"eval_selected_indices: {eval_selected_indices}")
print("Type of batch_src_encodec:", type(all_src_encodec))
print("Type of batch_instruction:", type(all_instruction))
print(all_src_encodec[0:2])
print(all_instruction[0:2])


num_iterations: 20
data_size_per_iteration: 1
sample_size: 10
beta: 0.1
ar_checkpoint: lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans
nar_checkpoint: lca0503/speech-chatgpt-base-nar-v2-epoch4-wotrans
args_predict: namespace(output_path='/work/b0990106x/trl/output/0725-1827/example.wav', seed=0, device='cuda')
model_output_dir: /work/b0990106x/trl/model_output/0725-1827
agent_output_dir: /work/b0990106x/trl/output/0725-1827
base_path: /work/b0990106x/trl
device: cuda
eval_data_len: 1
eval_selected_indices: [0]
Type of batch_src_encodec: <class 'list'>
Type of batch_instruction: <class 'list'>
[[[408, 835, 835, 798, 585, 550, 535, 535, 737, 737, 377, 556, 601, 787, 8, 99, 411, 411, 378, 937, 378, 937, 804, 838, 890, 934, 47, 438, 438, 731, 738, 133, 709, 479, 479, 479, 151, 940, 502, 906, 407, 645, 70, 208, 537, 537, 1022, 681, 723, 747, 593, 804, 681, 879, 136, 967, 233, 431, 754, 421, 182, 182, 651, 879, 887, 819, 904, 904, 887, 309, 880, 396, 754, 775, 997, 222, 336, 548, 841, 269,

In [6]:
import logging
# Set up logging
logging.basicConfig(
    filename=f'{model_output_dir}/log_training.log', 
    filemode='a', 
    format='%(asctime)s - %(levelname)s - %(message)s', 
    level=logging.INFO
)

logging.info(f"Parameters:")
logging.info(f"num_iterations: {num_iterations}")
logging.info(f"data_size_per_iteration: {data_size_per_iteration}")
logging.info(f"sample_size: {sample_size}")
logging.info(f"beta: {beta}")
logging.info(f"timestep: {ts}")

# Start time
total_start_time = time.time()

original_model_metrics, original_model_rewards = eval_dpo_mos(ar_checkpoint=ar_checkpoint,
                                                                nar_checkpoint=nar_checkpoint,
                                                                trained_model_checkpoint=ar_checkpoint, # original model
                                                                all_src_encodec=all_src_encodec,
                                                                all_instruction=all_instruction,
                                                                eval_data_len=eval_data_len,
                                                                selected_indices=eval_selected_indices,
                                                                device=device,
                                                                iteration = -1,
                                                                args_predict=args_predict)

logging.info(f"Original model metrics: {original_model_metrics}")
logging.info(f"Original model rewards: {original_model_rewards}")

for iteration in tqdm(range(num_iterations), desc="Training Iterations"):
    # if iteration == 0:
    #     start_idx = 0
    #     end_idx = initial_data_size
    # else:
    #     start_idx = initial_data_size + (iteration - 1) * data_size_per_iteration
    #     end_idx = start_idx + data_size_per_iteration
    
    start_idx = 0
    end_idx = data_size_per_iteration

    batch_src_encodec = all_src_encodec[start_idx:end_idx] # select 'data_size_per_iteration' datas
    batch_instruction = all_instruction[start_idx:end_idx]
    resume = iteration > 0 # resume from the previous checkpoint when iteration > 0

    logging.info(f"Starting iteration {iteration}")
    logging.info(f"Processing data from index {start_idx} to {end_idx}")
    
    # model_checkpoint is the model checkpoint from the previous iteration
    # chosen_rewards and rejected_rewards are the rewards of the data
    model_checkpoint, chosen_rewards, rejected_rewards = train_iteration(model_checkpoint=model_checkpoint,
                                       iteration=iteration,
                                       data_size=data_size_per_iteration,
                                       sample_size=sample_size,
                                       ar_checkpoint=ar_checkpoint,
                                       nar_checkpoint=nar_checkpoint,
                                       all_src_encodec=batch_src_encodec,
                                       all_instruction=batch_instruction,
                                       args_predict=args_predict,
                                       agent_output_dir=agent_output_dir,
                                       model_output_dir_base=model_output_dir,
                                       temperature = 1.0,
                                       beta=beta,
                                       base_path=base_path,
                                       resume_from_checkpoint=resume)
    
    
    logging.info(f"Chosen rewards for iteration {iteration}: {chosen_rewards}")
    logging.info(f"Rejected rewards for iteration {iteration}: {rejected_rewards}")
    logging.info(f"Finished training iteration {iteration}")

    # Evaluate the result of the current iteration
    trained_model_metrics, trained_model_rewards = eval_dpo_mos(ar_checkpoint=ar_checkpoint,
                                                                nar_checkpoint=nar_checkpoint,
                                                                trained_model_checkpoint=model_checkpoint,
                                                                all_src_encodec=all_src_encodec,
                                                                all_instruction=all_instruction,
                                                                eval_data_len=eval_data_len,
                                                                selected_indices=eval_selected_indices,
                                                                device=device,
                                                                iteration = iteration,
                                                                args_predict=args_predict)
    
    logging.info(f"Evaluation metrics for iteration {iteration}: {trained_model_metrics}")
    logging.info(f"Evaluation rewards for iteration {iteration}: {trained_model_rewards}")

total_end_time = time.time()

# Calculate total time taken
total_time_taken = total_end_time - total_start_time
logging.info(f"Total time taken for the entire process: {total_time_taken:.2f} seconds")



Episode eval_-1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_-1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_-1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_-1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_-1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_-1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_-1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_-1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_-1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_-1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_-1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_-1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_-1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_-1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_-1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_-1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_-1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_-1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_-1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_-1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_1 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_1.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_2 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_2.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_3 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_3.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_4 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_4.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_5 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_5.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_6 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_6.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_7 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_7.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_8 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_8.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_9 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_9.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)

Processing Samples: 100%|██████████| 10/10 [00:52<00:00,  5.29s/it]
Processing Data: 100%|██████████| 1/1 [00:52<00:00, 52.94s/it]
Map: 100%|██████████| 1/1 [00:00<00:00, 129.77 examples/s]
Map: 100%|██████████| 1/1 [00:00<00:00, 147.61 examples/s]
[34m[1mwandb[0m: Currently logged in as: [33mb09901066[0m ([33mb09901066_alan[0m). Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss


Some weights of the model checkpoint at /work/b0990106x/trl/model_output/0725-1827/iter_0/dpo_model were not used when initializing BartForConditionalGeneration: ['v_head.summary.weight', 'v_head.summary.bias']
- This IS expected if you are initializing BartForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Episode eval_0_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_0_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_0_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_0_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_0_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_0_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_0_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_0_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_0_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_0_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_0_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_0_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_0_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_0_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_0_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_0_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_0_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_0_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_0_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_0_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
Training Iterations:   5%|▌         | 1/20 [02:57<56:18, 177.82s/it]Some weights of the model checkpoint at /work/b0990106x/trl/model_output/0725-1827/iter_0/dpo_model were not used when initializing BartForConditionalGeneration: ['v_head.summary.weight', 'v_head.summary.bias']
- This IS expected if you are initializing BartForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).



Episode data_0_episode_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_1 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_1.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_2 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_2.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_3 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_3.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_4 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_4.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_5 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_5.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_6 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_6.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_7 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_7.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_8 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_8.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_9 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_9.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)

Processing Samples: 100%|██████████| 10/10 [00:56<00:00,  5.66s/it]
Processing Data: 100%|██████████| 1/1 [00:56<00:00, 56.60s/it]
Some weights of the model checkpoint at /work/b0990106x/trl/model_output/0725-1827/iter_0/dpo_model were not used when initializing BartForConditionalGeneration: ['v_head.summary.weight', 'v_head.summary.bias']
- This IS expected if you are initializing BartForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Map: 100%|██████████| 1/1 [00:00<00:00, 126.11 examples/s]
Map: 100%|██████████| 1/1 [00:00<00:00, 142.49 exam

Step,Training Loss


Some weights of the model checkpoint at /work/b0990106x/trl/model_output/0725-1827/iter_1/dpo_model were not used when initializing BartForConditionalGeneration: ['v_head.summary.weight', 'v_head.summary.bias']
- This IS expected if you are initializing BartForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Episode eval_1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_1_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_1_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
Training Iterations:  10%|█         | 2/20 [06:13<56:29, 188.29s/it]Some weights of the model checkpoint at /work/b0990106x/trl/model_output/0725-1827/iter_1/dpo_model were not used when initializing BartForConditionalGeneration: ['v_head.summary.weight', 'v_head.summary.bias']
- This IS expected if you are initializing BartForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).



Episode data_0_episode_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_1 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_1.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_2 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_2.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_3 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_3.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_4 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_4.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_5 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_5.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_6 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_6.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_7 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_7.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_8 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_8.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)



Episode data_0_episode_9 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_data_0_episode_9.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)

Processing Samples: 100%|██████████| 10/10 [01:15<00:00,  7.50s/it]
Processing Data: 100%|██████████| 1/1 [01:15<00:00, 75.03s/it]
Some weights of the model checkpoint at /work/b0990106x/trl/model_output/0725-1827/iter_1/dpo_model were not used when initializing BartForConditionalGeneration: ['v_head.summary.weight', 'v_head.summary.bias']
- This IS expected if you are initializing BartForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Map: 100%|██████████| 1/1 [00:00<00:00, 84.26 examples/s]
Map: 100%|██████████| 1/1 [00:00<00:00, 106.19 examp

Step,Training Loss


Some weights of the model checkpoint at /work/b0990106x/trl/model_output/0725-1827/iter_2/dpo_model were not used when initializing BartForConditionalGeneration: ['v_head.summary.weight', 'v_head.summary.bias']
- This IS expected if you are initializing BartForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Episode eval_2_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_2_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_2_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_2_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_2_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_2_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_2_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_2_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Episode eval_2_data_0 : audio saved to  /work/b0990106x/trl/output/0725-1827/example_save_eval_2_data_0.wav


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
