In [None]:
import json
import pandas as pd
from tqdm import tqdm
from datasets import Dataset
import pandas as pd
from transformers import AutoTokenizer

import sys
sys.path.append('.')
from utils.data import extract_text_and_spk, format_data


Filter data to only include the training set.

In [None]:
# load testing ids
with open('./data/fisher.txt', 'r') as file:
    fids = file.read().split('\n')
    fids = [x for x in fids if x]

# load the json file with the processed data
with open('./data/processed_data.json',
            'r') as file:
        res = json.load(file)



In [None]:
# remove the utterances that are in the testing set
res2 = {'utterances': [x for x in res['utterances'] if x['utterance_id'] not in fids]}

with open('./data/processed_data_train.json', 'w') as file:
    json.dump(res2, file)


Next we need to split the data into prompts and completions. We will use the `train_data_prep.py` script to do this. The script will take the input data and output prompts and completions in the format required by the model. The script will also add the speaker information to the prompts and completions from hyp_spk and hyp_spk_oracle fields in the input data.

```shell
python3 train_data_prep.py \
--input="./data/processed_data_train.json" \
--output="./data/prompts_train.jsonl" \
--output_type=jsonl \
--emit_input_length=2500 \
--emit_target_length=2500 \
--prompt_suffix="" \
--completion_suffix="" \
--input_feature_key="prompt" \
--output_feature_key="completion" \
--text_field="hyp_text" \
--input_speaker_field="hyp_spk" \
--target_speaker_field="hyp_spk_oracle" \
--speaker_prefix="<spk:"
```

In [None]:
# load the train_data
## file is a jsonl with a dictionary in each line
with open('./data/prompts_train.jsonl', 'r') as file:
    train_data = [json.loads(x) for x in file]


The next steps are needed to preprocess the data for training:

1. Remove prompts and completions that have a repeated word/phtrase issue.
2. Convert to the instruction format required by the model.
3. Tokenize the data.


In [None]:
# for loop in utterances - check if same word/phrase (up to 3 words) is repeated 10 times consecutively
repeated_words = []
train_data2 = []
for i in tqdm(range(len(train_data))):
    x = train_data[i]
    prompt = x['prompt']
    words, _ = extract_text_and_spk(prompt)
    words = words.split()
    for k in range(len(words)-10):
        if len(set(words[k:k+10])) == 1:
            repeated_words.append(x['utterance_id'])
            break
    for k in range(len(words)-20):
        if len(set(words[k:k+20])) == 2:
            repeated_words.append(x['utterance_id'])
            break
    for k in range(len(words)-30):
        if len(set(words[k:k+30])) == 3:
            repeated_words.append(x['utterance_id'])
            break
    if x['utterance_id'] not in repeated_words:
       train_data2.append(x)


In [None]:
## convert to prompt and completion
dataset_jsonl = {
    "instruction": [f"In the speaker diarization transcript below, some words are potentially misplaced. Please correct those words and move them to the right speaker. Directly show the corrected transcript without explaining what changes were made or why you made those changes.:\n\n{x['prompt']}" for x in train_data2],
    "response": [x['completion'] for x in train_data2],
}

In [None]:
model_id = "mistralai/Mistral-7B-Instruct-v0.2" # sharded weights
tokenizer = AutoTokenizer.from_pretrained(model_id,use_auth_token=True)

final_data = pd.DataFrame(dataset_jsonl)
dataset = Dataset.from_pandas(final_data)


In [None]:
def template_dataset(sample):
    sample["text"] = f"{format_data(sample)}{tokenizer.eos_token}"

    return sample

# apply prompt template per sample
dataset = dataset.map(template_dataset, remove_columns=list(dataset.features))


In [None]:
# tokenize dataset
lm_dataset = dataset.map(
    lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features)
)

lm_dataset.save_to_disk('./train')
