In [1]:
!pip install -qqq -U transformers datasets huggingface_hub accelerate bitsandbytes --progress-bar off
!FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install -qqq -U flash-attn --no-build-isolation pip install flash-attn --progress-bar off

In [2]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
import transformers
import torch
from peft import PeftModel

# Set torch dtype and attention implementation
if torch.cuda.get_device_capability()[0] >= 8:
    !pip install -qqq flash-attn
    torch_dtype = torch.bfloat16
    attn_implementation = "flash_attention_2"
else:
    torch_dtype = torch.float16
    attn_implementation = "eager"

base_model = "meta-llama/Meta-Llama-3-8B"
new_model = "Meta-Llama-3-8B-qlora-translation-no-tag"

bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

tokenizer = transformers.AutoTokenizer.from_pretrained(new_model)

model = transformers.AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation,
)
model = PeftModel.from_pretrained(model, new_model)
model = model.merge_and_unload()

pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



In [4]:
from datasets import load_dataset, interleave_datasets

def get_dataset():
    lm_datasets_train = []
    lm_datasets_test = []

    single_lang = ["eng", "yue", "cmn"]
    lang_datasets = ["eng-yue", "cmn-yue"]
    lang_pairs = ["eng-yue", "yue-cmn"]

    tag_name_dict = {}
    for lang in single_lang:
        tag_name_dict[lang] = f'{lang}:'

    prompt_examples = {}

    for i, lang_dataset in enumerate(lang_datasets):

        lm_dataset = load_dataset("AlienKevin/yue-cmn-eng", lang_dataset)
        lm_dataset_train = lm_dataset["train"]
        lm_dataset_train = lm_dataset_train.shuffle(seed=42)

        source_lang, target_lang = lang_pairs[i].split("-")

        def preprocess_train(example):
            example = example['translation']
            return {"input": 'input:' + example[source_lang] + '\n' + 'output:' + example[target_lang]}

        lm_dataset_train = lm_dataset_train.select(range(10))
        prompt_examples[lang_pairs[i]] = [example['input'] for example in lm_dataset_train.map(preprocess_train, remove_columns=['translation']).take(10)]

    prompts = {pair: '\n'.join(examples) + '\n' for pair, examples in prompt_examples.items()}
    print(prompts)
    
    for i, lang_dataset in enumerate(lang_datasets):

        lm_dataset = load_dataset("AlienKevin/yue-cmn-eng", lang_dataset)
        lm_dataset_test = lm_dataset["test"]

        source_lang, target_lang = lang_pairs[i].split("-")

        def preprocess_eval(examples):
            examples["inputs"] = [prompts[lang_pairs[i]] + 'input:' + example[source_lang] + '\n' + 'output:' for example in examples["translation"]]
            del examples['translation']
            return examples
        
        lm_dataset_test = lm_dataset_test.map(preprocess_eval, batched=True)
        lm_datasets_test.append(lm_dataset_test)
    
    eval_dataset = interleave_datasets(lm_datasets_test)
    return prompts, eval_dataset, tag_name_dict

In [5]:
prompts, eval_dataset, tag_name_dict = get_dataset()

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/82.6k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/624k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/1500 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/11504 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

