# Controllable generation via RL to let Elon Musk speak ill of DOGE
> How to control text generation through a sentiment classifier.



In [None]:
import torch
from datasets import load_from_disk
from vc.encodec_model.nar_bart_model import NARBartForConditionalGeneration
from transformers import (AutoTokenizer, BartForConditionalGeneration)

# define path
base_path = '/work/b0990106x/TextRL'
agent_input_dir = f'{base_path}/data-encodec'
agent_output_dir = f'{base_path}/output-predict'
env_input_dir = agent_output_dir
env_output_dir = agent_input_dir

ar_checkpoint = "lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans"
nar_checkpoint = "lca0503/speech-chatgpt-base-nar-v2-epoch4-wotrans"

device = "cuda" if torch.cuda.is_available() else "cpu"
ar_tokenizer = AutoTokenizer.from_pretrained(ar_checkpoint)
ar_model = BartForConditionalGeneration.from_pretrained(ar_checkpoint)
nar_tokenizer = AutoTokenizer.from_pretrained(nar_checkpoint)
nar_model = NARBartForConditionalGeneration.from_pretrained(nar_checkpoint)
ar_model.to(device)

In [None]:
dataset = load_from_disk(agent_input_dir)

In [None]:
all_src_encodec_layers = []
all_src_encodec = []
all_instruction = []
all_instruction_ids = []

data_len = len(dataset)
print(data_len)

layer_len = 8

for i in range(layer_len):
    all_src_encodec_layers.append(dataset[f"src_encodec_{i}"])

for i in range(data_len):
    src_encodec = []
    for j in range(layer_len):        
        src_encodec.append(all_src_encodec_layers[j][i])
    all_src_encodec.append(src_encodec)

for i in range(data_len):
    all_instruction.append(dataset["instruction"][i])
    all_instruction_ids.append(ar_tokenizer(all_instruction[i])["input_ids"][1 : -1])

In [None]:
from importlib import reload
import textrl
reload(textrl)
from textrl import TextRLEnv,TextRLActor

In [None]:
from NISQA.nisqa.NISQA_model import nisqaModel

class MyRLEnv(TextRLEnv):
    def get_reward(self, input_item, predicted_list, finish): # predicted will be the list of predicted token
        reward = 0
        
        if finish or len(predicted_list) >= self.env_max_length:
            try:
                # debug 0423
                print("Length of predicted_list:", len(predicted_list[0]))

                args_nisqa = {
                    'mode': 'predict_file', 
                    'pretrained_model': f'{base_path}/NISQA/weights/nisqa.tar', 
                    'deg': f'{base_path}/output-predict/example.wav', 
                    'data_dir': None, 
                    'output_dir': f'{base_path}/NISQA/result-predict',
                    'csv_file': None, 
                    'csv_deg': None,  
                    'num_workers': 0, 
                    'bs': 1,
                    'ms_channel': None
                }
                args_nisqa['tr_bs_val'] = args_nisqa['bs']
                args_nisqa['tr_num_workers'] = args_nisqa['num_workers']

                nisqa = nisqaModel(args_nisqa)
                prediction = nisqa.predict()
                reward = float(prediction['mos_pred'].iloc[0])
            except Exception as e:
                print("Error:", e)
                reward = 0
                
            print("Reward:", reward)

        return reward

**fit one example**

In [21]:
observation_list = []

# observation_list.append({'input': "", 'src_encodec': all_src_encodec[0], 'instruction': all_instruction[0]})
# observation_list.append({'input': "", 'src_encodec': all_src_encodec[0], 'instruction': all_instruction[0]})
# observation_list.append({'input': "", 'src_encodec': all_src_encodec[0], 'instruction': all_instruction[0]})
# observation_list.append({'input': "", 'src_encodec': all_src_encodec[0], 'instruction': all_instruction[0]})
# observation_list.append({'input': "", 'src_encodec': all_src_encodec[0], 'instruction': all_instruction[0]})
# observation_list.append({'input': "", 'src_encodec': all_src_encodec[0], 'instruction': all_instruction[0]})
# observation_list.append({'input': "", 'src_encodec': all_src_encodec[0], 'instruction': all_instruction[0]})
# observation_list.append({'input': "", 'src_encodec': all_src_encodec[0], 'instruction': all_instruction[0]})
# observation_list.append({'input': "", 'src_encodec': all_src_encodec[0], 'instruction': all_instruction[0]})
# observation_list.append({'input': "", 'src_encodec': all_src_encodec[0], 'instruction': all_instruction[0]})

