In [1]:
import torch
from datasets import load_from_disk
from vc.encodec_model.nar_bart_model import NARBartForConditionalGeneration
from vc.trainer_encodec_vc_inference import pack_inputs_v2, get_ar_prediction_v2
from transformers import AutoTokenizer, BartForConditionalGeneration

# load the model
ar_checkpoint = "lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans"
nar_checkpoint = "lca0503/speech-chatgpt-base-nar-v2-epoch4-wotrans"

device = "cuda" if torch.cuda.is_available() else "cpu"
ar_tokenizer = AutoTokenizer.from_pretrained(ar_checkpoint)
ar_model = BartForConditionalGeneration.from_pretrained(ar_checkpoint)
nar_tokenizer = AutoTokenizer.from_pretrained(nar_checkpoint)
nar_model = NARBartForConditionalGeneration.from_pretrained(nar_checkpoint)
ar_model.to(device)

  from .autonotebook import tqdm as notebook_tqdm


BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(59481, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(59481, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=

In [2]:
from datetime import datetime
import os

now = datetime.now()
ts = now.strftime("%m%d-%H%M")
print("timestamp:", ts)

# define the path
base_path = "/work/b0990106x/trl"
agent_input_dir = f"{base_path}/data-encodec"
agent_output_dir = f"{base_path}/output/{ts}"

timestamp: 0619-2002


In [3]:
# load the dataset
dataset = load_from_disk(agent_input_dir)


In [4]:
print(dataset)
data_len = len(dataset)
print("data_len:", data_len)


Dataset({
    features: ['file_id', 'instruction', 'transcription', 'src_encodec_0', 'src_encodec_1', 'src_encodec_2', 'src_encodec_3', 'src_encodec_4', 'src_encodec_5', 'src_encodec_6', 'src_encodec_7', 'tgt_encodec_0', 'tgt_encodec_1', 'tgt_encodec_2', 'tgt_encodec_3', 'tgt_encodec_4', 'tgt_encodec_5', 'tgt_encodec_6', 'tgt_encodec_7'],
    num_rows: 9957
})
data_len: 9957


In [5]:
import json
import sys
from types import SimpleNamespace
sys.path.append("/work/b0990106x/trl/vc") 
from vc.trainer_encodec_vc_inference import get_ar_prediction_v2, get_ar_prediction, get_ar_prediction_for_data
from tqdm import tqdm  # Import tqdm for progress bars

# Assuming `pack_inputs_v2` and `ar_tokenizer` are already defined

observation_list = []
decode_obs_input_str = []
all_src_encodec_layers = []

all_src_encodec = []
all_instruction = []
all_tgt_encodec = []

all_tgt_encodec_layers = []
layer_len = 8

for i in range(layer_len):
    all_src_encodec_layers.append(dataset[f"src_encodec_{i}"])
    all_tgt_encodec_layers.append(dataset[f"tgt_encodec_{i}"])

for i in range(data_len):
    src_encodec = []
    tgt_encodec = []
    for j in range(layer_len):
        src_encodec.append(all_src_encodec_layers[j][i])
        tgt_encodec.append(all_tgt_encodec_layers[j][i])
    all_src_encodec.append(src_encodec)
    all_tgt_encodec.append(tgt_encodec)
    all_instruction.append(dataset["instruction"][i])

    
    size_of_packed_input = (len(all_src_encodec[i][0]) + len(ar_tokenizer(all_instruction[i])["input_ids"][1:-1]) + 3)
    # print("size_of_packed_input:", size_of_packed_input)
    if size_of_packed_input <= 1024 or size_of_packed_input < 4:
        observation_list.append(
            {
                "input": "",
                "src_encodec": [all_src_encodec_layers[j][i] for j in range(layer_len)],
                "instruction": all_instruction[i],
                "tgt_encodec": [all_tgt_encodec_layers[j][i] for j in range(layer_len)],
            }
        )
    else:
        print(f"Notice: Packed input size too large for processing: {size_of_packed_input} elements. Instruction: '{all_instruction[i]}'")


In [6]:
from NISQA.nisqa.NISQA_model import nisqaModel
# Function to calculate reward
def get_reward(output_path):
    args_nisqa = {
        "mode": "predict_file",
        "pretrained_model": f"{base_path}/NISQA/weights/nisqa.tar",
        "deg": output_path,
        "data_dir": None,
        "output_dir": f"{base_path}/NISQA/result/",
        "csv_file": None,
        "csv_deg": None,
        "num_workers": 0,
        "bs": 1,
        "ms_channel": None,
    }
    args_nisqa["tr_bs_val"] = args_nisqa["bs"]
    args_nisqa["tr_num_workers"] = args_nisqa["num_workers"]
    nisqa = nisqaModel(args_nisqa)
    try:
        prediction = nisqa.predict()
        reward = float(prediction["mos_pred"].iloc[0])
        print("Reward:", reward)
        return reward
    except Exception as e:
        print("Error:", e)
        print("get_reward function end ___________________________")
        return None


In [7]:
# List to hold the 'prompt' values
prompts = []
chosen = []
rejected = []

args = SimpleNamespace(output_path=f"{base_path}/output/{ts}/example.wav", seed=0, device=device)

for obs in tqdm(observation_list, desc="Processing observations"):
    # for prompt
    obs_input = pack_inputs_v2(ar_tokenizer, obs["src_encodec"], obs["instruction"])
    tokenize_input= ar_tokenizer.convert_ids_to_tokens(obs_input)
    tokenize_input_str = ar_tokenizer.convert_tokens_to_string(tokenize_input)
    prompts.append(tokenize_input_str)
    
    # for chosen
    tgt_encodec = obs["tgt_encodec"]
    # tgt_ids = ar_tokenizer.convert_tokens_to_ids([f"v_tok_{u}" for u in tgt_encodec[0]])
    tokenize_tgt_encodec = ar_tokenizer.convert_tokens_to_string(
                [f"v_tok_{u}" for u in tgt_encodec[0]]
            ) 
    # tgt_ids = ar_tokenizer(tokenize_tgt_encodec)["input_ids"][1:-1]
    # print("tokenize_tgt_encodec:", tokenize_tgt_encodec)   
    chosen.append(tokenize_tgt_encodec)

    # for rejected
    single_src_encodec = obs["src_encodec"]
    single_instruction = obs["instruction"]
    try:
        decode_ar = get_ar_prediction_for_data(args, ar_model, ar_tokenizer, single_src_encodec, single_instruction)
    except Exception as e:
        print("single_src_encodec:", single_src_encodec)
        print("single_instruction:", single_instruction)
        print(e)
        break
    decode_ar_list = decode_ar.flatten().tolist()
    filtered_decode_ar_list = decode_ar_list[2:-1]
    decode_ar_tokens = ar_tokenizer.convert_ids_to_tokens(filtered_decode_ar_list)
    tokenized_decode_ar = ar_tokenizer.convert_tokens_to_string(decode_ar_tokens)
    # print("tokenized_decode_ar:", tokenized_decode_ar)
    rejected.append(tokenized_decode_ar)

    # # TEST REWARD
    # args_predict = SimpleNamespace(output_path=f"{base_path}/output/{ts}/example.wav", seed=0, device=device)
    # temp1 = get_ar_prediction_v2(args_predict, tgt_ids, nar_model, ar_tokenizer, nar_tokenizer, single_src_encodec, single_instruction, 0)
    # print("GOOD:")
    # get_reward(args_predict.output_path)

    # temp2 = get_ar_prediction_v2(args_predict, filtered_decode_ar_list , nar_model, ar_tokenizer, nar_tokenizer, single_src_encodec, single_instruction,2)
    # print("BAD")
    # get_reward(args_predict.output_path)
# Construct the JSON structure
data = {
    "prompt": prompts,
    "chosen": chosen,  # Placeholder for chosen responses
    "rejected": rejected  # Placeholder for rejected responses
}

# Save the JSON to a file
with open("dpo_data_all_v2.json", "w") as outfile:
    json.dump(data, outfile, indent=4)

Processing observations:   0%|          | 0/3 [00:00<?, ?it/s]

tokenize_tgt_encodec: v_tok_408v_tok_835v_tok_835v_tok_798v_tok_585v_tok_550v_tok_535v_tok_535v_tok_737v_tok_737v_tok_377v_tok_556v_tok_601v_tok_787v_tok_8v_tok_99v_tok_411v_tok_411v_tok_378v_tok_937v_tok_378v_tok_937v_tok_804v_tok_838v_tok_890v_tok_934v_tok_47v_tok_438v_tok_438v_tok_731v_tok_738v_tok_133v_tok_709v_tok_479v_tok_479v_tok_479v_tok_151v_tok_940v_tok_502v_tok_906v_tok_407v_tok_645v_tok_70v_tok_208v_tok_537v_tok_537v_tok_1022v_tok_681v_tok_723v_tok_747v_tok_593v_tok_804v_tok_681v_tok_879v_tok_136v_tok_967v_tok_233v_tok_431v_tok_754v_tok_421v_tok_182v_tok_182v_tok_651v_tok_879v_tok_887v_tok_819v_tok_904v_tok_904v_tok_887v_tok_309v_tok_880v_tok_396v_tok_754v_tok_775v_tok_997v_tok_222v_tok_336v_tok_548v_tok_841v_tok_269v_tok_479v_tok_479v_tok_940v_tok_23v_tok_56v_tok_738v_tok_835v_tok_395v_tok_206v_tok_779v_tok_531v_tok_862v_tok_931v_tok_306v_tok_203v_tok_755v_tok_369v_tok_6v_tok_466v_tok_716v_tok_948v_tok_82v_tok_575v_tok_288v_tok_556v_tok_903v_tok_556v_tok_392v_tok_796v_tok_

Processing observations:  33%|███▎      | 1/3 [00:04<00:09,  4.60s/it]

tokenize_tgt_encodec: v_tok_408v_tok_835v_tok_835v_tok_126v_tok_276v_tok_677v_tok_666v_tok_460v_tok_996v_tok_682v_tok_924v_tok_448v_tok_816v_tok_704v_tok_976v_tok_855v_tok_347v_tok_1021v_tok_552v_tok_301v_tok_301v_tok_448v_tok_186v_tok_611v_tok_574v_tok_438v_tok_438v_tok_133v_tok_430v_tok_856v_tok_877v_tok_800v_tok_830v_tok_997v_tok_997v_tok_358v_tok_850v_tok_233v_tok_233v_tok_846v_tok_951v_tok_666v_tok_62v_tok_62v_tok_834v_tok_835v_tok_63v_tok_602v_tok_887v_tok_935v_tok_339v_tok_834v_tok_430v_tok_835v_tok_492v_tok_699v_tok_462v_tok_465v_tok_1018v_tok_642v_tok_143v_tok_953v_tok_424v_tok_598v_tok_860v_tok_724v_tok_999v_tok_887v_tok_835v_tok_803v_tok_151v_tok_791v_tok_25v_tok_875v_tok_424v_tok_901v_tok_813v_tok_424v_tok_598v_tok_321v_tok_679v_tok_457v_tok_901v_tok_1000v_tok_552v_tok_695v_tok_695v_tok_501v_tok_432v_tok_876v_tok_1019v_tok_910v_tok_839v_tok_261v_tok_481v_tok_481v_tok_695v_tok_695v_tok_51v_tok_162v_tok_904v_tok_865v_tok_629v_tok_62v_tok_835v_tok_835v_tok_372v_tok_764v_tok_74

Processing observations:  67%|██████▋   | 2/3 [00:07<00:03,  3.75s/it]

tokenize_tgt_encodec: v_tok_408v_tok_835v_tok_339v_tok_339v_tok_604v_tok_324v_tok_230v_tok_600v_tok_771v_tok_422v_tok_846v_tok_747v_tok_457v_tok_393v_tok_833v_tok_782v_tok_411v_tok_411v_tok_479v_tok_1021v_tok_906v_tok_151v_tok_495v_tok_563v_tok_611v_tok_611v_tok_151v_tok_727v_tok_317v_tok_347v_tok_475v_tok_835v_tok_835v_tok_835v_tok_339v_tok_475v_tok_339v_tok_123v_tok_254v_tok_103v_tok_182v_tok_784v_tok_912v_tok_755v_tok_375v_tok_261v_tok_435v_tok_951v_tok_323v_tok_709v_tok_819v_tok_475v_tok_339v_tok_835v_tok_779v_tok_257v_tok_339v_tok_341v_tok_254v_tok_38v_tok_38v_tok_103v_tok_121v_tok_62v_tok_141v_tok_731v_tok_73v_tok_651v_tok_563v_tok_321v_tok_860v_tok_325v_tok_325v_tok_679v_tok_696v_tok_582v_tok_613v_tok_216v_tok_683v_tok_291v_tok_11v_tok_862v_tok_627v_tok_666v_tok_764v_tok_679v_tok_291v_tok_501v_tok_451v_tok_501v_tok_198v_tok_112v_tok_392v_tok_348v_tok_793v_tok_793v_tok_11v_tok_192v_tok_23v_tok_402v_tok_1022v_tok_276v_tok_73v_tok_73v_tok_887v_tok_25v_tok_103v_tok_148v_tok_148v_tok

Processing observations: 100%|██████████| 3/3 [00:09<00:00,  3.08s/it]


In [8]:
# import json
# if not os.path.exists(agent_output_dir):
#     os.makedirs(agent_output_dir)

# with open("dpo_data_1.json") as f:
#     dpo_data = json.load(f)

# # Step 2: Parse the JSON content
# prompts = dpo_data['prompt']
# chosen = dpo_data['chosen']
# rejected = dpo_data['rejected']
# args_predict = SimpleNamespace(output_path=f"{base_path}/output/{ts}/example.wav", seed=0, device=device)

# # Step 3: Access each item in the data

# chosen_data = chosen[0]
# rejected_data = rejected[0]
# single_src_encodec = all_src_encodec[0]
# single_instruction = all_instruction[0]

# print("single_src_encodec:", single_src_encodec)
# print("single_instruction:", single_instruction)

# chosen_data_id = ar_tokenizer(chosen_data)["input_ids"][1:-1]
# rejected_data_id = ar_tokenizer(rejected_data)["input_ids"][1:-1]

# print("chosen_data:", chosen_data)
# print("rejected_data:", rejected_data)
# print("chosen_data_id:", chosen_data_id)
# print("rejected_data_id:", rejected_data_id)

# temp1 = get_ar_prediction_v2(args_predict, chosen_data_id, nar_model, ar_tokenizer, nar_tokenizer, single_src_encodec, single_instruction, 0)
# print("GOOD:")
# get_reward(args_predict.output_path)

# temp2 = get_ar_prediction_v2(args_predict, rejected_data_id , nar_model, ar_tokenizer, nar_tokenizer, single_src_encodec, single_instruction, 1)
# print("BAD")
# get_reward(args_predict.output_path)


single_src_encodec: [[408, 835, 835, 798, 585, 550, 535, 535, 737, 737, 377, 556, 601, 787, 8, 99, 411, 411, 378, 937, 378, 937, 804, 838, 890, 934, 47, 438, 438, 731, 738, 133, 709, 479, 479, 479, 151, 940, 502, 906, 407, 645, 70, 208, 537, 537, 1022, 681, 723, 747, 593, 804, 681, 879, 136, 967, 233, 431, 754, 421, 182, 182, 651, 879, 887, 819, 904, 904, 887, 309, 880, 396, 754, 775, 997, 222, 336, 548, 841, 269, 479, 479, 940, 23, 56, 738, 835, 395, 206, 779, 531, 862, 931, 306, 203, 755, 369, 6, 466, 716, 948, 82, 575, 288, 556, 903, 556, 392, 796, 751, 835, 103, 25, 408, 835, 835, 339, 339, 395, 250, 706, 317, 479, 800, 960, 141, 479, 908, 801, 327, 937, 559, 708, 372, 372, 573, 437, 437, 421, 203, 739, 830, 739, 358, 830, 248, 411, 411, 112, 321, 23, 23, 185, 971, 62, 339, 461, 488, 934, 148, 373, 561, 681, 760, 531, 612, 699, 23, 967, 457, 790, 154, 906, 465, 502, 884, 479, 246, 820, 601, 309, 716, 314, 377, 309, 309, 556, 118, 99, 358, 1018, 862, 779, 62, 835, 25, 254, 254, 677,



Episode 0 : audio saved to  /work/b0990106x/trl/output/0619-2002/example_save_0.wav
GOOD:


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Reward: 3.2222368717193604




Episode 1 : audio saved to  /work/b0990106x/trl/output/0619-2002/example_save_1.wav
BAD
Reward: 2.7825844287872314


  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


2.7825844287872314