# 1. polglot-ko-1.3b를 squarelike/sharegpt_deepl_ko_translation로 파인 튜닝한 모델

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model_id = "EleutherAI/polyglot-ko-1.3b"
model_dir = "./results/poly-ko-1.3b-translate"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(base_model_id)
model = AutoModelForCausalLM.from_pretrained(model_dir, quantization_config=bnb_config, device_map="auto")

In [3]:
model.eval()
model.config.use_cache = True  # silence the warnings. Please re-enable for inference!

In [4]:
from transformers import StoppingCriteria, StoppingCriteriaList

class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops = [], encounters=1):
        super().__init__()
        self.stops = [stop for stop in stops]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True

        return False

stop_words = ["</끝>"]
stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

In [5]:
def gen(lan="en", x=""):
    if (lan == "ko"):
        prompt = f"### 한국어: {x}</끝>\n### 영어:"
    else:
        prompt = f"### 영어: {x}</끝>\n### 한국어:"
    gened = model.generate(
        **tokenizer(
            prompt,
            return_tensors='pt',
            return_token_type_ids=False
        ),
        max_new_tokens=2048,
        temperature=0.001,
        no_repeat_ngram_size=10,
        early_stopping=True,
        eos_token_id=2,
        stopping_criteria=stopping_criteria
    )
    return tokenizer.decode(gened[0]).replace(prompt+" ", "")

In [8]:
gen(lan="en", x="NATO summit keeps focus on Ukraine as Russia's war continues. Here's what you need to know")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'나토 정상회담에서 러시아의 전쟁이 계속되는 만큼 우크라이나에 집중하는 것이 중요합니다. 다음은 여러분이 알아야 할 사항입니다.</끝>'

In [7]:
gen(lan="en", x="Russian defense ministry says Wagner has handed over tanks and other weapons")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'러시아 국방부는 와그너가 탱크와 다른 무기를 넘겨받았다고 밝혔습니다.</끝>'