{'eng-yue': "input:Please don't put toilet paper into the urinal, so as to avoid clogging it, thanks for your cooperation.\noutput:請勿將廁紙放在尿兜内，以免淤塞，多謝合作。\ninput:This guy is very greedy for money; he was caught stealing money from his company before.\noutput:呢條友好貪錢㗎，之前俾人發現佢偷公司錢。\ninput:nostril\noutput:鼻哥窿\ninput:As your informer, I'll certainly pass on any information to you.\noutput:我做得你條針，實會過料畀你。\ninput:This website was designed by me.\noutput:呢個係我自己設計嘅網站。\ninput:to see the world\noutput:見世面\ninput:Mum! Are you fine?\noutput:媽！你有冇事啊？\ninput:I am becoming clumsier as I get older.\noutput:我老咗做嘢係論盡啲。\ninput:This shirt doesn't have even one pocket.\noutput:呢件裇衫一個衫袋都冇。\ninput:a rain shower\noutput:一陣雨\n", 'yue-cmn': 'input:見衫係紅色嘅\noutput:衣服是红色的\ninput:嗰個地方好多時有賊劏死牛，冇乜事唔好行去嗰度\noutput:那个地方经常有贼烂路抢劫，没什么事不要走到那儿去\ninput:唔成功都唔使心淡吖\noutput:不成功也不用著心灰意冷\ninput:睇呢啲濕星嘢你要唔要呀\noutput:看你要不要这些琐碎的东西\ninput:叫你整闊啲，你又闊過龍\noutput:叫你弄宽点儿，你又弄得太宽了\ninput:我點會笑你喎，我自己都唔叻得去邊\noutput:我怎么会笑你的，我自己也聪明不到哪里去\ninput:呢幾張枱由我包起\

Map:   0%|          | 0/1500 [00:00<?, ? examples/s]

Map:   0%|          | 0/1500 [00:00<?, ? examples/s]

In [6]:
eval_dataset

Dataset({
    features: ['inputs'],
    num_rows: 3000
})

In [7]:
eval_dataset[:2]

{'inputs': ["input:Please don't put toilet paper into the urinal, so as to avoid clogging it, thanks for your cooperation.\noutput:請勿將廁紙放在尿兜内，以免淤塞，多謝合作。\ninput:This guy is very greedy for money; he was caught stealing money from his company before.\noutput:呢條友好貪錢㗎，之前俾人發現佢偷公司錢。\ninput:nostril\noutput:鼻哥窿\ninput:As your informer, I'll certainly pass on any information to you.\noutput:我做得你條針，實會過料畀你。\ninput:This website was designed by me.\noutput:呢個係我自己設計嘅網站。\ninput:to see the world\noutput:見世面\ninput:Mum! Are you fine?\noutput:媽！你有冇事啊？\ninput:I am becoming clumsier as I get older.\noutput:我老咗做嘢係論盡啲。\ninput:This shirt doesn't have even one pocket.\noutput:呢件裇衫一個衫袋都冇。\ninput:a rain shower\noutput:一陣雨\ninput:This is really amusing, a radio controlled car that can climb on walls.\noutput:",
  'input:見衫係紅色嘅\noutput:衣服是红色的\ninput:嗰個地方好多時有賊劏死牛，冇乜事唔好行去嗰度\noutput:那个地方经常有贼烂路抢劫，没什么事不要走到那儿去\ninput:唔成功都唔使心淡吖\noutput:不成功也不用著心灰意冷\ninput:睇呢啲濕星嘢你要唔要呀\noutput:看你要不要这些琐碎的东西\ninput:叫你整闊啲，你又闊過龍\noutput:叫你弄宽点儿

In [9]:
# https://huggingface.co/PygmalionAI/pygmalion-6b/discussions/25#64387bf26c8841ba74e7d9c0
from transformers import StoppingCriteria

class TranslationStoppingCriteria(StoppingCriteria):
    def __init__(self, prompts):
        self.prompts = prompts
        
    def __call__(self, input_ids, scores, **kwargs):
        # Get the generated text as a string
        generated_text = tokenizer.decode(input_ids[0])
        for prompt in prompts.values():
            generated_text = generated_text.removeprefix(prompt)
        if generated_text.endswith('\n'):
            return True  # Stop generation
        return False  # Continue generation
    
    def __len__(self):
        return 1
    
    def __iter__(self):
        yield self

In [10]:
from transformers.pipelines.pt_utils import KeyDataset
import json
from tqdm import tqdm

lang_tags = ["eng:", "yue:", "cmn:"]

outputs = pipeline(
    KeyDataset(eval_dataset, "inputs"),
    max_new_tokens=128,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
    stopping_criteria=TranslationStoppingCriteria(prompts),
    pad_token_id=tokenizer.eos_token_id,
)

def parse_translation(text):
    lines = text.strip().split('\n')
    result = { 'langs': [], 'sents': [] }
    
    for line in lines:
        if ':' in line:
            lang, content = line.split(':', 1)
            if lang in ['input', 'output']:
                result['langs'].append('unknown')
                result['sents'].append(content.strip())
    
    return result

with open(f'experiment_results/translations_{new_model}.jsonl', 'w+') as f:
    for output in tqdm(outputs, total=len(eval_dataset)):
        generated_text = output[0]['generated_text']
        for prompt in prompts.values():
            generated_text = generated_text.removeprefix(prompt)
        f.write(json.dumps(parse_translation(generated_text)) + '\n')
        f.flush()

100%|██████████| 3000/3000 [54:21<00:00,  1.09s/it]  
