# Controllable generation via RL to let Elon Musk speak ill of DOGE
> How to control text generation through a sentiment classifier.



In [1]:
# %pip install pfrl@git+https://github.com/voidful/pfrl.git
# %pip install textrl==0.2.15

In [2]:
import torch
from datasets import load_from_disk
from vc.encodec_model.nar_bart_model import NARBartForConditionalGeneration
from transformers import (AutoTokenizer, BartForConditionalGeneration)
from textrl import TextRLEnv,TextRLActor
import logging
import sys
import pfrl
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')

# define path
base_path = '/work/b0990106x/TextRL'
agent_input_dir = f'{base_path}/data-encodec'
agent_output_dir = f'{base_path}/output'
env_input_dir = agent_output_dir
env_output_dir = agent_input_dir

ar_checkpoint = "lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans"
nar_checkpoint = "lca0503/speech-chatgpt-base-nar-v2-epoch4-wotrans"

device = "cuda" if torch.cuda.is_available() else "cpu"
ar_tokenizer = AutoTokenizer.from_pretrained(ar_checkpoint)
ar_model = BartForConditionalGeneration.from_pretrained(ar_checkpoint)
nar_tokenizer = AutoTokenizer.from_pretrained(nar_checkpoint)
nar_model = NARBartForConditionalGeneration.from_pretrained(nar_checkpoint)
ar_model.to(device)

dataset = load_from_disk(agent_input_dir)
source = dataset[f"src_encodec_0"][0]
instruction = dataset["instruction"][0]
transcription = dataset["transcription"][0]
instruction_ids = ar_tokenizer(instruction)["input_ids"][1 : -1]
transcription_ids = ar_tokenizer(transcription)["input_ids"][1 : -1]
src_encodec_ids = ar_tokenizer.convert_tokens_to_ids(
    [f"v_tok_{u}" for u in dataset[f"src_encodec_0"][0]])
src_encodec_str = ar_tokenizer.convert_tokens_to_string(
    [f"v_tok_{u}" for u in dataset[f"src_encodec_0"][0]])


In [3]:
# prepare data
all_src_encodec_layers = []
all_src_encodec = []
all_instruction = []
all_instruction_ids = []

# data_len = len(dataset)
data_len = 16
layer_len = 8

for i in range(layer_len):
    all_src_encodec_layers.append(dataset[f"src_encodec_{i}"])

for i in range(data_len):
    src_encodec = []
    for j in range(layer_len):
        src_encodec.append(all_src_encodec_layers[j][i])
    all_src_encodec.append(src_encodec)

for i in range(data_len):
    all_instruction.append(dataset["instruction"][i])
    all_instruction_ids.append(ar_tokenizer(all_instruction[i])["input_ids"][1 : -1])
    

In [4]:
# run voice conversion model to get the target speech
import sys
sys.path.append('/work/b0990106x/TextRL/vc')
from vc.trainer_encodec_vc_inference import get_ar_prediction
from types import SimpleNamespace

args_predict = SimpleNamespace(
    output_path = "/work/b0990106x/TextRL/output/example.wav",
    seed = 0,
    device = "cuda"
)    

single_src_encodec = all_src_encodec[0]
single_instruction = all_instruction[0]
print("single_src_encodec: ", single_src_encodec)
print("single_instruction: ", single_instruction)

decode_ar = get_ar_prediction(args_predict, ar_model, nar_model, ar_tokenizer, nar_tokenizer, single_src_encodec, single_instruction)
decode_ar_ids = ar_tokenizer.convert_tokens_to_ids(
    [f"v_tok_{u}" for u in decode_ar])
decode_ar_str = ar_tokenizer.convert_tokens_to_string(
    [f"v_tok_{u}" for u in decode_ar])
    
print("decode_ar: ", decode_ar)
print("decode_ar_ids: ", decode_ar_ids)
print("decode_ar_str: ", decode_ar_str)

