# Controllable generation via RL about text-guided voice conversion


In [None]:
import torch
from datasets import load_from_disk
from vc.encodec_model.nar_bart_model import NARBartForConditionalGeneration
from transformers import AutoTokenizer, BartForConditionalGeneration
import sys
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, create_reference_model
from tqdm import tqdm

# load the model
ar_checkpoint = "lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans"
nar_checkpoint = "lca0503/speech-chatgpt-base-nar-v2-epoch4-wotrans"

device = "cuda" if torch.cuda.is_available() else "cpu"
# model = AutoModelForCausalLMWithValueHead.from_pretrained(ar_checkpoint)
# model = BartForConditionalGeneration.from_pretrained(ar_checkpoint)
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(ar_checkpoint)
model_ref = create_reference_model(model)
tokenizer = AutoTokenizer.from_pretrained(ar_checkpoint)
nar_tokenizer = AutoTokenizer.from_pretrained(nar_checkpoint)
nar_model = NARBartForConditionalGeneration.from_pretrained(nar_checkpoint)
tokenizer.pad_token = tokenizer.eos_token

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
from datetime import datetime
import os

now = datetime.now()
ts = now.strftime("%m%d-%H%M")
print("timestamp:", ts)

# define the path
base_path = "/work/b0990106x/trl"
agent_input_dir = f"{base_path}/data-encodec"
agent_output_dir = f"{base_path}/output/{ts}"
env_input_dir = agent_output_dir
env_output_dir = agent_input_dir

if not os.path.exists(agent_output_dir):
    os.makedirs(agent_output_dir)

In [None]:
# load the dataset
dataset = load_from_disk(agent_input_dir)

In [None]:
all_src_encodec_layers = []
all_src_encodec = []
all_instruction = []
# all_instruction_ids = []

layer_len = 8
data_len = 3
# data_len = len(dataset)
print("data_len:", data_len)

for i in range(layer_len):
    all_src_encodec_layers.append(dataset[f"src_encodec_{i}"])

for i in range(data_len):
    src_encodec = []
    for j in range(layer_len):
        src_encodec.append(all_src_encodec_layers[j][i])
    all_src_encodec.append(src_encodec)

    all_instruction.append(dataset["instruction"][i])
    # all_instruction_ids.append(ar_tokenizer(all_instruction[i])["input_ids"][1 : -1])

In [None]:
# print the length of all src encodec
for i in range(data_len):
    print(f"src_encodec_{i} len:", len(all_src_encodec[i][0]))

### Debugging Section

In [None]:
observation_list = []
for i in range(data_len):
    observation_list.append(
        {
            "input": "",
            "src_encodec": all_src_encodec[i],
            "instruction": all_instruction[i],
        }
    )

# # pop the first one
observation_list.pop(0)
all_instruction.pop(0)
observation_list.pop(0)
all_instruction.pop(0)
print("observation_list:", observation_list)
print("all_instruction:", all_instruction)


# for i in range(data_len):
#     observation_list.append({'input': "", 'src_encodec': all_src_encodec[i], 'instruction': all_instruction[i]})

In [None]:
import sys
sys.path.append("/work/b0990106x/TextRL/vc")
from vc.trainer_encodec_vc_inference import get_ar_prediction
from types import SimpleNamespace


args_predict = SimpleNamespace(output_path=f"{base_path}/output/{ts}/example.wav", seed=0, device=device)

decode_ar = get_ar_prediction(args_predict, model, nar_model, tokenizer, nar_tokenizer, all_src_encodec[0], all_instruction[0], 0)

decode_ar_str = tokenizer.convert_tokens_to_string(
                [f"v_tok_{u}" for u in decode_ar]
            )
print("Decode AR:", decode_ar)
print("Decode AR str: ", decode_ar_str)

In [None]:
from datetime import datetime
import os

now = datetime.now()
ts = now.strftime("%m%d-%H%M")
print("timestamp:", ts)
log_dir = f"logs/{ts}"
os.makedirs(log_dir, exist_ok=True)

lr= 0.0000141
batch_size = 1
mini_batch_size = 1


ppo_config = PPOConfig(batch_size=1, mini_batch_size=1, log_with='tensorboard', learning_rate=lr, project_kwargs={'logging_dir': log_dir})
ppo_trainer = PPOTrainer(config = ppo_config, model = model, ref_model=model_ref, tokenizer=tokenizer)

