# Controllable generation via RL to let Elon Musk speak ill of DOGE
> How to control text generation through a sentiment classifier.



In [1]:
import torch
from datasets import load_from_disk
from vc.encodec_model.nar_bart_model import NARBartForConditionalGeneration
from transformers import (AutoTokenizer, BartForConditionalGeneration)
import logging
import sys
import pfrl
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')

# define path
base_path = '/work/b0990106x/TextRL'
agent_input_dir = f'{base_path}/data-encodec'
agent_output_dir = f'{base_path}/output'
env_input_dir = agent_output_dir
env_output_dir = agent_input_dir

ar_checkpoint = "lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans"
nar_checkpoint = "lca0503/speech-chatgpt-base-nar-v2-epoch4-wotrans"

device = "cuda" if torch.cuda.is_available() else "cpu"
ar_tokenizer = AutoTokenizer.from_pretrained(ar_checkpoint)
ar_model = BartForConditionalGeneration.from_pretrained(ar_checkpoint)
nar_tokenizer = AutoTokenizer.from_pretrained(nar_checkpoint)
nar_model = NARBartForConditionalGeneration.from_pretrained(nar_checkpoint)
ar_model.to(device)

dataset = load_from_disk(agent_input_dir)

In [2]:
all_src_encodec_layers = []
all_src_encodec = []
all_instruction = []
all_instruction_ids = []

data_len = len(dataset)
print(data_len)

data_len = 22 # for testing
layer_len = 8

for i in range(layer_len):
    all_src_encodec_layers.append(dataset[f"src_encodec_{i}"])

for i in range(data_len):
    src_encodec = []
    for j in range(layer_len):        
        src_encodec.append(all_src_encodec_layers[j][i])
    all_src_encodec.append(src_encodec)

for i in range(data_len):
    all_instruction.append(dataset["instruction"][i])
    all_instruction_ids.append(ar_tokenizer(all_instruction[i])["input_ids"][1 : -1])

9957


In [3]:
# import sys
# sys.path.append('/work/b0990106x/TextRL/vc')

from importlib import reload
import textrl
reload(textrl)

from textrl import TextRLEnv,TextRLActor
# reload(sys.modules['vc.trainer_encodec_vc_inference'])

In [4]:
from NISQA.nisqa.NISQA_model import nisqaModel

class MyRLEnv(TextRLEnv):
    def get_reward(self, input_item, predicted_list, finish): # predicted will be the list of predicted token
        reward = 0
        if finish or len(predicted_list) >= self.env_max_length:
            args_nisqa = {
                'mode': 'predict_file', 
                'pretrained_model': f'{base_path}/NISQA/weights/nisqa.tar', 
                'deg': f'{base_path}/output/example.wav', 
                'data_dir': None, 
                'output_dir': f'{base_path}/NISQA/result',
                'csv_file': None, 
                'csv_deg': None,  
                'num_workers': 0, 
                'bs': 1,
                'ms_channel': None
            }
            args_nisqa['tr_bs_val'] = args_nisqa['bs']
            args_nisqa['tr_num_workers'] = args_nisqa['num_workers']
            
            nisqa = nisqaModel(args_nisqa)
            prediction = nisqa.predict()
            reward = float(prediction['mos_pred'].iloc[0])
            print("input_item : ",input_item['input'])
            print("predicted_list: ", predicted_list)
            print("reward: ", reward) 
                       
        return reward

**fit one example**

In [5]:
observation_list = []
for i in range(data_len):
    observation_list.append({'input': "", 'src_encodec': all_src_encodec[i], 'instruction': all_instruction[i]})

In [6]:
for i in range(data_len):
    print(f"Instruction {i}: ", observation_list[i]['instruction'])

