In [None]:
def generate_output(
        ar_model, 
        nar_model, 
        ar_tokenizer, 
        nar_tokenizer, 
        src_encodec: list, 
        instruction: list, 
        args_predict: SimpleNamespace, 
        episode_counter: int = 0, 
        base_path: str = "/work/b0990106x/trl", 
        temperature: float = 1.0
) -> tuple[float, str]:
    '''
    Generates output from AR model, synthesize the audio, and evaluate the audio using NISQA.
    Returns:
        tuple:
            reward(float): The reward of the audio.
            tokenized_decode_ar(str): The tokenized output of the AR model - first layer.
    '''
    # Generate predictions using the AR model
    decode_ar, wav = get_ar_prediction_get_audio(
        args_predict, ar_model, nar_model, ar_tokenizer, nar_tokenizer, src_encodec, instruction, episode_counter, temperature=temperature
    )
    # extract the instruction from the list 

    tensor_wav = convert_array_to_tensor_format(wav)
    if tensor_wav[0].shape[0]==1:
        tensor_wav[0] = tensor_wav[0].squeeze(0)

    reward = get_reward_claps(prompts = instruction, wavs = tensor_wav)
    
    list_decode_ar = decode_ar.flatten().tolist()   
    filtered_decode_ar_list = list_decode_ar[2:-1]
    decode_ar_tokens = ar_tokenizer.convert_ids_to_tokens(filtered_decode_ar_list)
    tokenized_decode_ar = ar_tokenizer.convert_tokens_to_string(decode_ar_tokens)
    print(f"REWARD: {reward}")

    return reward, tokenized_decode_ar

In [None]:
def process_data(sample_size: int, 
                 ar_model, 
                 nar_model, 
                 ar_tokenizer, 
                 nar_tokenizer, 
                 all_src_encodec: List[list], 
                 all_instruction: List[str],
                 args_predict: SimpleNamespace, 
                 base_path: str = "/work/b0990106x/trl", 
                 temperature: float = 1.0, 
                 iteration: int = 0
) -> Tuple[List[str], List[str], List[str], List[float], List[float], List[float]]:
    """
    Process data to generate outputs, calculate rewards, and organize chosen and rejected data.
    Returns:
        tuple:
            chosen (List[str]): A list of chosen outputs.
            rejected (List[str]): A list of rejected outputs.
            prompts (List[str]): A list of prompts.
            chosen_rewards (List[float]): A list of rewards for the chosen outputs.
            rejected_rewards (List[float]): A list of rewards for the rejected outputs.
            average_rewards (List[float]): A list of average rewards.
    """
    # If sample size is 1, we cannot choose the best and worst outputs
    if sample_size < 2:
        raise ValueError("Parameter 'sample_size' must be greater than 1.")

    chosen, rejected, prompts, chosen_rewards, rejected_rewards, average_rewards = [], [], [], [], [], []

    for i in tqdm(range(len(all_src_encodec)), desc="Processing Data"):
        rewards, tokenized_outputs = [], []

        for j in tqdm(range(sample_size), desc="Processing Samples"):
            size_of_packed_input = (
                len(all_src_encodec[i][0]) +
                len(ar_tokenizer(all_instruction[i])["input_ids"][1:-1]) +
                3
            )
            if 4 < size_of_packed_input <= 1024:
                set_seed(42+iteration+j)
                reward, tokenized_decode_ar = generate_output(
                    ar_model=ar_model, 
                    nar_model=nar_model, 
                    ar_tokenizer=ar_tokenizer, 
                    nar_tokenizer=nar_tokenizer,
                    src_encodec = all_src_encodec[i],
                    instruction=all_instruction[i], 
                    args_predict=args_predict,
                    episode_counter=f"data_{i}_episode_{j}",
                    base_path=base_path, 
                    temperature=temperature
                )
                rewards.append(reward)
                tokenized_outputs.append(tokenized_decode_ar)


        valid_rewards = [r for r in rewards if r is not None]
        valid_outputs = [tokenized_outputs[j] for j in range(len(rewards)) if rewards[j] is not None]

        if len(valid_rewards) >= 2:
            max_reward_index = np.argmax(valid_rewards)
            min_reward_index = np.argmin(valid_rewards)
            average_reward = np.mean(valid_rewards)
            chosen_output = valid_outputs[max_reward_index]
            rejected_output = valid_outputs[min_reward_index]

            obs_input = pack_inputs_v2(ar_tokenizer, all_src_encodec[i], all_instruction[i])
            tokenize_input = ar_tokenizer.convert_ids_to_tokens(obs_input)
            tokenize_input_str = ar_tokenizer.convert_tokens_to_string(tokenize_input)
            prompts.append(tokenize_input_str)

            chosen.append(chosen_output)
            chosen_rewards.append(valid_rewards[max_reward_index])
            rejected.append(rejected_output)
            rejected_rewards.append(valid_rewards[min_reward_index])
            average_rewards.append(average_reward)
        else:
            print(f"Not enough valid rewards for data index {i}")

    # If there is only one data, we need to double the data because we need it for training set and validation set
    if len(all_src_encodec) == 1:
        chosen *= 2
        rejected *= 2
        prompts *= 2
        chosen_rewards *= 2
        rejected_rewards *= 2
        average_rewards *= 2    
    
    return chosen, rejected, prompts, chosen_rewards, rejected_rewards, average_rewards
