## Fine-tune large models for Vietnamese poem generation using Low-rank adapter

*   Install requirements
*   Model loading
*   Post processing
*   Apply LoRa
*   Training

In this notebook, we will finetune a large models on `8 bit` quantization and `low-rank` adaptation for resource efficiency. Otherwise, colabs won't be able to run it. We also use a `custom loss function` to weigh the generated result, basing on its quality/conformity to the rigid rules of Vietnamese poems.


### Install requirements

In [None]:
!pip install -q bitsandbytes datasets accelerate loralib
!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git

In [None]:
# We can't upload the models .bin to github for being too large
# So either upload the project to google drive, name it accordingly (Trainer_file)
# Or don't use this cell at all
from google.colab import drive
drive.mount('/content/drive/')
%cd /content/drive/My Drive/Trainer_file/

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

peft_model_id = "modeling/poem_generator_(bloom)"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id, device_map={"": 0})

Downloading (…)lve/main/config.json:   0%|          | 0.00/739 [00:00<?, ?B/s]



Downloading (…)model.bin.index.json:   0%|          | 0.00/27.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00002.bin:   0%|          | 0.00/4.16G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/222 [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

### Inference

In [None]:
text='Viết một bài thơ lục bát về mùa xuân. Có chứa các từ khóa "sang năm","tôi","kiếm", "bồ", "mới".'
batch = tokenizer(text+'\n###\n', return_tensors='pt').to('cuda')
with torch.cuda.amp.autocast():
  output_tokens = model.generate(**batch, repetition_penalty=1.1,  max_new_tokens=128)

completion = post_process(tokenizer.decode(output_tokens[0], skip_special_tokens=True))
print(completion)

### Lưu lại 1 bài generated vì thấy hài :)


In [None]:
'''sang năm mới đến rồi đây
tôi đi kiếm việc để mà làm ăn
bồ tôi cũng đã có chồng
tôi thì vẫn độc thân không có bồ.
'''

### Evaluation

In [None]:
import pandas as pd
from utils.check_rule import *

def blind_preprocess(prompt:str):
    if 'lục bát ' in prompt:
        prompt = prompt.replace('lục bát ','')
    elif '4 chữ ' in prompt:
        prompt = prompt.replace('4 chữ ','')
    elif '5 chữ ' in prompt:
        prompt = prompt.replace('5 chữ ','')
    elif '7 chữ ' in prompt:
        prompt = prompt.replace('7 chữ ','')
    else:
        prompt = prompt.replace('8 chữ ','')
    return prompt

def post_process_2nd(completion):
    completion = completion.split('\n\n')
    if len(completion) >= 2:
        if '.' not in completion[-1]:
            completion = completion[:-1]
        completion[-1] = completion[-1].split('\n')
        if len(completion[-1]) % 2 != 0:
            completion[-1] = '\n'.join(completion[-1][:-1])
        elif len(completion[-1])>2:
            if '.' not in completion[-1][-1]:
                completion[-1] = completion[-1][:-2]
            completion[-1] = '\n'.join(completion[-1])
        else:
            completion[-1] = '\n'.join(completion[-1])
        for i in range(len(completion)):
          completion[i] = completion[i].split('.')[0]+'.'
        completion = '\n\n'.join(completion)
    else:
        completion = completion[0]
        completion = completion.split('\n')
        if len(completion) % 2 != 0:
            completion = '\n'.join(completion[:-1])
        elif len(completion)>2:
            if '.' not in completion[-1]:
                completion = completion[:-2]
            completion = '\n'.join(completion)
        else:
            completion = '\n'.join(completion)

    return completion.strip()

def post_process(completion):
  completion = completion.split('###')[1].split('@@@')[0]
  completion = completion.split('\n')
  for i in range(len(completion)):
    if '.' in completion[i] and len(completion[i].split('.'))>1:
      completion[i] = completion[i].split('.')[0]+'.'
  return post_process_2nd('\n'.join(completion))

def eval_score(prompt, completion):
    if 'lục bát' in prompt:
        score = calculate_score(completion, 'luc bat')
    elif '4 chữ' in prompt:
        score = calculate_score(completion, '4 chu')
    elif '5 chữ' in prompt:
        score = calculate_score(completion, '5 chu')
    elif '7 chữ' in prompt:
        score = calculate_score(completion, '7 chu')
    elif '8 chữ' in prompt:
        score = calculate_score(completion, '8 chu')
    else:
        score = calculate_score(completion)
    return score

def eval_generator(num):
  scores = []
  for prompt in eval_data[:num]:
    batch = tokenizer(prompt+'\n###\n', return_tensors='pt').to('cuda')
    with torch.cuda.amp.autocast():
      output_tokens = model.generate(**batch, repetition_penalty=1.1,  max_new_tokens=128)

    completion = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    #print(completion)
    completion = post_process(completion)
    score = eval_score(prompt, completion)
    if score != 0:
      scores.append(score[0])
    print(completion)
    print(score)

  print(sum(scores) / len(scores))

eval_data = pd.read_csv('resource/dataset/dataset.csv')[20000:]
eval_data = eval_data[eval_data['genre']=='8 chu'].sample(frac=1).reset_index(drop=True)
#eval_data['prompt'] = eval_data['prompt'].apply(lambda x: blind_preprocess(x)).sample(frac=1).reset_index(drop=True)
eval_data = eval_data['prompt'].tolist()

In [None]:
eval_generator(50)

In [None]:
# performance
performance_bloom_20000 = {'luc bat':0.678, 
                           '7 chu':0.367,
                           '8 chu':0.279,
                           '4 chu':0.44,
                           '5 chu':0.48,
                           'blind':0.596}