# Controllable generation via RL to let Elon Musk speak ill of DOGE
> How to control text generation through a sentiment classifier.



In [1]:
# %pip install pfrl@git+https://github.com/voidful/pfrl.git
# %pip install textrl==0.2.15

In [2]:
import torch
from datasets import load_from_disk
from vc.encodec_model.nar_bart_model import NARBartForConditionalGeneration
from transformers import (AutoTokenizer, BartForConditionalGeneration)
import logging
import sys
import pfrl
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')

# define path
base_path = '/work/b0990106x/TextRL'
agent_input_dir = f'{base_path}/data-encodec'
agent_output_dir = f'{base_path}/output'
env_input_dir = agent_output_dir
env_output_dir = agent_input_dir

ar_checkpoint = "lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans"
nar_checkpoint = "lca0503/speech-chatgpt-base-nar-v2-epoch4-wotrans"

device = "cuda" if torch.cuda.is_available() else "cpu"
ar_tokenizer = AutoTokenizer.from_pretrained(ar_checkpoint)
ar_model = BartForConditionalGeneration.from_pretrained(ar_checkpoint)
nar_tokenizer = AutoTokenizer.from_pretrained(nar_checkpoint)
nar_model = NARBartForConditionalGeneration.from_pretrained(nar_checkpoint)
ar_model.to(device)

dataset = load_from_disk(agent_input_dir)
# source = dataset[f"src_encodec_0"][0]
# instruction = dataset["instruction"][0]
# transcription = dataset["transcription"][0]
# instruction_ids = ar_tokenizer(instruction)["input_ids"][1 : -1]
# transcription_ids = ar_tokenizer(transcription)["input_ids"][1 : -1]
# src_encodec_ids = ar_tokenizer.convert_tokens_to_ids(
#     [f"v_tok_{u}" for u in dataset[f"src_encodec_0"][0]])
# src_encodec_str = ar_tokenizer.convert_tokens_to_string(
#     [f"v_tok_{u}" for u in dataset[f"src_encodec_0"][0]])


In [3]:
# prepare data
all_src_encodec_layers = []
all_src_encodec = []
all_instruction = []
all_instruction_ids = []

# data_len = len(dataset)
data_len = 5 # for testing
layer_len = 8

for i in range(layer_len):
    all_src_encodec_layers.append(dataset[f"src_encodec_{i}"])

for i in range(data_len):
    src_encodec = []
    for j in range(layer_len):        
        src_encodec.append(all_src_encodec_layers[j][i])
    all_src_encodec.append(src_encodec)

for i in range(data_len):
    all_instruction.append(dataset["instruction"][i])
    all_instruction_ids.append(ar_tokenizer(all_instruction[i])["input_ids"][1 : -1])
    

In [4]:
# check data validity
# all data in all_src_encodec must be the numbers instead of strings
for i in range(data_len):
    for j in range(layer_len):
        assert isinstance(all_src_encodec[i][j], list)
        for k in range(len(all_src_encodec[i][j])):
            assert isinstance(all_src_encodec[i][j][k], int)

In [5]:
# run voice conversion model to get the target speech
import sys
sys.path.append('/work/b0990106x/TextRL/vc')
from vc.trainer_encodec_vc_inference import get_ar_prediction, get_ar_prediction_without_writing_files
from types import SimpleNamespace

args_predict = SimpleNamespace(
    output_path = "/work/b0990106x/TextRL/output/example.wav",
    seed = 0,
    device = "cuda"
)    

# single_src_encodec = all_src_encodec[0]
# single_instruction = all_instruction[0]
# print("single_src_encodec: ", single_src_encodec)
# print("single_instruction: ", single_instruction)

# decode_ar = get_ar_prediction(args_predict, ar_model, nar_model, ar_tokenizer, nar_tokenizer, single_src_encodec, single_instruction)
# decode_ar_ids = ar_tokenizer.convert_tokens_to_ids(
#     [f"v_tok_{u}" for u in decode_ar])
# decode_ar_str = ar_tokenizer.convert_tokens_to_string(
#     [f"v_tok_{u}" for u in decode_ar])
    
# print("decode_ar: ", decode_ar)
# print("decode_ar_ids: ", decode_ar_ids)
# print("decode_ar_str: ", decode_ar_str)

In [6]:
# # demo how the tokenization works