for i in range(data_len):
    observation_list.append({'input': "", 'src_encodec': all_src_encodec[i], 'instruction': all_instruction[i]})

In [13]:
env = MyRLEnv(ar_model, ar_tokenizer, nar_model, nar_tokenizer, observation_input=observation_list, compare_sample=1)
actor = TextRLActor(env, ar_model, ar_tokenizer)
agent = actor.agent_ppo(update_interval=3, minibatch_size=3, epochs=10)

model name:  BartForConditionalGeneration
----------------------------- reset -----------------------------
size_of_packed_input:  335
Input IDs shape: torch.Size([1, 335])
Episode 0 : audio saved to  /work/b0990106x/TextRL/output-predict/example_save_0.wav


loading the best result and predict.

In [26]:
print(observation_list[5]["instruction"])

Intensify the sound of the higher frequencies.


In [14]:
actor.predict(observation_list[0])

----------------------------- reset -----------------------------
size_of_packed_input:  335
Input IDs shape: torch.Size([1, 335])
Episode 1 : audio saved to  /work/b0990106x/TextRL/output-predict/example_save_1.wav
Length of predicted_list: 650
n_wins 851  seg_length 15  x.shape[1] 865
x.shape torch.Size([1300, 1, 48, 15])  n_wins 851
Reward: 2.37109375


