#실행전 준비

데이터 출처: https://aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&dataSetSn=580

github ipynb 출력 오류로 셀 출력 없는 버전이 업로드 되었습니다.

셀 출력 있는 버전은 colab에 업로드 되어있습니다.

링크: https://colab.research.google.com/drive/1_NN0nExSkmpgAt7DWya3XZuSFa3uIGff?usp=sharing

In [None]:
# pip installs

!pip install -q -U transformers bitsandbytes protobuf
!pip install -q -U peft trl matplotlib wandb

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from google.colab import files

src = list(files.upload().values())[0]
open('make_data2.py','wb').write(src)

#라이브러리 import

In [None]:
import os
import re
import math
from tqdm import tqdm
from google.colab import userdata
from huggingface_hub import login
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText, AutoTokenizer, TrainingArguments, set_seed, BitsAndBytesConfig
from datasets import load_dataset, Dataset, DatasetDict
import wandb
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from datetime import datetime
import matplotlib.pyplot as plt
import pickle

from make_data2 import Datasetup

#데이터 로드

In [None]:
with open("/content/drive/MyDrive/law_ai/new_pickle/train.pkl","rb") as fr:
    train_data = pickle.load(fr)

In [None]:
with open("/content/drive/MyDrive/law_ai/new_pickle/test.pkl","rb") as fr:
    test_data = pickle.load(fr)

In [None]:
with open("/content/drive/MyDrive/law_ai/new_pickle/val.pkl","rb") as fr:
    val_data = pickle.load(fr)

In [None]:
train_data[1]

In [None]:
print(len(train_data))
print(len(test_data))
print(len(val_data))

#모델 학습 준비

In [None]:
#모델 설정
MODEL_ID = "google/gemma-3-4b-pt"
TOKENIZER_MODEL = "google/gemma-3-4b-it"
PROJECT_NAME = "law_ai5"
HF_USER = "dlddu123"

#모델 불러오는 라이브러리 선택
if MODEL_ID == "google/gemma-3-1b-pt":
    MODEL_CLASS = AutoModelForCausalLM
else:
    MODEL_CLASS = AutoModelForImageTextToText

#모델의 dtype 설정
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

#프로젝트 이름 설정
RUN_NAME =  f"{datetime.now():%Y-%m-%d_%H.%M.%S}"
PROJECT_RUN_NAME = f"{PROJECT_NAME}-{RUN_NAME}"
HUB_MODEL_NAME = f"{HF_USER}/{PROJECT_RUN_NAME}"

#LoRA 하이퍼파라미터 설정
peft_config = LoraConfig(
    r = 32,
    lora_alpha  = 32,
    target_modules = ['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
    lora_dropout  = 0.1,
    task_type="CAUSAL_LM"
)

#WANDB 설정
EPOCHS = 10
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 4
LEARNING_RATE = 1e-4
LR_SCHEDULER_TYPE = 'cosine'
WARMUP_RATIO = 0.03
OPTIMIZER = "adamw_torch_fused"

#WANDB 설정
LOG_TO_WANDB = True

%matplotlib inline

In [None]:
HUB_MODEL_NAME

In [None]:
# Log in to HuggingFace

hf_token = userdata.get('HF_TOKEN')
login(hf_token, add_to_git_credential=True)

#WANDB 준비 및 시작

In [None]:
wandb_api_key = userdata.get('WANDB_API_KEY')
os.environ["WANDB_API_KEY"] = wandb_api_key#from google.colab import userdata으로 colab에 저장된 키 가져오기

#WANDB 프로젝트 설정
os.environ["WANDB_PROJECT"] = PROJECT_NAME
os.environ["WANDB_LOG_MODEL"] = "checkpoint" if LOG_TO_WANDB else "end"
os.environ["WANDB_WATCH"] = "gradients"

wandb.login()

In [None]:
#혹시나 wandb가 이미 실행되어있을경우 초기화
try:
    if wandb.run is not None:
        wandb.finish()
except Exception:
    pass


In [None]:
wandb_api_key = userdata.get('WANDB_API_KEY')  # 주피터/코랩 비밀변수 등에서 가져오는 경우
if not wandb_api_key or not str(wandb_api_key).strip():
    raise ValueError("WANDB_API_KEY가 비어 있습니다. userdata 등에 키를 설정해주세요.")

In [None]:
wandb.login(key=wandb_api_key, relogin=True)

In [None]:
#wandb 실행 시작
if LOG_TO_WANDB:
  wandb.init(project=PROJECT_NAME, name=RUN_NAME)

#모델 로드 설정

In [None]:
# 양자화 설정
QUANT_4_BIT = True
if QUANT_4_BIT:
  quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4"
  )
else:
  quant_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_compute_dtype=torch.bfloat16
  )

