In [1]:
import torch
from datasets import load_from_disk
from vc.encodec_model.nar_bart_model import NARBartForConditionalGeneration
from vc.trainer_encodec_vc_inference import pack_inputs_v2
from transformers import AutoTokenizer, BartForConditionalGeneration

# load the model
ar_checkpoint = "lca0503/speech-chatgpt-base-ar-v2-epoch10-wotrans"
nar_checkpoint = "lca0503/speech-chatgpt-base-nar-v2-epoch4-wotrans"

device = "cuda" if torch.cuda.is_available() else "cpu"
ar_tokenizer = AutoTokenizer.from_pretrained(ar_checkpoint)
ar_model = BartForConditionalGeneration.from_pretrained(ar_checkpoint)
nar_tokenizer = AutoTokenizer.from_pretrained(nar_checkpoint)
nar_model = NARBartForConditionalGeneration.from_pretrained(nar_checkpoint)
ar_model.to(device)

  from .autonotebook import tqdm as notebook_tqdm


BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(59481, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(59481, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=

In [2]:
from datetime import datetime
import os

now = datetime.now()
ts = now.strftime("%m%d-%H%M")
print("timestamp:", ts)

# define the path
base_path = "/work/b0990106x/trl"
agent_input_dir = f"{base_path}/data-encodec"
agent_output_dir = f"{base_path}/output/{ts}"
env_input_dir = agent_output_dir
env_output_dir = agent_input_dir

# if not os.path.exists(agent_output_dir):
#     os.makedirs(agent_output_dir)

timestamp: 0613-1859


In [3]:
# load the dataset
dataset = load_from_disk(agent_input_dir)


In [4]:
print(dataset)
data_len = len(dataset)
print("data_len:", data_len)


Dataset({
    features: ['file_id', 'instruction', 'transcription', 'src_encodec_0', 'src_encodec_1', 'src_encodec_2', 'src_encodec_3', 'src_encodec_4', 'src_encodec_5', 'src_encodec_6', 'src_encodec_7', 'tgt_encodec_0', 'tgt_encodec_1', 'tgt_encodec_2', 'tgt_encodec_3', 'tgt_encodec_4', 'tgt_encodec_5', 'tgt_encodec_6', 'tgt_encodec_7'],
    num_rows: 9957
})
data_len: 9957


In [5]:
import json
import sys
from types import SimpleNamespace
sys.path.append("/work/b0990106x/trl/vc") 
from vc.trainer_encodec_vc_inference import get_ar_prediction_v2, get_ar_prediction

# Assuming `pack_inputs_v2` and `ar_tokenizer` are already defined

observation_list = []
decode_obs_input_str = []
all_src_encodec_layers = []

all_src_encodec = []
all_instruction = []
all_tgt_encodec = []

all_tgt_encodec_layers = []

layer_len = 8
# data_len = len(dataset)  # Assuming you have a defined `data_len`



# size_of_packed_input = (len(single_src_encodec[0]) + len(ar_tokenizer(single_instruction)["input_ids"][1:-1])+ 3)

# if size_of_packed_input > 1024 or size_of_packed_input < 4:
#     print(
#         f"Notice: Packed input size too large or too small for processing: {size_of_packed_input} elements. Instruction: '{single_instruction}'"
#     )
#     continue  # Continue to select a new random item



for i in range(layer_len):
    all_src_encodec_layers.append(dataset[f"src_encodec_{i}"])
    all_tgt_encodec_layers.append(dataset[f"tgt_encodec_{i}"])

for i in range(data_len):
    src_encodec = []
    tgt_encodec = []
    for j in range(layer_len):
        src_encodec.append(all_src_encodec_layers[j][i])
        tgt_encodec.append(all_tgt_encodec_layers[j][i])
    all_src_encodec.append(src_encodec)
    all_tgt_encodec.append(tgt_encodec)
    all_instruction.append(dataset["instruction"][i])
    

for i in range(data_len):
    observation_list.append(
        {
            "input": "",
            "src_encodec": [all_src_encodec_layers[j][i] for j in range(layer_len)],
            "instruction": all_instruction[i],
            "tgt_encodec": [all_tgt_encodec_layers[j][i] for j in range(layer_len)],
        }
    )


In [6]:
# List to hold the 'prompt' values
prompts = []
chosen = []
rejected = []

for obs in observation_list:
    obs_input = pack_inputs_v2(ar_tokenizer, obs["src_encodec"], obs["instruction"])
    tgt_encodec = obs["tgt_encodec"]
    
    tokenize_tgt_encodec = ar_tokenizer.convert_tokens_to_string(
                [f"v_tok_{u}" for u in tgt_encodec[0]]
            )  
    tokenize_input= ar_tokenizer.convert_ids_to_tokens(obs_input)
    tokenize_input_str = ar_tokenizer.convert_tokens_to_string(tokenize_input)
    prompts.append(tokenize_input_str)
    chosen.append(tokenize_tgt_encodec)

args = SimpleNamespace(output_path=f"{base_path}/output/{ts}/example.wav", seed=0, device=device)

for i in range(data_len):
    single_src_encodec = all_src_encodec[i]
    single_instruction = all_instruction[i]
    try:
        predicted_ids, decode_ar = get_ar_prediction(args, ar_model, nar_model, ar_tokenizer, nar_tokenizer, single_src_encodec, single_instruction, episode_counter=0)
    except Exception as e:
        print("i:", i)
        print("single_src_encodec:", single_src_encodec)
        print("single_instruction:", single_instruction)
        print(e)
        break
    decode_ar_list = decode_ar.flatten().tolist()
    decode_ar_tokens = ar_tokenizer.convert_ids_to_tokens(decode_ar_list)
    decode_ar_str = ar_tokenizer.convert_tokens_to_string(
                [f"v_tok_{u}" for u in predicted_ids]
            ) 
    rejected.append(decode_ar_str)


# Construct the JSON structure
data = {
    "prompt": prompts,
    "chosen": chosen,  # Placeholder for chosen responses
    "rejected": rejected  # Placeholder for rejected responses
}

# Save the JSON to a file
with open("dpo_data_all.json", "w") as outfile:
    json.dump(data, outfile, indent=4)