# # source speech before tokenization
# print('source: ', source)
# print('size of source: ', len(source))
# print('src_encodec_ids: ', src_encodec_ids)
# print('size of src_encodec_ids: ', len(src_encodec_ids))
# print('src_encodec_str: ', src_encodec_str)
# print('size of src_encodec_str: ', len(src_encodec_str))
# # source speech after tokenization
# tokens = ar_tokenizer.convert_ids_to_tokens(src_encodec_ids)
# ids = ar_tokenizer.convert_tokens_to_ids(tokens)
# print('ar_tokenizer.convert_ids_to_tokens(src_encodec_ids): ', tokens)
# print('ar_tokenizer.convert_tokens_to_ids(tokens): ', ids)
# print('ar_tokenizer.convert_tokens_to_ids(tokens): ', ids)
# print('size of ar_tokenizer.convert_ids_to_tokens(src_encodec_ids): ', len(tokens))
# # instruction before tokenization
# print(instruction)
# # instruction after tokenization
# print(ar_tokenizer.convert_ids_to_tokens(instruction_ids))
# # transcription before tokenization
# print(transcription)
# # transcription after tokenization
# print(ar_tokenizer.convert_ids_to_tokens(transcription_ids))



In [7]:
# # inference all data and replace the src_encodec[0] with the decode_ar
# all_decode_ar = []
# all_decode_ar_str = []
# for i in range(data_len):
#     print(f"Processing {i}...")
#     decode_ar = get_ar_prediction_without_writing_files(args_predict, ar_model, nar_model, ar_tokenizer, nar_tokenizer, all_src_encodec[i], all_instruction[i])
#     all_decode_ar.append(decode_ar)
    
#     decode_ar_str = ar_tokenizer.convert_tokens_to_string(
#         [f"v_tok_{u}" for u in decode_ar])
#     all_decode_ar_str.append(decode_ar_str)

In [8]:
from importlib import reload
import textrl
reload(textrl)
from textrl import TextRLEnv,TextRLActor
reload(sys.modules['vc.trainer_encodec_vc_inference'])

from vc.trainer_encodec_vc_inference import get_ar_prediction, get_ar_prediction_without_writing_files


In [9]:
from NISQA.nisqa.NISQA_model import nisqaModel

class MyRLEnv(TextRLEnv):
    def get_reward(self, input_item, predicted_list, finish): # predicted will be the list of predicted token
        reward = 0
        if finish or len(predicted_list) >= self.env_max_length:
            # single_src_encodec = input_item['src_encodec']
            # single_instruction = input_item['instruction']
            # decode_ar = get_ar_prediction(args_predict, ar_model, nar_model, ar_tokenizer, nar_tokenizer, single_src_encodec, single_instruction)
            
            args_nisqa = {
                'mode': 'predict_file', 
                'pretrained_model': f'{base_path}/NISQA/weights/nisqa.tar', 
                'deg': f'{base_path}/output/example.wav', 
                'data_dir': None, 
                'output_dir': f'{base_path}/NISQA/result',
                'csv_file': None, 
                'csv_deg': None,  
                'num_workers': 0, 
                'bs': 1,
                'ms_channel': None
            }
            args_nisqa['tr_bs_val'] = args_nisqa['bs']
            args_nisqa['tr_num_workers'] = args_nisqa['num_workers']
            
            nisqa = nisqaModel(args_nisqa)
            prediction = nisqa.predict()
            reward = float(prediction['mos_pred'].iloc[0])
            print("input_item : ",input_item['input'])
            print("predicted_list: ", predicted_list)
            print("reward: ", reward) 
                       
        return reward

**fit one example**

In [10]:
# observation_list = [{'input':src_encodec_str}]
# put all decode_ar to the observation_list
observation_list = []
for i in range(data_len):
    # observation_list.append({'input':all_decode_ar_str[i], 'src_encodec':all_src_encodec[i], 'instruction':all_instruction[i]})
    observation_list.append({'input':"", 'src_encodec':all_src_encodec[i], 'instruction':all_instruction[i]})


In [11]:
env = MyRLEnv(ar_model, ar_tokenizer, nar_model, nar_tokenizer, observation_input=observation_list, compare_sample=1)
actor = TextRLActor(env, ar_model, ar_tokenizer)
agent = actor.agent_ppo(update_interval=100, minibatch_size=3, epochs=10)

