In [35]:
import json

hyps = []
refs = []

def extract_multiple_predictions_to_tsv(json_file_path, tsv_file_path):
    """
    Reads a file containing multiple lines of JSON, each containing a prediction and a reference,
    and writes these to a TSV file with 'hyp' and 'tgt' columns.

    Args:
    - json_file_path (str): The path of the file containing the JSON lines.
    - tsv_file_path (str): The path where the TSV file will be saved.

    Returns:
    - None
    """
    with open(json_file_path, 'r', encoding='utf-8') as json_file, \
         open(tsv_file_path, 'w', encoding='utf-8') as tsv_file:
        # Write the headers to the TSV file
        tsv_file.write("hyp\ttgt\n")

        for line in json_file:
            data = json.loads(line)  

            prediction = data["prediction"]
            reference = data["reference"]

            hyps.append(prediction)
            refs.append(reference)

            tsv_file.write(f"{prediction}\t{reference}\n")

# Example usage:
json_file_path = "/mnt/taurus/data1/xixu/runs/sllama/wavlm_clean/stage2/checkpoint-1900/sanity_check/320ms-1001/instances.log"  
tsv_file_path = "sanity_check.tsv"  

# Call the function with the paths
extract_multiple_predictions_to_tsv(json_file_path, tsv_file_path)

print(f"Data has been written to {tsv_file_path}")

Data has been written to sanity_check.tsv


In [36]:
import csv
import sacrebleu

def calculate_bleu_from_tsv(tsv_file_path):
    """
    Calculates the BLEU score for hypotheses and references contained in a TSV file.

    Args:
    - tsv_file_path (str): The path to the TSV file with 'hyp' and 'tgt' columns.

    Returns:
    - The BLEU score as a float.
    """

    global hyps, refs

    # hyps = []  # To store all hypotheses
    # refs = []  # To store all references

    # # Read the hypotheses and references from the TSV file
    # with open(tsv_file_path, 'r', encoding='utf-8') as tsv_file:
    #     reader = csv.DictReader(tsv_file, delimiter='\t')
    #     for row in reader:
    #         hyps.append(row['hyp'])
    #         refs.append('' if row['tgt'] is None else row['tgt'])  # Note: refs must be a list of lists

    # Calculate the BLEU score
    bleu_score = sacrebleu.corpus_bleu(hyps, [refs])

    return bleu_score.score

# Example usage:
tsv_file_path = "sanity_check.tsv"  # The path to your TSV file
bleu_score = calculate_bleu_from_tsv(tsv_file_path)

print(f"BLEU score: {bleu_score}")

BLEU score: 26.785563885367445


In [3]:
import csv
import subprocess
from fairseq import scoring


def prepare_data_for_fairseq(tsv_file_path, hyp_file_path, ref_file_path):
    """
    Splits the TSV file into separate hypothesis and reference files for fairseq evaluation,
    handling cases where the data might be missing.

    Args:
    - tsv_file_path (str): Path to the TSV file with 'hyp' and 'tgt' columns.
    - hyp_file_path (str): Path to save the hypothesis file.
    - ref_file_path (str): Path to save the reference file.
    """
    with open(tsv_file_path, 'r', encoding='utf-8') as tsv_file, \
         open(hyp_file_path, 'w', encoding='utf-8') as hyp_file, \
         open(ref_file_path, 'w', encoding='utf-8') as ref_file:
        reader = csv.DictReader(tsv_file, delimiter='\t')
        for row in reader:
            # Ensure that 'hyp' and 'tgt' are treated as empty strings if they are None
            hyp_text = row['hyp'] if row['hyp'] is not None else ""
            tgt_text = row['tgt'] if row['tgt'] is not None else ""
            hyp_file.write(hyp_text + '\n')
            ref_file.write(tgt_text + '\n')


def calculate_bleu_with_fairseq_python(hyp_file_path, ref_file_path):
    """
    Calculates the BLEU score using Fairseq's Python API.

    Args:
    - hyp_file_path (str): Path to the hypothesis file.
    - ref_file_path (str): Path to the reference file.

    Returns:
    - The BLEU score as a string.
    """
    # Initialize the scorer
    scorer = scoring.build_scorer("sacrebleu", None)
    
    # Read hypotheses and references
    with open(hyp_file_path, 'r', encoding='utf-8') as hyp_file, \
         open(ref_file_path, 'r', encoding='utf-8') as ref_file:
        hyps = [line.strip() for line in hyp_file]
        refs = [line.strip() for line in ref_file]
    
    for hyp, ref in zip(hyps, refs):
        scorer.add_string(ref, hyp)
    
    return scorer.result_string()



tsv_file_path = "sanity_check.tsv"
hyp_file_path = "hypotheses.txt"
ref_file_path = "references.txt"

prepare_data_for_fairseq(tsv_file_path, hyp_file_path, ref_file_path)

bleu_score_output = calculate_bleu_with_fairseq_python(hyp_file_path, ref_file_path)
print(bleu_score_output)

[2024-04-11 15:58:39,998] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)




BLEU = 28.10 61.8/35.5/22.8/15.2 (BP = 0.950 ratio = 0.952 hyp_len = 49110 ref_len = 51605)


In [13]:
import conversation as conversation_lib
import transformers

conv = conversation_lib.default_conversation.copy()
conv.messages = []
conv.append_message(conv.roles[0], None)
conv.append_message(conv.roles[1], None)
prompt_inputs = conv.get_prompt()
print(prompt_inputs)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    '/mnt/data1/xixu/runs/sllama/wavlm_clean/stage2/checkpoint-1900',
    padding_side="right",
    use_fast=False,
)
input_ids1 = tokenizer([prompt_inputs])
input_ids = tokenizer(prompt_inputs, return_tensors="pt", padding=True, truncation=True)
print(input_ids1)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


You are a large language and speech assistant.You are able to understand the speech content that the user provides, and assist the user with a variety of tasks using natural language.Follow the instructions carefully and explain your answers in detail. USER:ASSISTANT:
{'input_ids': [[1, 887, 526, 263, 2919, 4086, 322, 12032, 20255, 29889, 3492, 526, 2221, 304, 2274, 278, 12032, 2793, 393, 278, 1404, 8128, 29892, 322, 6985, 278, 1404, 411, 263, 12875, 310, 9595, 773, 5613, 4086, 29889, 29943, 2952, 278, 11994, 16112, 322, 5649, 596, 6089, 297, 9493, 29889, 3148, 1001, 29901, 22933, 9047, 13566, 29901]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