In [None]:
from importlib import reload
from NISQA.nisqa.NISQA_model import nisqaModel

import sys
sys.path.append("/work/b0990106x/trl/vc") 
from vc.trainer_encodec_vc_inference import get_ar_prediction_v2


def get_reward(predicted_list, single_src_encodec, single_instruction, episode_counter,finish):
    reward = 0
    # predicted_list will be one text of "v_tok_410v_tok_411v_tok_595 ...""
    # predicted_token will be a list of [v_tok_410, v_tok_411, v_tok_595 ...]
    
    if finish or len(predicted_list) >= 1000:
        try:
            # predicted_tokens = predicted_list[0][1:-1]
            predicted_tokens = [f'v_tok_{u}' for u in predicted_list.split("v_tok_")[1:]]
            predicted_ids = tokenizer.convert_tokens_to_ids([f"{u}" for u in predicted_tokens])
            print("predict length: ", len(predicted_ids))
            print("predicted_tokens: ", predicted_tokens)
            print("predicted_ids: ", predicted_ids)

            decode_ar = get_ar_prediction_v2(
                args_predict,
                predicted_ids,
                nar_model,
                tokenizer,
                nar_tokenizer,
                single_src_encodec,
                single_instruction,
                episode_counter,
            )
            # print("decode_ar:", decode_ar)
            
            # use nisqa to get the reward
            args_nisqa = {
                "mode": "predict_file",
                "pretrained_model": f"{base_path}/NISQA/weights/nisqa.tar",
                "deg": f"{base_path}/output/{ts}/example.wav",
                "data_dir": None,
                "output_dir": f"{base_path}/NISQA/result/",
                "csv_file": None,
                "csv_deg": None,
                "num_workers": 0,
                "bs": 1,
                "ms_channel": None,
            }
            args_nisqa["tr_bs_val"] = args_nisqa["bs"]
            args_nisqa["tr_num_workers"] = args_nisqa["num_workers"]

            nisqa = nisqaModel(args_nisqa)
            prediction = nisqa.predict()
            reward = float(prediction["mos_pred"].iloc[0])*10
            # reward = float(prediction["mos_pred"].iloc[0])-3.0
            print(
                "Length of predicted_list:",
                len(predicted_list),
                ", Reward:",
                reward,
            )

        except Exception as e:
            print("Error:", e)
            reward = 0

    return reward

In [None]:
# import logging
# import os
# import sys

# output_log_path = f"logs/log_{ts}.log"

# logger = logging.getLogger(__name__)
# logger.setLevel(logging.DEBUG)

# handlers = logger.handlers[:]
# for handler in handlers:
#     logger.removeHandler(handler)

# file_handler = logging.FileHandler(output_log_path)
# logger.addHandler(file_handler)

In [None]:
import time
from trl.core import respond_to_batch

start_time = time.time()
output_file_path = f"logs/{ts}/output_{ts}.txt"

# with open(output_file_path, "w") as f:
#     original_stdout = sys.stdout
#     sys.stdout = f
try:
    for iteration in tqdm(range(100)):
        query_txt = decode_ar_str
        query_tensor = tokenizer.encode(query_txt, return_tensors="pt")
        query_tensor = query_tensor.to(device)
        
        # FILEPATH: /work/b0990106x/trl/textrl.ipynb
        response_tensor = respond_to_batch(model, query_tensor, txt_len=2000)
        # print("response_tensor:", response_tensor)
        # response_tensor = model.generate(query_tensor)
        response_text = tokenizer.decode(response_tensor[0], skip_special_tokens=True)
        # Mimic batch structure
        batch = {
            "query": query_tensor,
            "response": response_text
        }
        reward_float = get_reward(response_text, all_src_encodec[0], all_instruction[0], iteration, True)
        reward_length = len(tokenizer.decode(response_tensor[0], skip_special_tokens=True))
        reward = torch.tensor([float(reward_float)], device=device)
        
        train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], [reward])
        ppo_trainer.log_stats(train_stats, batch, reward)

        print(f"Iteration {iteration + 1}, Reward: {reward.item()}, Length: {len(response_tensor[0])}, Reward_Length: {reward_length}, Predicted Text: {response_text}")

except Exception as e:
    print("An error occurred:", e)
    
    # sys.stdout = original_stdout   

print("used time: ", time.time() - start_time)