model name:  BartForConditionalGeneration
single_src_encodec:  [[835, 835, 835, 126, 276, 677, 409, 460, 522, 682, 924, 477, 816, 704, 976, 106, 25, 1021, 552, 301, 301, 448, 186, 934, 574, 47, 438, 133, 430, 856, 877, 800, 830, 997, 997, 358, 850, 846, 233, 846, 206, 602, 62, 62, 834, 339, 63, 602, 887, 798, 339, 834, 976, 835, 335, 276, 462, 400, 776, 642, 368, 858, 424, 598, 860, 724, 999, 887, 835, 144, 151, 151, 25, 875, 1021, 901, 813, 424, 598, 321, 679, 955, 901, 838, 552, 695, 695, 501, 432, 339, 1019, 910, 839, 261, 481, 870, 695, 695, 51, 233, 904, 865, 629, 62, 835, 339, 372, 492, 74, 982, 922, 835, 25, 25, 408, 202, 38, 747, 754, 926, 984, 387, 434, 880, 358, 358, 941, 870, 387, 739, 870, 387, 99, 387, 348, 688, 990, 886, 613, 385, 185, 23, 62, 835, 141, 157, 62, 835, 835, 339, 662, 564, 372, 372, 984, 437, 437, 754, 358, 248, 754, 838, 813, 465, 407, 588, 431, 956, 154, 782, 862, 136, 255, 858, 931, 588, 317, 317, 154, 1021, 838, 813, 385, 804, 925, 722, 385, 679, 843, 97



Write audio to  /work/b0990106x/TextRL/output/example.wav
decode_ar:  [835, 835, 126, 276, 276, 151, 460, 528, 682, 924, 477, 460, 704, 976, 106, 25, 1021, 552, 301, 301, 448, 186, 934, 574, 438, 373, 133, 430, 856, 877, 926, 830, 358, 739, 358, 850, 233, 233, 233, 65, 151, 62, 62, 834, 133, 499, 602, 887, 819, 339, 834, 976, 835, 335, 276, 462, 916, 776, 642, 143, 858, 424, 598, 860, 724, 999, 904, 835, 206, 151, 151, 1017, 875, 1021, 901, 953, 424, 598, 321, 143, 955, 901, 645, 1000, 695, 695, 501, 432, 835, 1019, 910, 839, 261, 481, 870, 695, 695, 233, 233, 904, 922, 629, 62, 835, 835, 372, 492, 74, 982, 922, 835, 339, 25, 408, 202, 38, 747, 754, 926, 984, 248, 742, 880, 358, 358, 870, 870, 387, 739, 870, 387, 99, 387, 348, 688, 683, 886, 942, 385, 185, 23, 62, 835, 951, 157, 62, 835, 408, 835, 662, 564, 372, 428, 984, 437, 437, 754, 358, 248, 754, 501, 813, 465, 407, 944, 588, 956, 154, 612, 862, 136, 255, 890, 931, 944, 317, 317, 465, 1021, 502, 813, 385, 1021, 925, 722, 722, 679,

In [12]:
predicted_str = actor.predict(observation_list[0])

single_src_encodec:  [[408, 835, 835, 798, 585, 550, 535, 535, 737, 737, 377, 556, 601, 787, 8, 99, 411, 411, 378, 937, 378, 937, 804, 838, 890, 934, 47, 438, 438, 731, 738, 133, 709, 479, 479, 479, 151, 940, 502, 906, 407, 645, 70, 208, 537, 537, 1022, 681, 723, 747, 593, 804, 681, 879, 136, 967, 233, 431, 754, 421, 182, 182, 651, 879, 887, 819, 904, 904, 887, 309, 880, 396, 754, 775, 997, 222, 336, 548, 841, 269, 479, 479, 940, 23, 56, 738, 835, 395, 206, 779, 531, 862, 931, 306, 203, 755, 369, 6, 466, 716, 948, 82, 575, 288, 556, 903, 556, 392, 796, 751, 835, 103, 25, 408, 835, 835, 339, 339, 395, 250, 706, 317, 479, 800, 960, 141, 479, 908, 801, 327, 937, 559, 708, 372, 372, 573, 437, 437, 421, 203, 739, 830, 739, 358, 830, 248, 411, 411, 112, 321, 23, 23, 185, 971, 62, 339, 461, 488, 934, 148, 373, 561, 681, 760, 531, 612, 699, 23, 967, 457, 790, 154, 906, 465, 502, 884, 479, 246, 820, 601, 309, 716, 314, 377, 309, 309, 556, 118, 99, 358, 1018, 862, 779, 62, 835, 25, 254, 254, 677

layer  0 :  v_tok_835v_tok_835v_tok_798v_tok_585v_tok_550v_tok_535v_tok_535v_tok_737v_tok_737v_tok_377v_tok_556v_tok_601v_tok_787v_tok_8v_tok_99v_tok_411v_tok_411v_tok_378v_tok_937v_tok_378v_tok_937v_tok_804v_tok_838v_tok_890v_tok_934v_tok_47v_tok_438v_tok_438v_tok_731v_tok_738v_tok_133v_tok_709v_tok_479v_tok_479v_tok_479v_tok_151v_tok_940v_tok_502v_tok_906v_tok_407v_tok_645v_tok_70v_tok_208v_tok_537v_tok_537v_tok_1022v_tok_681v_tok_723v_tok_747v_tok_593v_tok_804v_tok_681v_tok_879v_tok_136v_tok_967v_tok_233v_tok_431v_tok_754v_tok_421v_tok_182v_tok_182v_tok_651v_tok_879v_tok_887v_tok_819v_tok_904v_tok_904v_tok_887v_tok_309v_tok_880v_tok_396v_tok_754v_tok_775v_tok_997v_tok_222v_tok_336v_tok_548v_tok_841v_tok_269v_tok_479v_tok_479v_tok_940v_tok_23v_tok_56v_tok_738v_tok_835v_tok_395v_tok_206v_tok_779v_tok_531v_tok_862v_tok_931v_tok_306v_tok_203v_tok_755v_tok_369v_tok_6v_tok_466v_tok_716v_tok_948v_tok_82v_tok_575v_tok_288v_tok_556v_tok_903v_tok_556v_tok_392v_tok_796v_tok_751v_tok_835v_tok_1



Write audio to  /work/b0990106x/TextRL/output/example.wav
decode_ar:  [835, 835, 798, 585, 550, 535, 535, 737, 737, 377, 556, 601, 787, 8, 99, 411, 411, 378, 937, 378, 937, 804, 838, 890, 934, 47, 438, 438, 731, 738, 133, 709, 479, 479, 479, 151, 940, 502, 906, 407, 645, 70, 208, 537, 537, 1022, 681, 723, 747, 593, 804, 681, 879, 136, 967, 233, 431, 754, 421, 182, 182, 651, 879, 887, 819, 904, 904, 887, 309, 880, 396, 754, 775, 997, 222, 336, 548, 841, 269, 479, 479, 940, 23, 56, 738, 835, 395, 206, 779, 531, 862, 931, 306, 203, 755, 369, 6, 466, 716, 948, 82, 575, 288, 556, 903, 556, 392, 796, 751, 835, 103, 25, 408, 835, 835, 339, 339, 395, 250, 706, 317, 479, 800, 960, 141, 479, 908, 801, 327, 937, 559, 708, 372, 372, 573, 437, 437, 421, 203, 739, 830, 739, 358, 830, 248, 411, 411, 112, 321, 23, 23, 185, 971, 62, 339, 461, 488, 934, 148, 373, 561, 681, 760, 531, 612, 699, 23, 967, 457, 790, 154, 906, 465, 502, 884, 479, 246, 820, 601, 309, 716, 314, 377, 309, 309, 556, 118, 99, 358,

ValueError: invalid literal for int() with base 10: '835 air'

In [None]:
# decode the predicted token
predicted_ids = ar_tokenizer.convert_tokens_to_ids(predicted_str)
decoded_text = ar_tokenizer.decode(predicted_ids, skip_special_tokens=True)
print("predicted ids: ", predicted_ids)
print("decoded text: ", decoded_text)

In [None]:
pfrl.experiments.train_agent_with_evaluation(
    agent,
    env,
    steps=300,
    eval_n_steps=None,
    eval_n_episodes=2,       
    train_max_episode_len=100,  
    eval_interval=10,
    outdir='elon_musk_dogecoin', 
)

loading the best result and predict.

In [None]:
agent.load("./elon_musk_dogecoin/best")

In [None]:
actor.predict(observation_list[0])