**Environment Setup**

In [1]:
# %pip install pfrl@git+https://github.com/voidful/pfrl.git
# %pip install textrl==0.2.15
# %pip install ipywidgets

In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans")
model = AutoModelForSeq2SeqLM.from_pretrained("lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans")
model.eval()
model.cuda()

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(59481, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(59481, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=

In [3]:
import pfrl
from textrl import TextRLEnv,TextRLActor
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer, AutoModelWithLMHead
import logging
import sys
import torch
from NISQA.nisqa.NISQA_model import nisqaModel

logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')

In [4]:
base_path = '/work/b0990106x/TextRL'

# print("model", model)
# print("tokenizer", tokenizer)

**RL Environment: NISQA**

In [5]:
import gym
import logging
import random
import sys
import torch
from torch import autocast
from vc.wav_to_arrow import process_audio

class VcRLEnv(gym.Env):
    def __init__(self, model, tokenizer, observation_input=[], max_length=100, compare_sample=2, unfreeze_layer_from_past=0, env_input_dir=None, env_output_dir=None, instruction="", transcription=""):
        self.model = model
        self.tokenizer = tokenizer
        self.observation_space = observation_input
        self.compare_sample = compare_sample
        self.unfreeze_layer_from_past = 1 if unfreeze_layer_from_past else 0
        self.env_max_length = min(max(self.model.config.max_length, self.tokenizer.model_max_length), max_length)
        self.env_input_dir = env_input_dir
        self.env_output_dir = env_output_dir
        self.instruction = instruction1
        self.transcription = transcription
        # self.reset()
        
        self.gen_stop_toks = []
        logging.disable(sys.maxsize)
        if self.tokenizer.sep_token:
            self.gen_stop_toks.append(self.tokenizer.sep_token)
        if self.tokenizer.eos_token:
            self.gen_stop_toks.append(self.tokenizer.eos_token)
        logging.disable(logging.NOTSET)
    
    @autocast('cuda')
    def reset(self, input_item=None):
        self.predicted = [[]] * self.compare_sample
        self.predicted_end = [False] * self.compare_sample
        self.input_item = {"input": ""}
        return self._get_obs(self.predicted)

    def step(self, count):
        reward = self._predict(count)
        self._get_obs(count)
        return reward
    
    def get_reward(self, input_path=None, output_dir=None, count=0): # predicted will be the list of predicted token
        args = {
            'mode': 'predict_file', 
            'pretrained_model': f'{base_path}/NISQA/weights/nisqa.tar', 
            'deg': f'{base_path}/output/{count}.wav', 
            'data_dir': None, 
            'output_dir': f'{base_path}/NISQA/result',
            'csv_file': None, 
            'csv_deg': None,  
            'num_workers': 0, 
            'bs': 1,
            'ms_channel': None
        }

        if input_path is not None:
            args['deg'] = input_path

        args['tr_bs_val'] = args['bs']
        args['tr_num_workers'] = args['num_workers']
        
        nisqa = nisqaModel(args)
        prediction = nisqa.predict()
        reward = float(prediction['mos_pred'].iloc[0])
        print("count", count, "reward", reward)
        return reward
    
        
    def gat_obs_input(self, input_item):
        return input_item['input']
    
    def _get_obs(self, count, predicted=[]):
        # # backup
        # audio_path = f"{self.env_input_dir}/{count}.wav"
        # process_audio(source_audio_path=audio_path, output_dir = self.env_output_dir, temp_dir="/work/b0990106x/TextRL/vc/data/temp", instruction=self.instruction, transcription=self.transcription)
        # return count
        
        with torch.inference_mode():
            obs_list = []
            
            for p_text in predicted:
                p_text_str = self.tokenizer.convert_tokens_to_string(p_text)
                if self.model.__class__.__name__ == 'OPTForCausalLM':
                    feature_dict = self.tokenizer([[self.gat_obs_input(self.input_item), p_text_str]],
                                                  return_tensors='pt',
                                                  return_token_type_ids=False,
                                                  add_special_tokens=False).to(self.model.device)
                    with torch.cuda.amp.autocast(enabled=False):
                        prediction = self.model(**feature_dict, output_hidden_states=True)
                    outputs = prediction.hidden_states[-self.unfreeze_layer_from_past][:, -1, :]
                else:
                    if len([k for k, v in self.model.named_parameters() if 'decoder' in k]) > 0:
                        feature_dict = self.tokenizer([self.gat_obs_input(self.input_item)],
                                                      return_tensors='pt',
                                                      return_token_type_ids=False,
                                                      add_special_tokens=True).to(self.model.device)
                        if len(p_text) > 0:
                            decoder_input_ids = [self.model.config.decoder_start_token_id] + \
                                                self.tokenizer.convert_tokens_to_ids(p_text)
                            dec_input = torch.tensor([decoder_input_ids]).to(self.model.device)
                            feature_dict['decoder_input_ids'] = dec_input
                        else:
                            feature_dict['decoder_input_ids'] = torch.tensor(
                                [[self.model.config.decoder_start_token_id]]).to(self.model.device)
                        with torch.cuda.amp.autocast(enabled=False):
                            prediction = self.model(**feature_dict, output_hidden_states=True)
                        outputs = prediction.decoder_hidden_states[-self.unfreeze_layer_from_past].squeeze(0)
                    else:
                        if self.model.__class__.__name__ == 'DistributedBloomForCausalLM':
                            with self.model.inference_session(max_length=self.env_max_length) as sess:
                                feature_dict = self.tokenizer([[self.gat_obs_input(self.input_item), p_text_str]],
                                                              return_tensors='pt',
                                                              return_token_type_ids=False,
                                                              add_special_tokens=False).to(self.model.device)
                                embs = self.model.transformer.word_embeddings(feature_dict.input_ids)
                                embs = self.model.transformer.word_embeddings_layernorm(embs)
                                h = sess.step(embs)
                                outputs = self.model.transformer.ln_f(h[:, -1])
                        else:
                            feature_dict = self.tokenizer([[self.gat_obs_input(self.input_item), p_text_str]],
                                                          return_tensors='pt',
                                                          return_token_type_ids=False,
                                                          add_special_tokens=False).to(self.model.device)
                            prediction = self.model(**feature_dict, output_hidden_states=True)
                            outputs = prediction.hidden_states[-self.unfreeze_layer_from_past].squeeze(0)
                obs_list.append(outputs.data[-1])
            return (torch.stack(obs_list))
        
        
    def _predict(self, count):
        self.get_reward(count=count)

        


PyTorch version 2.2.1 available.


**RL Agent: Text-Instruction-Guided Voice Conversion Model**

In [6]:
from torch import autocast
from transformers import (AutoTokenizer, BartForConditionalGeneration,
                          BatchEncoding)
from vc.trainer_encodec_vc_inference import cascade_ar_nar, convert_to_encode_code,synthesize_audio
from vc.encodec_model.nar_bart_model import NARBartForConditionalGeneration
import numpy as np
import soundfile as sf
import textrl.actor

import itertools

import numpy as np
import pfrl
import torch
import torch.nn.functional as F
from pfrl.agents.ppo import _elementwise_clip
from pfrl.utils.mode_of_distribution import mode_of_distribution
from torch import autocast
from datasets import load_dataset, load_from_disk

class VcPPOAgent(pfrl.agents.PPO):
    def _update_if_dataset_is_ready(self):
        dataset_size = (
                sum(len(episode) for episode in self.memory)
                + len(self.last_episode)
                + (
                    0
                    if self.batch_last_episode is None
                    else sum(len(episode) for episode in self.batch_last_episode)
                )
        )
        if dataset_size >= self.update_interval:
            self._flush_last_episode()
            if self.recurrent:
                dataset = pfrl.agents.ppo._make_dataset_recurrent(
                    episodes=self.memory,
                    model=self.model,
                    phi=self.phi,
                    batch_states=self.batch_states,
                    obs_normalizer=self.obs_normalizer,
                    gamma=self.gamma,
                    lambd=self.lambd,
                    max_recurrent_sequence_len=self.max_recurrent_sequence_len,
                    device=self.device,
                )
                self._update_recurrent(dataset)
            else:
                dataset = pfrl.agents.ppo._make_dataset(
                    episodes=self.memory,
                    model=self.model,
                    phi=self.phi,
                    batch_states=self.batch_states,
                    obs_normalizer=self.obs_normalizer,
                    gamma=self.gamma,
                    lambd=self.lambd,
                    device=self.device,
                )
                assert len(dataset) == dataset_size
                self._update(dataset)
            self.explained_variance = self._compute_explained_variance(
                list(itertools.chain.from_iterable(self.memory))
            )
            self.memory = []

    def _compute_explained_variance(self, transitions):
        """Compute 1 - Var[return - v]/Var[return].

        This function computes the fraction of variance that value predictions can
        explain about returns.
        """
        t = np.array([tr["v_teacher"] for tr in transitions])
        y = np.array([tr["v_pred"] for tr in transitions])
        vart = np.var(t)
        if vart == 0:
            return np.nan
        else:
            return float(1 - np.var(np.average(t) - y) / vart)

    def batch_act(self, batch_obs):
        if self.training:
            return self._batch_act_train(batch_obs)
        else:
            return self._batch_act_eval(batch_obs)

    @autocast('cuda')
    def _batch_act_eval(self, batch_obs):
        assert not self.training
        b_state = self.batch_states(batch_obs, self.device, self.phi)

        if self.obs_normalizer:
            b_state = self.obs_normalizer(b_state, update=False)

        with torch.no_grad(), pfrl.utils.evaluating(self.model):
            action_distrib, _ = self.model(b_state)
            if self.act_deterministically:
                action = mode_of_distribution(action_distrib).cpu().numpy()
            else:
                action = action_distrib.sample().cpu().numpy()

        return action

    def _lossfun(
            self, entropy, vs_pred, log_probs, vs_pred_old, log_probs_old, advs, vs_teacher
    ):
        prob_ratio = torch.exp(log_probs - log_probs_old)
        loss_policy = -torch.mean(
            torch.min(
                (prob_ratio * advs),
                torch.clamp(prob_ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advs,
            ),
        )
        if self.clip_eps_vf is None:
            loss_value_func = F.mse_loss(vs_pred.squeeze(), vs_teacher.squeeze())
        else:
            clipped_vs_pred = _elementwise_clip(
                vs_pred,
                vs_pred_old - self.clip_eps_vf,
                vs_pred_old + self.clip_eps_vf,
            )
            loss_value_func = torch.mean(
                torch.max(
                    F.mse_loss(vs_pred.squeeze(), vs_teacher, reduction="none"),
                    F.mse_loss(clipped_vs_pred.squeeze(), vs_teacher, reduction="none"),
                )
            )
        loss_entropy = -torch.mean(entropy)

        self.value_loss_record.append(float(loss_value_func))
        self.policy_loss_record.append(float(loss_policy))
        loss = (
                loss_policy
                + self.value_func_coef * loss_value_func
                + self.entropy_coef * loss_entropy
        )
        return loss

class VcActor():
    def __init__(self, env, model, tokenizer,ar_checkpoint, nar_checkpoint, input_dir, output_dir, device, observation_input=[], max_length=100, compare_sample=2, gpu_id=0):
        
        self.env = env
        self.model = model
        self.tokenizer = tokenizer
        self.observation_input = observation_input
        self.ar_tokenizer = AutoTokenizer.from_pretrained(ar_checkpoint)
        self.ar_model = BartForConditionalGeneration.from_pretrained(ar_checkpoint)
        self.nar_tokenizer = AutoTokenizer.from_pretrained(nar_checkpoint)
        self.nar_model = NARBartForConditionalGeneration.from_pretrained(nar_checkpoint)
        self.gpu_id = gpu_id
        # self.device = torch.device(device)
        self.device = torch.device("cuda:{}".format(gpu_id))
        self.ar_model.to(self.device)
        self.nar_model.to(self.device)
        self.input_dir = input_dir
        self.output_dir = output_dir

    @autocast('cuda')
    def predict(self, count):
        t = 0 
        with torch.inference_mode():
            # use the model to predict the next wav
            with self.agent.eval_mode():
                obs = self.env.reset(self.observation_input)
                while True:
                    action = self.agent.act(obs)
                    obs, reward, done, pred = self.env.step(action)
                    t += 1
                    reset = t >= self.env.env_max_length
                    self.agent.observe(obs, reward, done, reset)
                    if done or reset:
                        return pred.get('predicted_str')

        # dataset = load_from_disk(self.input_dir)
        # layer_list = cascade_ar_nar(self.ar_model, self.nar_model, self.ar_tokenizer, self.nar_tokenizer, dataset, self.device)
        # encodec_code = convert_to_encode_code(self.nar_tokenizer, layer_list)    
        # audio = synthesize_audio(encodec_code, self.device)
        # output_path = f"{self.output_dir}/{count}.wav"
        # sf.write(output_path, np.ravel(audio), samplerate=24000)

    def agent_ppo(self, update_interval=10, minibatch_size=3000, epochs=20, lr=3e-6):
        policy = torch.nn.Sequential(
            self.middle_model,
            self.remaining_model,
            self.converter,
            textrl.actor.SoftmaxCategoricalHead(self.env,
                                   temperature=self.temperature,
                                   top_k=self.top_k,
                                   top_p=self.top_p)
        )
        vf = torch.nn.Sequential(
            torch.nn.Linear(self.obs_size, self.obs_size // 2),
            torch.nn.Linear(self.obs_size // 2, self.obs_size // 4),
            torch.nn.Linear(self.obs_size // 4, 1)
        )
        model = pfrl.nn.Branched(policy, vf)
        if isinstance(self.optimizer, str):
            if self.optimizer.lower() == 'adamw':
                opt = torch.optim.AdamW(model.parameters(), lr=lr)
            else:
                opt = torch.optim.SGD(model.parameters(), lr=lr)
        else:
            opt = self.optimizer
        model = model.cuda()
        agent = VcPPOAgent(
            model,
            opt,
            gpu=self.gpu_id,
            update_interval=update_interval,
            minibatch_size=minibatch_size,
            epochs=epochs,
            clip_eps_vf=None,
            entropy_coef=0,
            gamma=0.95,  # https://arxiv.org/abs/2210.01241
            lambd=1,
            max_grad_norm=1.0,
            standardize_advantages=True,
            act_deterministically=self.act_deterministically
        )
        self.agent = agent
        return agent


In [7]:
import random
import numpy as np
import soundfile as sf
import torch
from datasets import load_dataset, load_from_disk
from encodec import EncodecModel
from vc.encodec_model.nar_bart_model import NARBartForConditionalGeneration
from argparse import ArgumentParser, Namespace
from transformers import (AutoTokenizer, BartForConditionalGeneration,
                          BatchEncoding)
import vc.trainer_encodec_vc_inference as vc_inference
from types import SimpleNamespace

args = SimpleNamespace(
    dataset="lca0503/soxdata_small_encodec",
    splits=["train"],
    ground_truth_only=False,
    cascade_ar_nar=True,
    nar_model_only=False,
    ground_truth_model_name="voidful/bart-base-unit",
    ar_checkpoint="lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans",
    nar_checkpoint="lca0503/speech-chatgpt-base-nar-v2-epoch4-wotrans",
    ground_truth_output_path="output_wav/vc/ground_truth/train_1.wav",
    cascade_output_path="output_wav/vc/ar_nar_cascade/train_1.wav",
    nar_output_path="output_wav/vc/nar/train_1.wav",
    seed=0,
    device="cuda"
)

# modify ar_checkpoint in args
args.ar_checkpoint = "lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans"




**Load Datasets**

In [8]:
# %pip install datasets
# from datasets import load_from_disk ,load_dataset

# dataset = load_dataset("lca0503/soxdata_encodec")
# dataset.save_to_disk("data")

# dataset = load_dataset("lca0503/soxdata_encodec", split="+".join(["train"]))
# dataset = dataset.filter(lambda x : len(x[f"src_encodec_0"]) <= 700)
# dataset = dataset.shuffle(0).select(range(1))

# dataset.save_to_disk("data-encodec")
# dataset = load_from_disk("data-encodec")

**Start Training**

*1. Agent to Environment*

In [9]:
# define path
agent_input_dir = f'{base_path}/data-encodec'
agent_output_dir = f'{base_path}/output'
env_input_dir = agent_output_dir
env_output_dir = agent_input_dir

ar_checkpoint = "lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans"
nar_checkpoint = "lca0503/speech-chatgpt-base-nar-v2-epoch4-wotrans"

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
ar_tokenizer = AutoTokenizer.from_pretrained(ar_checkpoint)
ar_model = BartForConditionalGeneration.from_pretrained(ar_checkpoint)
ar_model.to(device)

dataset = load_from_disk(agent_input_dir)
instruction_ids = ar_tokenizer(dataset["instruction"][0])["input_ids"][1 : -1]
transcription_ids = ar_tokenizer(dataset["transcription"][0])["input_ids"][1 : -1]
instruction = dataset["instruction"][0]
transcription = dataset["transcription"][0]

print("Instruction: ", instruction)
print("Transcription: ", transcription)

# for i in range(len(instruction_ids)):
#     print("Instruction(cascade): ", ar_tokenizer.decode(instruction_ids[i]))
# for i in range(len(transcription_ids)):
#     print("Transcription(cascade): ", ar_tokenizer.decode(transcription_ids[i]))
    
observation_list = [{'input': 0, 'transcription': transcription, 'instruction': instruction, 'dataset': dataset}]

Instruction:  Play the audio twice.
Transcription:  There is even a white row of beehives in the orchard, under the walnut trees.


In [11]:
env = VcRLEnv(model, tokenizer, observation_list, 100, 2, 1, env_input_dir, env_output_dir, instruction, transcription)
actor = VcActor(env, ar_model, ar_tokenizer, ar_checkpoint, nar_checkpoint, agent_input_dir, agent_output_dir, device, observation_list, 100, 2)
actor.predict(0)

Dataset Dataset({
    features: ['file_id', 'instruction', 'transcription', 'src_encodec_0', 'src_encodec_1', 'src_encodec_2', 'src_encodec_3', 'src_encodec_4', 'src_encodec_5', 'src_encodec_6', 'src_encodec_7', 'tgt_encodec_0', 'tgt_encodec_1', 'tgt_encodec_2', 'tgt_encodec_3', 'tgt_encodec_4', 'tgt_encodec_5', 'tgt_encodec_6', 'tgt_encodec_7'],
    num_rows: 4982
})
[20780, 5, 6086, 2330, 4]
[970, 16, 190, 10, 1104, 3236, 9, 28, 14897, 3699, 11, 5, 50, 15782, 6, 223, 5, 21788, 10873, 3980, 4]
cuda




In [12]:
# define env, actor, and agent
# actor = TextRLActor(env,model,tokenizer)
env = VcRLEnv(model, tokenizer, observation_list, 100, 2, 1, env_input_dir, env_output_dir, instruction, transcription)
actor = VcActor(env, ar_model, ar_tokenizer, ar_checkpoint, nar_checkpoint, agent_input_dir, agent_output_dir, device, observation_list, 100, 2)

for i in range(10):
    print("Step: ", i)
    actor.predict(i)
    env.step(i)
    actor.agent_ppo(update_interval=100, minibatch_size=3, epochs=10)
    
# agent = vc_inference.run(args, agent_input_dir, agent_output_path)

Step:  0
Dataset Dataset({
    features: ['file_id', 'instruction', 'transcription', 'src_encodec_0', 'src_encodec_1', 'src_encodec_2', 'src_encodec_3', 'src_encodec_4', 'src_encodec_5', 'src_encodec_6', 'src_encodec_7', 'tgt_encodec_0', 'tgt_encodec_1', 'tgt_encodec_2', 'tgt_encodec_3', 'tgt_encodec_4', 'tgt_encodec_5', 'tgt_encodec_6', 'tgt_encodec_7'],
    num_rows: 4982
})
[20780, 5, 6086, 2330, 4]
[970, 16, 190, 10, 1104, 3236, 9, 28, 14897, 3699, 11, 5, 50, 15782, 6, 223, 5, 21788, 10873, 3980, 4]
cuda




Device: cuda
Model architecture: NISQA_DIM
Loaded pretrained model from /work/b0990106x/TextRL/NISQA/weights/nisqa.tar




count 0 reward 2.0097849369049072
[INFO] It took 0.11599326133728027 seconds to process the file.


Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

AttributeError: 'VcActor' object has no attribute 'middle_model'

In [None]:
agent = actor.agent_ppo(update_interval=100, minibatch_size=3, epochs=10)
pfrl.experiments.train_agent_with_evaluation(
    agent,
    env,
    steps=300,
    eval_n_steps=None,
    eval_n_episodes=1,       
    train_max_episode_len=100,  
    eval_interval=10,
    outdir='elon_musk_dogecoin', 
)

In [None]:
# actor.predict(observaton_list[0])

*2. Environment to Agent*

In [None]:
# reward = env.get_reward()
# print("reward", reward)
env._predict()

In [None]:
# output of agent (wav) + instruction + transcription