In [None]:
!pip install torch==2.4.0 transformers==4.45.1 datasets==3.0.1 accelerate==0.34.2 trl==0.11.1 peft==0.13.0

In [None]:
import torch

from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer

In [None]:
dataset = load_dataset("iamjoon/finance_news_summarizer", split = "train")
print(len(dataset))

In [None]:
test_ratio = 0.5

data_indices = list(range(len(dataset)))
test_size = int(len(data_indices) * test_ratio)

test_data = data_indices[: test_size]
train_data = data_indices[test_size :]

In [None]:
print(len(train_data))
print(len(test_data))

In [None]:
def format_data(sample):
    return {
        "messages" :[
            {"role" : "system", "content" : sample["system_prompt"]},
            {"role" : "user", "content" : sample["user_prompt"]},
            {"role" : "assistant", "content" : str(sample["assistant"])}
        ]
    }

In [None]:
train_dataset = [format_data(dataset[i]) for i in train_data]
test_dataset = [format_data(dataset[i]) for i in test_data]

In [None]:
print(train_dataset)
print(test_dataset)
print()

print(type(train_dataset))
print(type(test_dataset))

In [None]:
train_dataset = Dataset.from_list(train_dataset)
test_dataset = Dataset.from_list(test_dataset)

print(type(train_dataset))
print(type(test_dataset))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model_id = "NCSOFT/Llama-VARCO-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(model_id,
                                             device_map = "auto"
                                             torch_dtype = torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
peft_config = LoraConfig(
    lora_alpha = 32,
    lora_dropout = 0.1,
    r = 8,
    bias = "none",
    target_modules = ['q_proj', 'v_proj'],
    task_type = "CAUSAL_LM"
)

In [None]:
args = SFTConfig(
    output_dir = "llama3-8b-summarizer-ko",
    num_train_epochs = 3,
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 2,
    gradient_checkpointing = True,
    optim = "adamw_torch_fused",
    logging_steps = 10,
    save_strategy = "steps",
    save_steps = 50,
    bf16 = True,
    learning_rate = 1e-4,
    max_grad_norm = 0.3,
    warmup_ratio = 0.03,
    lr_scheduler_type = "constant",
    push_to_hub = False,
    remove_unused_columns = False,
    dataset_kwargs = {"skip_prepare_dataset" : True},
    report_to = None
)

In [None]:
def collate_fn(batch):
    new_batch = {
        "input_ids" : [],
        "attention_mask" : [],
        "labels" : []
    }

    for example in batch:
        message = example["messages"]

        prompt = "<|begin_of_text|>"
        for msg in message:
            role = msg["role"]
            content = msg["content"].strip()
            prompt += f"<|start_header_id|>{role}<|end_header_id|>\n{content}<|eot_id|>"

        text = prompt.strip()

        tokenized = tokenizer(text,
                              truncation = True,
                              padding = False,
                              max_length = max_seq_length,
                              return_tensors = None)

        input_ids = tokenized["input_ids"]
        attention_mask = tokenized["attention_mask"]
        labels = [-100] * len(input_ids)

        assistant_token = "<|start_header_id|>assistant<|end_header_id|>\n"
        assistant_tokens = tokenizer.encode(assistant_token, add_special_tokens = False)

        eot_token = "<|eot_id|>"
        eot_tokens = tokenizer.encode(eot_token, add_special_tokens = False)

        i = 0
        while i <= len(input_ids) - len(assistant_tokens):
            if input_ids[i : i + len(assistant_tokens)] == assistant_tokens:
                start = i + len(assistant_tokens)
                end = start

                while end <= len(input_ids) - len(eot_tokens):
                    if input_ids[end : end + len(eot_tokens)] == eot_tokens:
                        break
                    end += 1

                for j in range(start, end):
                    labels[j] = input_ids[j]
                for j in range(end, end + len(eot_tokens)):
                    labels[j] = input_ids[j]
                break

            i += 1

        new_batch["input_ids"].append(input_ids)
        new_batch["attention_mask"].append(attention_mask)
        new_batch['labels'].append(labels)

    max_length = max(len(ids) for ids in new_batch["input_ids"])
    for i in range(len(new_batch["input_ids"])):
        pad = max_length - len(new_batch["input_ids"][i])
        new_batch["input_ids"].extend([tokenizer.pad_token_id] * pad)
        new_batch["attention_mask"].extend([0] * pad)
        new_batch["labels"].extend([-100] * pad)

    for k in new_batch:
        new_batch[k] = torch.tensor(new_batch[k])

    return new_batch

In [None]:
max_seq_length = 8192

In [None]:
example = train_dataset[0]
batch = collate_fn([example])

print(batch)

In [None]:
print(batch['input_ids'][0].tolist())

In [None]:
print(batch["labels"][0].tolist())

In [None]:
decoded_text = tokenizer.decode(batch["input_ids"][0].tolist(),
                                skip_special_tokens = False,
                                cleanup_tokenization_spaces = False)

print(decoded_text)

In [None]:
label_ids = [token_id for token_id in batch["labels"][0].tolist() if token_id != -100]
decoded_label = tokenizer.decode(label_ids,
                                 skip_special_tokens = False,
                                 cleanup_tokenization_spaces = False)

decoded_label

In [None]:
trainer = SFTTrainer(
    model = model,
    args = args,
    train_dataset = train_dataset,
    data_collator = collate_fn,
    peft_config = peft_config
)

In [None]:
trainer.train()

trainer.save_model()

In [None]:
prompt_list = []
labels_list = []

for message in test_dataset['messages']:
    text = tokenizer.apply_chat_template(message, tokenize = False, add_generation_prompt = False)
    input = text.split('<|start_header_id|>assistant<|end_header_id|>\n')[0] + '<|start_header_id|>assistant<|end_header_id|>\n')
    label = text.split('<|start_header_id|>assistant<|end_header_id|>\n')[1].split('<|eot_id|>')[0])
    prompt_list.append(input)
    labels_list.append(label)

In [None]:
fine_model = AutoModelForCausalLm.from_pretrained("llama3-8b-summarizer-ko/checkpoint-372", device_map = "auto", torch_dtype = torch.bfloat16)
pipe = pipeline("text-generation", model = "llama3-8b-summarizer-ko/checkpoint-372", tokenizer = tokenizer)