Instruction 0:  Play the audio twice.
Instruction 1:  Mildly decrease the emphasis on the higher frequencies.
Instruction 2:  Considerably abate the bass frequencies.
Instruction 3:  Heighten the chorus effect in the audio by a small amount.
Instruction 4:  Hold off on playing the audio for 1 second.
Instruction 5:  Intensify the sound of the higher frequencies.
Instruction 6:  Give the audio a gradual increase in volume for 5 seconds from the onset.
Instruction 7:  Add a conspicuous chorus effect to the audio.
Instruction 8:  Significantly dampen the vibrations of the high notes.
Instruction 9:  Decrease the pitch of the audio by a moderate amount.
Instruction 10:  Introduce a minor adjustment to the pitch of the audio to make it lower.
Instruction 11:  Enlarge the scope and widen the reach of the sound quality.
Instruction 12:  Amplifying the sound to deliver a clearer and brighter rendition.
Instruction 13:  Backtrack the sound.
Instruction 14:  Enlarge the depth of the lower freque

In [7]:
env = MyRLEnv(ar_model, ar_tokenizer, nar_model, nar_tokenizer, observation_input=observation_list, compare_sample=1)
actor = TextRLActor(env, ar_model, ar_tokenizer)
agent = actor.agent_ppo(update_interval=3, minibatch_size=3, epochs=10)

model name:  BartForConditionalGeneration
----------------------------- reset -----------------------------
size_of_packed_input:  121
output_path_ckpt:  /work/b0990106x/TextRL/output/example_save_0.wav
Input IDs shape: torch.Size([1, 121])




Episode Counter:  0


In [8]:
actor.predict(observation_list[0])

----------------------------- reset -----------------------------
size_of_packed_input:  335
output_path_ckpt:  /work/b0990106x/TextRL/output/example_save_1.wav
Input IDs shape: torch.Size([1, 335])




Episode Counter:  1
input_item :  v_tok_835v_tok_835v_tok_798v_tok_585v_tok_550v_tok_535v_tok_535v_tok_737v_tok_737v_tok_377v_tok_556v_tok_601v_tok_787v_tok_8v_tok_99v_tok_411v_tok_411v_tok_378v_tok_937v_tok_378v_tok_937v_tok_804v_tok_838v_tok_890v_tok_934v_tok_47v_tok_438v_tok_438v_tok_731v_tok_738v_tok_133v_tok_709v_tok_479v_tok_479v_tok_479v_tok_151v_tok_940v_tok_502v_tok_906v_tok_407v_tok_645v_tok_70v_tok_208v_tok_537v_tok_537v_tok_1022v_tok_681v_tok_723v_tok_747v_tok_593v_tok_804v_tok_681v_tok_879v_tok_136v_tok_967v_tok_233v_tok_431v_tok_754v_tok_421v_tok_182v_tok_182v_tok_651v_tok_879v_tok_887v_tok_819v_tok_904v_tok_904v_tok_887v_tok_309v_tok_880v_tok_396v_tok_754v_tok_775v_tok_997v_tok_222v_tok_336v_tok_548v_tok_841v_tok_269v_tok_479v_tok_479v_tok_940v_tok_23v_tok_56v_tok_738v_tok_835v_tok_395v_tok_206v_tok_779v_tok_531v_tok_862v_tok_931v_tok_306v_tok_203v_tok_755v_tok_369v_tok_6v_tok_466v_tok_716v_tok_948v_tok_82v_tok_575v_tok_288v_tok_556v_tok_903v_tok_556v_tok_392v_tok_796v_t



