# Controllable generation via RL about text-guided voice conversion


In [1]:
import torch
from datasets import load_from_disk
from vc.encodec_model.nar_bart_model import NARBartForConditionalGeneration
from transformers import AutoTokenizer, BartForConditionalGeneration

# load the model
ar_checkpoint = "lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans"
nar_checkpoint = "lca0503/speech-chatgpt-base-nar-v2-epoch4-wotrans"

device = "cuda" if torch.cuda.is_available() else "cpu"
ar_tokenizer = AutoTokenizer.from_pretrained(ar_checkpoint)
ar_model = BartForConditionalGeneration.from_pretrained(ar_checkpoint)
nar_tokenizer = AutoTokenizer.from_pretrained(nar_checkpoint)
nar_model = NARBartForConditionalGeneration.from_pretrained(nar_checkpoint)
ar_model.to(device)

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(59481, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(59481, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=

In [2]:
from datetime import datetime
import os

now = datetime.now()
ts = now.strftime("%m%d-%H%M")
print("timestamp:", ts)

# define the path
base_path = "/work/b0990106x/TextRL"
agent_input_dir = f"{base_path}/data-encodec"
agent_output_dir = f"{base_path}/output/{ts}"
env_input_dir = agent_output_dir
env_output_dir = agent_input_dir

if not os.path.exists(agent_output_dir):
    os.makedirs(agent_output_dir)

timestamp: 0509-0100


In [3]:
# load the dataset
dataset = load_from_disk(agent_input_dir)

In [4]:
all_src_encodec_layers = []
all_src_encodec = []
all_instruction = []
# all_instruction_ids = []

layer_len = 8
data_len = 3
# data_len = len(dataset)
print("data_len:", data_len)

for i in range(layer_len):
    all_src_encodec_layers.append(dataset[f"src_encodec_{i}"])

for i in range(data_len):
    src_encodec = []
    for j in range(layer_len):
        src_encodec.append(all_src_encodec_layers[j][i])
    all_src_encodec.append(src_encodec)

    all_instruction.append(dataset["instruction"][i])
    # all_instruction_ids.append(ar_tokenizer(all_instruction[i])["input_ids"][1 : -1])

data_len: 3


In [5]:
# print the length of all src encodec
for i in range(data_len):
    print(f"src_encodec_{i} len:", len(all_src_encodec[i][0]))

src_encodec_0 len: 327
src_encodec_1 len: 336
src_encodec_2 len: 131


### Debugging Section

In [6]:
# redefine the path (one can run the code from here when the model is already loaded)
now = datetime.now()
ts = now.strftime("%m%d-%H%M")
print("timestamp:", ts)

agent_output_dir = f"{base_path}/output/{ts}"
env_input_dir = agent_output_dir


timestamp: 0509-0100


In [7]:
from importlib import reload
import textrl
reload(textrl)

from textrl import TextRLEnv, TextRLActor
from NISQA.nisqa.NISQA_model import nisqaModel

import sys
sys.path.append("/work/b0990106x/TextRL/vc") 
from vc.trainer_encodec_vc_inference import get_ar_prediction_v2

class MyRLEnv(TextRLEnv):
    def get_reward(self, _, predicted_list, finish):
        reward = 0
        if finish or len(predicted_list[0]) >= self.env_max_length:
        # if finish or len(predicted_list) >= self.env_max_length:
            print("Length of predicted_list:", len(predicted_list))
            print("predicted_list:", predicted_list)
            reward = len(predicted_list[0])
            # try:
            #     predicted_tokens = predicted_list[0][1:-1]
            #     predicted_ids = self.tokenizer.convert_tokens_to_ids([f"{u}" for u in predicted_tokens])
            #     # print("predicted_ids:", predicted_ids)

            #     decode_ar = get_ar_prediction_v2(
            #         self.args_predict,
            #         predicted_ids,
            #         self.nar_model,
            #         self.tokenizer,
            #         self.nar_tokenizer,
            #         self.single_src_encodec,
            #         self.single_instruction,
            #         self.episode_counter,
            #     )
            #     # print("decode_ar:", decode_ar)
                
            #     # use nisqa to get the reward
            #     args_nisqa = {
            #         "mode": "predict_file",
            #         "pretrained_model": f"{base_path}/NISQA/weights/nisqa.tar",
            #         "deg": f"{base_path}/output/{ts}/example.wav",
            #         "data_dir": None,
            #         "output_dir": f"{base_path}/NISQA/result/",
            #         "csv_file": None,
            #         "csv_deg": None,
            #         "num_workers": 0,
            #         "bs": 1,
            #         "ms_channel": None,
            #     }
            #     args_nisqa["tr_bs_val"] = args_nisqa["bs"]
            #     args_nisqa["tr_num_workers"] = args_nisqa["num_workers"]

            #     nisqa = nisqaModel(args_nisqa)
            #     prediction = nisqa.predict()
            #     reward = float(prediction["mos_pred"].iloc[0])*10
            #     # reward = float(prediction["mos_pred"].iloc[0])-3.0
            #     print(
            #         "Length of predicted_list:",
            #         len(predicted_list[0]),
            #         ", Reward:",
            #         reward,
            #     )

            # except Exception as e:
            #     print("Error:", e)
            #     reward = 0

        return reward

In [8]:
observation_list = []
for i in range(data_len):
    observation_list.append(
        {
            "input": "",
            "src_encodec": all_src_encodec[i],
            "instruction": all_instruction[i],
        }
    )
    # print("src_encodec:", observation_list[i]["src_encodec"][0])
    # print("instruction:", all_instruction[i])
    
# pop the first one
observation_list.pop(0)
all_instruction.pop(0)
observation_list.pop(0)
all_instruction.pop(0)
print("observation_list:", observation_list)
print("all_instruction:", all_instruction)

# for i in range(data_len):
#     observation_list.append({'input': "", 'src_encodec': all_src_encodec[i], 'instruction': all_instruction[i]})

observation_list: [{'input': '', 'src_encodec': [[835, 339, 999, 629, 604, 462, 314, 600, 846, 562, 846, 358, 984, 393, 182, 453, 584, 535, 407, 1021, 701, 843, 945, 495, 563, 495, 495, 727, 317, 604, 475, 835, 835, 835, 339, 475, 339, 123, 254, 103, 561, 858, 646, 755, 375, 548, 435, 233, 323, 395, 819, 475, 339, 835, 779, 257, 339, 341, 170, 38, 38, 103, 408, 62, 141, 731, 73, 651, 143, 875, 321, 310, 310, 972, 679, 582, 808, 813, 808, 291, 722, 982, 627, 192, 764, 531, 291, 466, 567, 601, 771, 112, 688, 348, 793, 793, 11, 192, 23, 983, 1022, 23, 73, 73, 276, 537, 103, 53, 148, 148, 148, 463, 176, 148, 463, 463, 463, 463, 463, 463, 463, 433, 25, 472, 257, 228, 395, 133, 395, 475, 126], [646, 841, 168, 1023, 277, 820, 278, 215, 58, 592, 607, 607, 349, 346, 504, 632, 482, 14, 968, 588, 529, 904, 662, 662, 602, 1013, 662, 386, 617, 870, 648, 1023, 277, 277, 913, 200, 1007, 503, 807, 144, 132, 558, 984, 164, 610, 66, 830, 925, 744, 129, 87, 648, 391, 646, 424, 700, 646, 713, 702, 443, 4,

In [9]:
# print("observation_list:", observation_list)

In [10]:
from types import SimpleNamespace

args_predict = SimpleNamespace(
    output_path=f"{base_path}/output/{ts}/example.wav", seed=0, device="cuda"
)

env = MyRLEnv(
    ar_model,
    ar_tokenizer,
    nar_model,
    nar_tokenizer,
    args_predict,
    observation_input=observation_list,
    compare_sample=1,
)
actor = TextRLActor(env = env, model = ar_model, tokenizer = ar_tokenizer)
# agent = actor.agent_ppo(update_interval=1800, minibatch_size=256, epochs=10, lr=3e-8)
# agent = actor.agent_ppo(update_interval=1200, minibatch_size=128, epochs=10)
# agent = actor.agent_ppo(update_interval=1000, minibatch_size=128, epochs=10, lr=3e-8)
update_interval = 1000
minibatch_size = 512
epochs = 10
lr = 0.001
# agent = actor.agent_ppo(update_interval=1000, minibatch_size=512, epochs=10, lr=0.001)
# agent = actor.agent_ppo(update_interval=1000, minibatch_size=512, epochs=10, lr=0.01, entropy_coef=0.1)
# agent = actor.agent_ppo(update_interval=1000, minibatch_size=2048, epochs=1, lr=0.01, entropy_coef=0.1)
# agent = actor.agent_ppo(update_interval=100, minibatch_size=1024, epochs=1, lr=0.05, entropy_coef=0.5)
# agent = actor.agent_ppo(update_interval=100, minibatch_size=1024, epochs=1, lr=0.01, entropy_coef=0.5)
# agent = actor.agent_ppo(update_interval=2048, minibatch_size=32, epochs=1, lr=0.01, entropy_coef=0.1)
agent = actor.agent_ppo(update_interval=2048, minibatch_size=32, epochs=1, lr=0.01)

model name:  BartForConditionalGeneration
----------------------------- reset -----------------------------


In [11]:
import logging
import os
import sys

output_log_path = f"log/log_{ts}.log"
output_file_path = f"log/output_{ts}.txt"

if not os.path.exists("log"):
    os.makedirs("log")

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

handlers = logger.handlers[:]
for handler in handlers:
    logger.removeHandler(handler)

file_handler = logging.FileHandler(output_log_path)
logger.addHandler(file_handler)

In [12]:
import sys
import time
import pfrl

start_time = time.time()
pfrl_outdir = f"ckpt/train_{ts}"

with open(output_file_path, "w") as f:
    original_stdout = sys.stdout
    sys.stdout = f
    print(f"update_interval = {update_interval}, minibatch_size = {minibatch_size}, epochs = {epochs}, lr = {lr}")
    pfrl.experiments.train_agent_with_evaluation(
        agent,
        env,
        steps=1000000,
        eval_n_steps=None,
        eval_n_episodes=2,
        train_max_episode_len=10000,
        eval_interval=1000,
        outdir=pfrl_outdir,
        logger=logger,
        use_tensorboard=True,
        checkpoint_freq=5000,
    )
    sys.stdout = original_stdout   
    
print("Output has been written to", output_file_path)
print("used time: ", time.time() - start_time)


  actions = torch.tensor([b["action"] for b in dataset], device=device)


In [None]:
agent.load(pfrl_outdir + "/best")
actor.predict(observation_list[0])

In [None]:
print(ts)

In [None]:
# import read_pickle
# ts = "0430-1250"
# # len of the pickle file
# dir_path = f"{base_path}/replay_buffer/{ts}"
# length = len(os.listdir(dir_path))
# for i in range(length):
#     pickle_file = f"replay_buffer_update_{i}.pkl"
#     pickle_data = read_pickle.load_pickle(file_path = f'{base_path}/replay_buffer/{ts}/{pickle_file}')
#     length_of_pickle_data = len(pickle_data[0])
#     length_of_pickle_data1 = len(pickle_data[1])
#     if len(pickle_data)>= 3:
#         length_of_pickle_data2 = len(pickle_data[2])
#         print(f"length of {pickle_file}: {length_of_pickle_data} and {length_of_pickle_data1} and {length_of_pickle_data2}")
#     else:
#         print(f"length of {pickle_file}: {length_of_pickle_data} and {length_of_pickle_data1}") 
 

In [None]:
# import tensorflow as tf
# print(tf.__version__)