['v_tok_408v_tok_835v_tok_835v_tok_798v_tok_585v_tok_550v_tok_535v_tok_535v_tok_737v_tok_737v_tok_377v_tok_556v_tok_601v_tok_787v_tok_8v_tok_99v_tok_411v_tok_411v_tok_378v_tok_937v_tok_378v_tok_937v_tok_804v_tok_838v_tok_890v_tok_934v_tok_47v_tok_438v_tok_438v_tok_731v_tok_738v_tok_133v_tok_709v_tok_479v_tok_479v_tok_479v_tok_151v_tok_940v_tok_502v_tok_906v_tok_407v_tok_645v_tok_70v_tok_208v_tok_537v_tok_537v_tok_1022v_tok_681v_tok_723v_tok_747v_tok_593v_tok_804v_tok_681v_tok_879v_tok_136v_tok_967v_tok_233v_tok_431v_tok_754v_tok_421v_tok_182v_tok_182v_tok_651v_tok_879v_tok_887v_tok_819v_tok_904v_tok_904v_tok_887v_tok_309v_tok_880v_tok_396v_tok_754v_tok_775v_tok_997v_tok_222v_tok_336v_tok_548v_tok_841v_tok_269v_tok_479v_tok_479v_tok_940v_tok_23v_tok_56v_tok_738v_tok_835v_tok_395v_tok_206v_tok_779v_tok_531v_tok_862v_tok_931v_tok_306v_tok_203v_tok_755v_tok_369v_tok_6v_tok_466v_tok_716v_tok_948v_tok_82v_tok_575v_tok_288v_tok_556v_tok_903v_tok_556v_tok_392v_tok_796v_tok_751v_tok_835v_tok_10

In [15]:
pfrl_outdir = 'train-0424-300000'
agent.load(pfrl_outdir + '/best')
actor.predict(observation_list[0])

----------------------------- reset -----------------------------
size_of_packed_input:  335
Input IDs shape: torch.Size([1, 335])
Episode 2 : audio saved to  /work/b0990106x/TextRL/output-predict/example_save_2.wav
Length of predicted_list: 652
n_wins 853  seg_length 15  x.shape[1] 867
x.shape torch.Size([1300, 1, 48, 15])  n_wins 853
Reward: 2.310546875


['v_tok_835v_tok_835v_tok_835v_tok_798v_tok_585v_tok_550v_tok_535v_tok_535v_tok_737v_tok_737v_tok_377v_tok_556v_tok_601v_tok_787v_tok_8v_tok_99v_tok_411v_tok_411v_tok_378v_tok_937v_tok_378v_tok_937v_tok_804v_tok_838v_tok_890v_tok_934v_tok_47v_tok_438v_tok_438v_tok_731v_tok_738v_tok_133v_tok_709v_tok_479v_tok_479v_tok_479v_tok_151v_tok_940v_tok_502v_tok_906v_tok_407v_tok_645v_tok_70v_tok_208v_tok_537v_tok_537v_tok_1022v_tok_681v_tok_723v_tok_747v_tok_593v_tok_804v_tok_681v_tok_879v_tok_136v_tok_967v_tok_233v_tok_431v_tok_754v_tok_421v_tok_182v_tok_182v_tok_651v_tok_879v_tok_887v_tok_819v_tok_904v_tok_904v_tok_887v_tok_309v_tok_880v_tok_396v_tok_754v_tok_775v_tok_997v_tok_222v_tok_336v_tok_548v_tok_841v_tok_269v_tok_479v_tok_479v_tok_940v_tok_23v_tok_56v_tok_738v_tok_835v_tok_395v_tok_206v_tok_779v_tok_531v_tok_862v_tok_931v_tok_306v_tok_203v_tok_755v_tok_369v_tok_6v_tok_466v_tok_716v_tok_948v_tok_82v_tok_575v_tok_288v_tok_556v_tok_903v_tok_556v_tok_392v_tok_796v_tok_751v_tok_835v_tok_10

In [16]:
for i in range(1, 6):
    ckpt = f'{i*50000}_checkpoint'
    agent.load(pfrl_outdir + '/' + ckpt)
    actor.predict(observation_list[0])

----------------------------- reset -----------------------------
size_of_packed_input:  335
Input IDs shape: torch.Size([1, 335])
Episode 3 : audio saved to  /work/b0990106x/TextRL/output-predict/example_save_3.wav
Length of predicted_list: 658
n_wins 861  seg_length 15  x.shape[1] 875
x.shape torch.Size([1300, 1, 48, 15])  n_wins 861
Reward: 2.451171875
----------------------------- reset -----------------------------
size_of_packed_input:  335
Input IDs shape: torch.Size([1, 335])
Episode 4 : audio saved to  /work/b0990106x/TextRL/output-predict/example_save_4.wav
Length of predicted_list: 654
n_wins 856  seg_length 15  x.shape[1] 870
x.shape torch.Size([1300, 1, 48, 15])  n_wins 856
Reward: 2.26953125
----------------------------- reset -----------------------------
size_of_packed_input:  335
Input IDs shape: torch.Size([1, 335])
Episode 5 : audio saved to  /work/b0990106x/TextRL/output-predict/example_save_5.wav
Length of predicted_list: 650
n_wins 851  seg_length 15  x.shape[1] 8

In [17]:
ckpt = '300000_finish'
agent.load(pfrl_outdir + '/' + ckpt)
actor.predict(observation_list[0])

----------------------------- reset -----------------------------
size_of_packed_input:  335
Input IDs shape: torch.Size([1, 335])
Episode 8 : audio saved to  /work/b0990106x/TextRL/output-predict/example_save_8.wav
Length of predicted_list: 655
n_wins 857  seg_length 15  x.shape[1] 871
x.shape torch.Size([1300, 1, 48, 15])  n_wins 857
Reward: 2.12890625


['v_tok_780v_tok_835v_tok_835v_tok_798v_tok_585v_tok_550v_tok_535v_tok_535v_tok_737v_tok_737v_tok_377v_tok_556v_tok_601v_tok_787v_tok_8v_tok_99v_tok_411v_tok_411v_tok_378v_tok_937v_tok_378v_tok_937v_tok_804v_tok_838v_tok_890v_tok_934v_tok_47v_tok_438v_tok_438v_tok_731v_tok_738v_tok_133v_tok_709v_tok_479v_tok_479v_tok_479v_tok_151v_tok_940v_tok_502v_tok_906v_tok_407v_tok_645v_tok_70v_tok_208v_tok_537v_tok_537v_tok_1022v_tok_681v_tok_723v_tok_747v_tok_593v_tok_804v_tok_681v_tok_879v_tok_136v_tok_967v_tok_233v_tok_431v_tok_754v_tok_421v_tok_182v_tok_182v_tok_651v_tok_879v_tok_887v_tok_819v_tok_904v_tok_904v_tok_887v_tok_309v_tok_880v_tok_396v_tok_754v_tok_775v_tok_997v_tok_222v_tok_336v_tok_548v_tok_841v_tok_269v_tok_479v_tok_479v_tok_940v_tok_23v_tok_56v_tok_738v_tok_835v_tok_395v_tok_206v_tok_779v_tok_531v_tok_862v_tok_931v_tok_306v_tok_203v_tok_755v_tok_369v_tok_6v_tok_466v_tok_716v_tok_948v_tok_82v_tok_575v_tok_288v_tok_556v_tok_903v_tok_556v_tok_392v_tok_796v_tok_751v_tok_835v_tok_10