single_src_encodec:  [[408, 835, 835, 798, 585, 550, 535, 535, 737, 737, 377, 556, 601, 787, 8, 99, 411, 411, 378, 937, 378, 937, 804, 838, 890, 934, 47, 438, 438, 731, 738, 133, 709, 479, 479, 479, 151, 940, 502, 906, 407, 645, 70, 208, 537, 537, 1022, 681, 723, 747, 593, 804, 681, 879, 136, 967, 233, 431, 754, 421, 182, 182, 651, 879, 887, 819, 904, 904, 887, 309, 880, 396, 754, 775, 997, 222, 336, 548, 841, 269, 479, 479, 940, 23, 56, 738, 835, 395, 206, 779, 531, 862, 931, 306, 203, 755, 369, 6, 466, 716, 948, 82, 575, 288, 556, 903, 556, 392, 796, 751, 835, 103, 25, 408, 835, 835, 339, 339, 395, 250, 706, 317, 479, 800, 960, 141, 479, 908, 801, 327, 937, 559, 708, 372, 372, 573, 437, 437, 421, 203, 739, 830, 739, 358, 830, 248, 411, 411, 112, 321, 23, 23, 185, 971, 62, 339, 461, 488, 934, 148, 373, 561, 681, 760, 531, 612, 699, 23, 967, 457, 790, 154, 906, 465, 502, 884, 479, 246, 820, 601, 309, 716, 314, 377, 309, 309, 556, 118, 99, 358, 1018, 862, 779, 62, 835, 25, 254, 254, 677



Write audio to  /work/b0990106x/TextRL/output/example.wav
decode_ar:  [835, 835, 798, 585, 550, 535, 535, 737, 737, 377, 556, 601, 787, 8, 99, 411, 411, 378, 937, 378, 937, 804, 838, 890, 934, 47, 438, 438, 731, 738, 133, 709, 479, 479, 479, 151, 940, 502, 906, 407, 645, 70, 208, 537, 537, 1022, 681, 723, 747, 593, 804, 681, 879, 136, 967, 233, 431, 754, 421, 182, 182, 651, 879, 887, 819, 904, 904, 887, 309, 880, 396, 754, 775, 997, 222, 336, 548, 841, 269, 479, 479, 940, 23, 56, 738, 835, 395, 206, 779, 531, 862, 931, 306, 203, 755, 369, 6, 466, 716, 948, 82, 575, 288, 556, 903, 556, 392, 796, 751, 835, 103, 25, 408, 835, 835, 339, 339, 395, 250, 706, 317, 479, 800, 960, 141, 479, 908, 801, 327, 937, 559, 708, 372, 372, 573, 437, 437, 421, 203, 739, 830, 739, 358, 830, 248, 411, 411, 112, 321, 23, 23, 185, 971, 62, 339, 461, 488, 934, 148, 373, 561, 681, 760, 531, 612, 699, 23, 967, 457, 790, 154, 906, 465, 502, 884, 479, 246, 820, 601, 309, 716, 314, 377, 309, 309, 556, 118, 99, 358,

In [5]:
# demo how the tokenization works

# source speech before tokenization
print('source: ', source)
print('size of source: ', len(source))
print('src_encodec_ids: ', src_encodec_ids)
print('size of src_encodec_ids: ', len(src_encodec_ids))
print('src_encodec_str: ', src_encodec_str)
print('size of src_encodec_str: ', len(src_encodec_str))
# source speech after tokenization
tokens = ar_tokenizer.convert_ids_to_tokens(src_encodec_ids)
ids = ar_tokenizer.convert_tokens_to_ids(tokens)
print('ar_tokenizer.convert_ids_to_tokens(src_encodec_ids): ', tokens)
print('ar_tokenizer.convert_tokens_to_ids(tokens): ', ids)
print('ar_tokenizer.convert_tokens_to_ids(tokens): ', ids)
print('size of ar_tokenizer.convert_ids_to_tokens(src_encodec_ids): ', len(tokens))
# instruction before tokenization
print(instruction)
# instruction after tokenization
print(ar_tokenizer.convert_ids_to_tokens(instruction_ids))
# transcription before tokenization
print(transcription)
# transcription after tokenization
print(ar_tokenizer.convert_ids_to_tokens(transcription_ids))



source:  [408, 835, 835, 798, 585, 550, 535, 535, 737, 737, 377, 556, 601, 787, 8, 99, 411, 411, 378, 937, 378, 937, 804, 838, 890, 934, 47, 438, 438, 731, 738, 133, 709, 479, 479, 479, 151, 940, 502, 906, 407, 645, 70, 208, 537, 537, 1022, 681, 723, 747, 593, 804, 681, 879, 136, 967, 233, 431, 754, 421, 182, 182, 651, 879, 887, 819, 904, 904, 887, 309, 880, 396, 754, 775, 997, 222, 336, 548, 841, 269, 479, 479, 940, 23, 56, 738, 835, 395, 206, 779, 531, 862, 931, 306, 203, 755, 369, 6, 466, 716, 948, 82, 575, 288, 556, 903, 556, 392, 796, 751, 835, 103, 25, 408, 835, 835, 339, 339, 395, 250, 706, 317, 479, 800, 960, 141, 479, 908, 801, 327, 937, 559, 708, 372, 372, 573, 437, 437, 421, 203, 739, 830, 739, 358, 830, 248, 411, 411, 112, 321, 23, 23, 185, 971, 62, 339, 461, 488, 934, 148, 373, 561, 681, 760, 531, 612, 699, 23, 967, 457, 790, 154, 906, 465, 502, 884, 479, 246, 820, 601, 309, 716, 314, 377, 309, 309, 556, 118, 99, 358, 1018, 862, 779, 62, 835, 25, 254, 254, 677, 73, 143, 69

In [6]:
# source = encodec_code[0]
# src_encodec_ids = ar_tokenizer.convert_tokens_to_ids(
#     [f"v_tok_{u}" for u in source])
# src_encodec_str = ar_tokenizer.convert_tokens_to_string(
#     [f"v_tok_{u}" for u in source])

# print('source: ', source)
# print('src_encodec_ids: ', src_encodec_ids)
# print('src_encodec_str: ', src_encodec_str)
# print('size of source: ', len(source))
# print('size of src_encodec_ids: ', len(src_encodec_ids))
# print('size of src_encodec_str: ', len(src_encodec_str))

# decoded_text = ar_tokenizer.decode(encodec_code[0], skip_special_tokens=True)
# print(decoded_text)

In [8]:
# # inference
# import sys
# sys.path.append('/work/b0990106x/TextRL/vc')
# from vc.trainer_encodec_vc_inference import run
# from types import SimpleNamespace

# args_vc = SimpleNamespace(
#     dataset="lca0503/soxdata_small_encodec",
#     splits=["train"],
#     ground_truth_only=False,
#     cascade_ar_nar=True,
#     nar_model_only=False,
#     ground_truth_model_name="voidful/bart-base-unit",
#     ar_checkpoint="lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans",
#     nar_checkpoint="lca0503/speech-chatgpt-base-nar-v2-epoch4-wotrans",
#     ground_truth_output_path="output_wav/vc/ground_truth/train_1.wav",
#     cascade_output_path="output_wav/vc/ar_nar_cascade/train_1.wav",
#     nar_output_path="output_wav/vc/nar/train_1.wav",
#     seed=0,
#     device="cuda"
# )    

# input_dir = "/work/b0990106x/TextRL/data-encodec"
# output_path = "/work/b0990106x/TextRL/output/example.wav"

# encodec_code = run(args_vc, input_dir, output_path)
# print("Done")

# inference all data and replace the src_encodec[0] with the decode_ar
all_decode_ar = []
all_decode_ar_str = []
for i in range(data_len):
    print(f"Processing {i}...")
    decode_ar = get_ar_prediction(args_predict, ar_model, nar_model, ar_tokenizer, nar_tokenizer, all_src_encodec[i], all_instruction[i])
    all_decode_ar.append(decode_ar)
    
    decode_ar_str = ar_tokenizer.convert_tokens_to_string(
        [f"v_tok_{u}" for u in decode_ar])
    all_decode_ar_str.append(decode_ar_str)

Processing 0...




Write audio to  /work/b0990106x/TextRL/output/example.wav
Processing 1...
Write audio to  /work/b0990106x/TextRL/output/example.wav
Processing 2...
Write audio to  /work/b0990106x/TextRL/output/example.wav
Processing 3...
Write audio to  /work/b0990106x/TextRL/output/example.wav
Processing 4...
Write audio to  /work/b0990106x/TextRL/output/example.wav
Processing 5...
Write audio to  /work/b0990106x/TextRL/output/example.wav
Processing 6...
Write audio to  /work/b0990106x/TextRL/output/example.wav
Processing 7...
Write audio to  /work/b0990106x/TextRL/output/example.wav
Processing 8...
Write audio to  /work/b0990106x/TextRL/output/example.wav
Processing 9...
Write audio to  /work/b0990106x/TextRL/output/example.wav
Processing 10...
Write audio to  /work/b0990106x/TextRL/output/example.wav
Processing 11...
Write audio to  /work/b0990106x/TextRL/output/example.wav
Processing 12...
Write audio to  /work/b0990106x/TextRL/output/example.wav
Processing 13...
Write audio to  /work/b0990106x/Te

In [9]:
from importlib import reload
import textrl
reload(textrl)
from textrl import TextRLEnv,TextRLActor

In [10]:
from NISQA.nisqa.NISQA_model import nisqaModel

class MyRLEnv(TextRLEnv):
    def get_reward(self, input_item, predicted_list, finish): # predicted will be the list of predicted token
        reward = 0
        if finish or len(predicted_list) >= self.env_max_length:
            # nar_encodec_ids = input_item['input'][1:-1]
            single_src_encodec = input_item['src_encodec']
            single_instruction = input_item['instruction']
            decode_ar = get_ar_prediction(args_predict, ar_model, nar_model, ar_tokenizer, nar_tokenizer, single_src_encodec, single_instruction)

            print("- src_encodec_ids: ", src_encodec_ids)
            print("- size of decode_ar: ", len(decode_ar))
            
            args_nisqa = {
                'mode': 'predict_file', 
                'pretrained_model': f'{base_path}/NISQA/weights/nisqa.tar', 
                'deg': f'{base_path}/output/example.wav', 
                'data_dir': None, 
                'output_dir': f'{base_path}/NISQA/result',
                'csv_file': None, 
                'csv_deg': None,  
                'num_workers': 0, 
                'bs': 1,
                'ms_channel': None
            }
            args_nisqa['tr_bs_val'] = args_nisqa['bs']
            args_nisqa['tr_num_workers'] = args_nisqa['num_workers']
            
            nisqa = nisqaModel(args_nisqa)
            prediction = nisqa.predict()
            reward = float(prediction['mos_pred'].iloc[0])
            print("input_item : ",input_item['input'])
            print("predicted_list: ", predicted_list)
            print("reward: ", reward) 
                       
        return reward

**fit one example**

In [11]:
# observation_list = [{'input':src_encodec_str}]
# put all decode_ar to the observation_list
observation_list = []
for i in range(data_len):
    observation_list.append({'input':all_decode_ar_str[i], 'src_encodec':all_src_encodec[i], 'instruction':all_instruction[i]})

In [12]:
env = MyRLEnv(ar_model, ar_tokenizer, observation_input=observation_list, compare_sample=1)
actor = TextRLActor(env, ar_model, ar_tokenizer)
agent = actor.agent_ppo(update_interval=100, minibatch_size=3, epochs=10)

model name:  BartForConditionalGeneration


In [13]:
predicted_str = actor.predict(observation_list[0])



Write audio to  /work/b0990106x/TextRL/output/example.wav
- src_encodec_ids:  [50673, 51100, 51100, 51063, 50850, 50815, 50800, 50800, 51002, 51002, 50642, 50821, 50866, 51052, 50273, 50364, 50676, 50676, 50643, 51202, 50643, 51202, 51069, 51103, 51155, 51199, 50312, 50703, 50703, 50996, 51003, 50398, 50974, 50744, 50744, 50744, 50416, 51205, 50767, 51171, 50672, 50910, 50335, 50473, 50802, 50802, 51287, 50946, 50988, 51012, 50858, 51069, 50946, 51144, 50401, 51232, 50498, 50696, 51019, 50686, 50447, 50447, 50916, 51144, 51152, 51084, 51169, 51169, 51152, 50574, 51145, 50661, 51019, 51040, 51262, 50487, 50601, 50813, 51106, 50534, 50744, 50744, 51205, 50288, 50321, 51003, 51100, 50660, 50471, 51044, 50796, 51127, 51196, 50571, 50468, 51020, 50634, 50271, 50731, 50981, 51213, 50347, 50840, 50553, 50821, 51168, 50821, 50657, 51061, 51016, 51100, 50368, 50290, 50673, 51100, 51100, 50604, 50604, 50660, 50515, 50971, 50582, 50744, 51065, 51225, 50406, 50744, 51173, 51066, 50592, 51202, 5082



In [14]:
# decode the predicted token
predicted_ids = ar_tokenizer.convert_tokens_to_ids(predicted_str)
decoded_text = ar_tokenizer.decode(predicted_ids, skip_special_tokens=True)
print("predicted ids: ", predicted_ids)
print("decoded text: ", decoded_text)

predicted ids:  [3]
decoded text:  


In [15]:
pfrl.experiments.train_agent_with_evaluation(
    agent,
    env,
    steps=300,
    eval_n_steps=None,
    eval_n_episodes=16,       
    train_max_episode_len=100,  
    eval_interval=10,
    outdir='elon_musk_dogecoin', 
)

  actions = torch.tensor([b["action"] for b in dataset], device=device)


outdir:elon_musk_dogecoin step:100 episode:0 R:0
statistics:[('average_value', -0.111027166), ('average_entropy', 0.6466364), ('average_value_loss', 0.06189556084107608), ('average_policy_loss', 0.009546234952285886), ('n_updates', 334), ('explained_variance', -49.1095955034181)]
evaluation episode 0 length:100 R:0
evaluation episode 1 length:100 R:0
evaluation episode 2 length:100 R:0
evaluation episode 3 length:100 R:0
evaluation episode 4 length:100 R:0
evaluation episode 5 length:100 R:0
evaluation episode 6 length:100 R:0
evaluation episode 7 length:100 R:0
evaluation episode 8 length:100 R:0
evaluation episode 9 length:100 R:0
evaluation episode 10 length:100 R:0
evaluation episode 11 length:100 R:0
evaluation episode 12 length:100 R:0
evaluation episode 13 length:100 R:0
evaluation episode 14 length:100 R:0
Write audio to  /work/b0990106x/TextRL/output/example.wav
- src_encodec_ids:  [50673, 51100, 51100, 51063, 50850, 50815, 50800, 50800, 51002, 51002, 50642, 50821, 50866, 5105

(<textrl.actor.TextPPO at 0x7f0dfab37220>,
 [{'average_value': -0.111027166,
   'average_entropy': 0.6466364,
   'average_value_loss': 0.06189556084107608,
   'average_policy_loss': 0.009546234952285886,
   'n_updates': 334,
   'explained_variance': -49.1095955034181,
   'eval_score': 0.22349439561367035},
  {'average_value': -0.08942905,
   'average_entropy': 0.5842839,
   'average_value_loss': 0.05596172848483547,
   'average_policy_loss': 0.00359727269038558,
   'n_updates': 668,
   'explained_variance': -81.97832689315183,
   'eval_score': 0.2874697521328926},
  {'average_value': -0.11561437,
   'average_entropy': 0.54533297,
   'average_value_loss': 0.04943416909314692,
   'average_policy_loss': -0.002569490633904934,
   'n_updates': 1002,
   'explained_variance': -32.31793664470726,
   'eval_score': 0.52552829682827}])

loading the best result and predict.

In [16]:
agent.load("./elon_musk_dogecoin/best")

In [17]:
actor.predict(observation_list[0])

Write audio to  /work/b0990106x/TextRL/output/example.wav
- src_encodec_ids:  [50673, 51100, 51100, 51063, 50850, 50815, 50800, 50800, 51002, 51002, 50642, 50821, 50866, 51052, 50273, 50364, 50676, 50676, 50643, 51202, 50643, 51202, 51069, 51103, 51155, 51199, 50312, 50703, 50703, 50996, 51003, 50398, 50974, 50744, 50744, 50744, 50416, 51205, 50767, 51171, 50672, 50910, 50335, 50473, 50802, 50802, 51287, 50946, 50988, 51012, 50858, 51069, 50946, 51144, 50401, 51232, 50498, 50696, 51019, 50686, 50447, 50447, 50916, 51144, 51152, 51084, 51169, 51169, 51152, 50574, 51145, 50661, 51019, 51040, 51262, 50487, 50601, 50813, 51106, 50534, 50744, 50744, 51205, 50288, 50321, 51003, 51100, 50660, 50471, 51044, 50796, 51127, 51196, 50571, 50468, 51020, 50634, 50271, 50731, 50981, 51213, 50347, 50840, 50553, 50821, 51168, 50821, 50657, 51061, 51016, 51100, 50368, 50290, 50673, 51100, 51100, 50604, 50604, 50660, 50515, 50971, 50582, 50744, 51065, 51225, 50406, 50744, 51173, 51066, 50592, 51202, 5082

['v_tok_408v_tok_835v_tok_835v_tok_798v_tok_585v_tok_550v_tok_535v_tok_535v_tok_737v_tok_737v_tok_377v_tok_556v_tok_601v_tok_787v_tok_8v_tok_99v_tok_411v_tok_411v_tok_378v_tok_937v_tok_378v_tok_937v_tok_804v_tok_838v_tok_890v_tok_934v_tok_47v_tok_438v_tok_438v_tok_731v_tok_738v_tok_133v_tok_709v_tok_479v_tok_479v_tok_479v_tok_151v_tok_940v_tok_502v_tok_906v_tok_407v_tok_645v_tok_70v_tok_208v_tok_537v_tok_537v_tok_1022v_tok_681v_tok_723v_tok_747v_tok_593v_tok_804v_tok_681v_tok_879v_tok_136v_tok_967v_tok_233v_tok_431v_tok_754v_tok_421v_tok_182v_tok_182v_tok_651v_tok_879v_tok_887v_tok_819v_tok_904v_tok_904v_tok_887v_tok_309v_tok_880v_tok_396v_tok_754v_tok_775v_tok_997v_tok_222v_tok_336v_tok_548v_tok_841v_tok_269v_tok_479v_tok_479v_tok_940v_tok_23v_tok_56v_tok_738v_tok_835v_tok_395v_tok_206v_tok_779v_tok_531v_tok_862v_tok_931v_tok_306v_tok_203v_tok_755v_tok_369v_tok_6v_tok_466v_tok_716v_tok_948v_tok_82v_tok_575v_tok_288v_tok_556v_tok_903v_tok_556v_tok_392v_tok_796v_tok_751v_tok_835v_tok_10