In [None]:
# tokenizer와 model 로드
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL, trust_remote_code=True)

base_model = MODEL_CLASS.from_pretrained(#멀티모달은 AutoModelForImageTextToText로 불러오기 아니면 일반 text는 AutoModelForCausalLM
    MODEL_ID,
    quantization_config=quant_config,
    dtype = torch_dtype,
    attn_implementation = "sdpa",
    device_map="auto",
)

print(f"Memory footprint: {base_model.get_memory_footprint() / 1e6:.1f} MB")

#학습에 사용할 데이터 추출

In [None]:
train_data[0].prompt

In [None]:
train_list = []
val_list = []
test_list = []


for prompt in train_data:
      train_list.append(prompt.prompt)
for prompt in test_data:
      test_list.append(prompt.prompt)
for prompt in val_data:
      val_list.append(prompt.prompt)

In [None]:
train_dataset = Dataset.from_list(train_list)
eval_dataset = Dataset.from_list(val_list)
test_dataset = Dataset.from_list(test_list)

In [None]:
#관련 법조항까지 학습을 시키면 제대로된 결과가 나오지 않아서 제외시키기
def del_law_list(target):
  obj = str(target['messages'][2]['content'].split('결론')[1:]).strip()
  target['messages'][2]['content'] = obj
  return target

In [None]:
test_dataset[1]

In [None]:
del_law_list(test_dataset[1])

In [None]:
new_train_dataset = train_dataset.map(del_law_list)
new_eval_dataset = eval_dataset.map(del_law_list)
new_test_dataset = test_dataset.map(del_law_list)

In [None]:
new_test_dataset[1]

In [None]:
len(new_train_dataset)

In [None]:
type(new_train_dataset[0])

In [None]:
type(new_train_dataset)

#SFTTrainer에 사용할 하이퍼파라미터 설정

In [None]:
train_parameters = SFTConfig(
    output_dir=PROJECT_RUN_NAME,
    num_train_epochs=5,
    packing=False,

    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,#너무 크게 설정하면 학습속도 느려지고, 학습 성능 안나오니 잘 조절하기

    gradient_checkpointing=True,
    optim="adamw_torch_fused",

    learning_rate=1e-4,
    fp16=True if torch_dtype == torch.float16 else False,
    bf16=True if torch_dtype == torch.bfloat16 else False,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",

    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=5,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_safetensors=True,

    logging_steps=50,
    report_to="wandb" if LOG_TO_WANDB else "tensorboard",
    run_name=RUN_NAME,

    push_to_hub=True,
    hub_strategy="every_save",
    hub_model_id=HUB_MODEL_NAME,
    hub_private_repo=True,

    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": True,
    }
)


trainer = SFTTrainer(
    model=base_model,
    args=train_parameters,
    train_dataset=new_train_dataset,
    eval_dataset=new_eval_dataset,
    peft_config=peft_config,
    processing_class=tokenizer
)

In [None]:
#학습시작
trainer.train()

In [None]:
#마지막 모델 저장
trainer.model.push_to_hub(PROJECT_RUN_NAME, private=True)
print(f"Saved to the hub: {PROJECT_RUN_NAME}")

In [None]:
#wandb 종료
if LOG_TO_WANDB:
  wandb.finish()

In [None]:
# free the memory again
del base_model
del trainer
torch.cuda.empty_cache()

#모델 출력 테스트

In [None]:
RUN_NAME = "2025-12-04_21.54.03"
PROJECT_NAME = 'law_ai5'
PROJECT_RUN_NAME = f"{PROJECT_NAME}-{RUN_NAME}"
REVISION = "" # or REVISION = None
FINETUNED_MODEL = f"{HF_USER}/{PROJECT_RUN_NAME}"

In [None]:
FINETUNED_MODEL

In [None]:
from peft import PeftModel

# Load Model base model
model = MODEL_CLASS.from_pretrained(MODEL_ID, low_cpu_mem_usage=True)

# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model,FINETUNED_MODEL)


processor = AutoTokenizer.from_pretrained(TOKENIZER_MODEL)
processor.save_pretrained("merged_model")

In [None]:
new_test_dataset[13]

In [None]:
tokenizer = processor

In [None]:
from transformers import pipeline
from random import randint
import re

# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=peft_model, tokenizer=tokenizer)

# Load a random sample from the test dataset
rand_idx = randint(0, len(test_dataset))
test_sample = new_test_dataset[19]

# Convert as test example into a prompt with the Gemma template
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:2], tokenize=False, add_generation_prompt=True)

# Generate our SQL query.
outputs = pipe(prompt, max_new_tokens=300, do_sample=False, temperature=1, eos_token_id=stop_token_ids, disable_compile=True)

# Extract the user query and original answer

print(f"Original Answer:\n{test_sample['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")