['v_tok_408v_tok_835v_tok_835v_tok_798v_tok_585v_tok_550v_tok_535v_tok_535v_tok_737v_tok_737v_tok_377v_tok_556v_tok_601v_tok_787v_tok_8v_tok_99v_tok_411v_tok_411v_tok_378v_tok_937v_tok_378v_tok_937v_tok_804v_tok_838v_tok_890v_tok_934v_tok_47v_tok_438v_tok_438v_tok_731v_tok_738v_tok_133v_tok_709v_tok_479v_tok_479v_tok_479v_tok_151v_tok_940v_tok_502v_tok_906v_tok_407v_tok_645v_tok_70v_tok_208v_tok_537v_tok_537v_tok_1022v_tok_681v_tok_723v_tok_747v_tok_593v_tok_804v_tok_681v_tok_879v_tok_136v_tok_967v_tok_233v_tok_431v_tok_754v_tok_421v_tok_182v_tok_182v_tok_651v_tok_879v_tok_887v_tok_819v_tok_904v_tok_904v_tok_887v_tok_309v_tok_880v_tok_396v_tok_754v_tok_775v_tok_997v_tok_222v_tok_336v_tok_548v_tok_841v_tok_269v_tok_479v_tok_479v_tok_940v_tok_23v_tok_56v_tok_738v_tok_835v_tok_395v_tok_206v_tok_779v_tok_531v_tok_862v_tok_931v_tok_306v_tok_203v_tok_755v_tok_369v_tok_6v_tok_466v_tok_716v_tok_948v_tok_82v_tok_575v_tok_288v_tok_556v_tok_903v_tok_556v_tok_392v_tok_796v_tok_751v_tok_835v_tok_10

In [9]:
import sys

output_file_path = 'log.txt'

with open(output_file_path, 'w') as f:
    original_stdout = sys.stdout
    sys.stdout = f

    pfrl_outdir = 'train_steps_1100'
    # pfrl.experiments.train_agent_with_evaluation(
    #     agent,
    #     env,
    #     steps=900, # train the agent for n steps
    #     eval_n_steps=None, 
    #     eval_n_episodes=3, # evaluate n episodes per evaluation
    #     train_max_episode_len=1000,  
    #     eval_interval=300, # evaluation every n steps (not episodes)
    #     outdir=pfrl_outdir, 
    # )
    pfrl.experiments.train_agent_with_evaluation(
        agent,
        env,
        steps=1100,  
        eval_n_steps=None, 
        eval_n_episodes=6, 
        train_max_episode_len=1000,  
        eval_interval=3, 
        outdir=pfrl_outdir, 
    )

    sys.stdout = original_stdout

# pfrl_outdir = 'train_steps_900'
# pfrl.experiments.train_agent_with_evaluation(
#         agent,
#         env,
#         steps=900, # train the agent for n steps
#         eval_n_steps=None, 
#         eval_n_episodes=3, # evaluate n episodes per evaluation
#         train_max_episode_len=1000,
#         eval_interval=300, # evaluation every n steps (not episodes)
#         outdir=pfrl_outdir, 
#     )

print('Output has been written to', output_file_path)


Resetting the environment:
Training start:


  actions = torch.tensor([b["action"] for b in dataset], device=device)


Train - outdir:train_steps_1100 step:202 episode:0 R:2.5893542766571045
statistics:[('average_value', 0.024952846), ('average_entropy', 0.24391486), ('average_value_loss', 0.03236089914687909), ('average_policy_loss', -3.4665893433594166e-10), ('n_updates', 670), ('explained_variance', -144.21032811221025)]
evaluation episode 0 length:647 R:1.9931236505508423
evaluation episode 1 length:92 R:2.043978452682495
evaluation episode 2 length:336 R:1.9852625131607056
evaluation episode 3 length:338 R:2.1231095790863037
evaluation episode 4 length:184 R:2.7214086055755615
evaluation episode 5 length:413 R:2.3125948905944824
The best score is updated -3.4028235e+38 -> 2.196579615275065
Saved the agent to train_steps_1100/best
Evaluation - Evaluating agent at step 202, episode 1
Train - outdir:train_steps_1100 step:507 episode:1 R:1.6393355131149292
statistics:[('average_value', -0.07983082), ('average_entropy', 0.27021095), ('average_value_loss', 0.3697762148547918), ('average_policy_loss', 1.



evaluation episode 3 length:94 R:3.23932147026062
evaluation episode 4 length:305 R:1.5713354349136353
evaluation episode 5 length:336 R:1.9386097192764282
The best score is updated 2.196579615275065 -> 2.2352630893389382
Saved the agent to train_steps_1100/best
Evaluation - Evaluating agent at step 507, episode 2
Train - outdir:train_steps_1100 step:1100 episode:2 R:0
statistics:[('average_value', -0.06612169), ('average_entropy', 0.19046488), ('average_value_loss', 0.0588711621100083), ('average_policy_loss', -5.453825132839541e-08), ('n_updates', 3660), ('explained_variance', -566.5046935767299)]
evaluation episode 0 length:109 R:1.429401159286499
evaluation episode 1 length:132 R:2.8823482990264893
evaluation episode 2 length:116 R:1.6900497674942017
evaluation episode 3 length:306 R:1.3636900186538696
evaluation episode 4 length:344 R:1.714331865310669
evaluation episode 5 length:414 R:2.148758888244629
Evaluation - Evaluating agent at step 1100, episode 3
Saved the agent to train

loading the best result and predict.

In [10]:
agent.load(pfrl_outdir + '/best')

In [11]:
actor.predict(observation_list[0])

----------------------------- reset -----------------------------
size_of_packed_input:  335
output_path_ckpt:  /work/b0990106x/TextRL/output/example_save_23.wav
Input IDs shape: torch.Size([1, 335])
Episode Counter:  23
input_item :  v_tok_835v_tok_835v_tok_798v_tok_585v_tok_550v_tok_535v_tok_535v_tok_737v_tok_737v_tok_377v_tok_556v_tok_601v_tok_787v_tok_8v_tok_99v_tok_411v_tok_411v_tok_378v_tok_937v_tok_378v_tok_937v_tok_804v_tok_838v_tok_890v_tok_934v_tok_47v_tok_438v_tok_438v_tok_731v_tok_738v_tok_133v_tok_709v_tok_479v_tok_479v_tok_479v_tok_151v_tok_940v_tok_502v_tok_906v_tok_407v_tok_645v_tok_70v_tok_208v_tok_537v_tok_537v_tok_1022v_tok_681v_tok_723v_tok_747v_tok_593v_tok_804v_tok_681v_tok_879v_tok_136v_tok_967v_tok_233v_tok_431v_tok_754v_tok_421v_tok_182v_tok_182v_tok_651v_tok_879v_tok_887v_tok_819v_tok_904v_tok_904v_tok_887v_tok_309v_tok_880v_tok_396v_tok_754v_tok_775v_tok_997v_tok_222v_tok_336v_tok_548v_tok_841v_tok_269v_tok_479v_tok_479v_tok_940v_tok_23v_tok_56v_tok_738v_tok_

['v_tok_408v_tok_835v_tok_835v_tok_798v_tok_585v_tok_550v_tok_535v_tok_535v_tok_737v_tok_737v_tok_377v_tok_556v_tok_601v_tok_787v_tok_8v_tok_99v_tok_411v_tok_411v_tok_378v_tok_937v_tok_378v_tok_937v_tok_804v_tok_838v_tok_890v_tok_934v_tok_47v_tok_438v_tok_438v_tok_731v_tok_738v_tok_133v_tok_709v_tok_479v_tok_479v_tok_479v_tok_151v_tok_940v_tok_502v_tok_906v_tok_407v_tok_645v_tok_70v_tok_208v_tok_537v_tok_537v_tok_1022v_tok_681v_tok_723v_tok_747v_tok_593v_tok_804v_tok_681v_tok_879v_tok_136v_tok_967v_tok_233v_tok_431v_tok_754v_tok_421v_tok_182v_tok_182v_tok_651v_tok_879v_tok_887v_tok_819v_tok_904v_tok_904v_tok_887v_tok_309v_tok_880v_tok_396v_tok_754v_tok_775v_tok_997v_tok_222v_tok_336v_tok_548v_tok_841v_tok_269v_tok_479v_tok_479v_tok_940v_tok_23v_tok_56v_tok_738v_tok_835v_tok_395v_tok_206v_tok_779v_tok_531v_tok_862v_tok_931v_tok_306v_tok_203v_tok_755v_tok_369v_tok_6v_tok_466v_tok_716v_tok_948v_tok_82v_tok_575v_tok_288v_tok_556v_tok_903v_tok_556v_tok_392v_tok_796v_tok_751v_tok_835v_tok_10