In [19]:
# 세션 유지 코드
from IPython.display import clear_output
import threading, time

def keep_alive():
    for i in range(100000):
        time.sleep(60)
        clear_output(wait=True)
        print(f"Ping {i} ⏱️")

threading.Thread(target=keep_alive).start()

In [None]:
# ===================================================
# 0. 설치 및 환경 설정
# ===================================================
!pip install transformers accelerate peft trl bitsandbytes datasets --quiet

import os
import torch
from datasets import load_dataset, Dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,
    TrainingArguments
)
from peft import LoraConfig
from trl import SFTTrainer
import pandas as pd
import re
from huggingface_hub import login

# 로그인 토큰 필요 시 사용
login(token="YOUR_HF_TOKEN")

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()

In [21]:
# ===================================================
# 1. 데이터셋 로드 및 전처리
# ===================================================
print("🔄 Loading CNN/DailyMail dataset...")
dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:30000]")  # 테스트용 30000개

SYSTEM_PROMPT = """You are GemmaSummarizer, a professional AI assistant trained to generate concise and accurate summaries of news articles.
Your goal is to read the full article, understand its key points, and produce a natural, informative English summary suitable for a general audience.

Guidelines:
1️⃣ Read the entire article and identify the main events, entities, and facts.
2️⃣ Write the summary in your own words. Avoid copying long phrases from the article.
3️⃣ The summary should be:
  - Factual and objective
  - Around 3–4 sentences long
  - Written in fluent, readable English

Restrictions:
- Do not add personal opinions or interpretations.
- Do not mention that you are an AI or refer to the prompt instructions.
- If the article is incomplete or malformed, reply: "Unable to summarize due to incomplete article."""

def format_sample(sample):
    article = sample["article"]
    summary = sample["highlights"]
    return {
        "text": (
            f"<start_of_turn>system\n{SYSTEM_PROMPT}<end_of_turn>\n"
            f"<start_of_turn>user\n{article}<end_of_turn>\n"
            f"<start_of_turn>model\n{summary}<end_of_turn>"
        )
    }

formatted_data = [format_sample(s) for s in dataset]
train_dataset = Dataset.from_pandas(pd.DataFrame(formatted_data))

🔄 Loading CNN/DailyMail dataset...


In [22]:
# ===================================================
# 2. 모델 및 토크나이저 로드 (QLoRA)
# ===================================================
model_id = "google/gemma-3n-E2B-it"

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quant_config,
    device_map="auto"
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [23]:
# ===================================================
# 3. LoRA 설정
# ===================================================
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

In [24]:
# ===================================================
# 4. 훈련 파라미터 및 SFTTrainer 정의
# ===================================================
def tokenize(sample):
    return tokenizer(sample["text"], padding="max_length", truncation=True, max_length=1024)

tokenized_dataset = train_dataset.map(tokenize, batched=True)

from transformers import DataCollatorForLanguageModeling

training_args = TrainingArguments(
    output_dir="./gemma3n-summarizer",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=2,
    logging_steps=10,
    save_steps=500,
    learning_rate=2e-4,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    optim="paged_adamw_8bit",
    report_to=None,
    bf16=True,

)

trainer = SFTTrainer(
    model=model,
    train_dataset=tokenized_dataset,
    args=training_args,
    peft_config=lora_config,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

Map:   0%|          | 0/30000 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/30000 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [25]:
import torch

model.config.altup_coef_clip = 1e-5  # clamp 범위 설정

# prediction_coefs를 가진 모듈을 찾아 float32로 바꾸기
for name, module in model.named_modules():
    if hasattr(module, "weight") and isinstance(module.weight, torch.Tensor):
        if "prediction_coefs" in name:
            print(f"✅ Found {name}, casting to float32...")
            module.weight.data = module.weight.data.to(torch.float32)

# float32 변환 함수 확장
def cast_gemma3n_weights_to_float32(model):
    for name, module in model.named_modules():
        if hasattr(module, "prediction_coefs") and hasattr(module.prediction_coefs, "weight"):
            print(f"✅ Found {name}.prediction_coefs, casting to float32...")
            module.prediction_coefs.weight.data = module.prediction_coefs.weight.data.to(torch.float32)
        if hasattr(module, "correction_coefs") and hasattr(module.correction_coefs, "weight"):
            print(f"✅ Found {name}.correction_coefs, casting to float32...")
            module.correction_coefs.weight.data = module.correction_coefs.weight.data.to(torch.float32)

cast_gemma3n_weights_to_float32(model)

✅ Found model.language_model.layers.0.altup.prediction_coefs, casting to float32...
✅ Found model.language_model.layers.1.altup.prediction_coefs, casting to float32...
✅ Found model.language_model.layers.2.altup.prediction_coefs, casting to float32...
✅ Found model.language_model.layers.3.altup.prediction_coefs, casting to float32...
✅ Found model.language_model.layers.4.altup.prediction_coefs, casting to float32...
✅ Found model.language_model.layers.5.altup.prediction_coefs, casting to float32...
✅ Found model.language_model.layers.6.altup.prediction_coefs, casting to float32...
✅ Found model.language_model.layers.7.altup.prediction_coefs, casting to float32...
✅ Found model.language_model.layers.8.altup.prediction_coefs, casting to float32...
✅ Found model.language_model.layers.9.altup.prediction_coefs, casting to float32...
✅ Found model.language_model.layers.10.altup.prediction_coefs, casting to float32...
✅ Found model.language_model.layers.11.altup.prediction_coefs, casting to f

In [26]:
# ===================================================
# 5. 학습 시작
# ===================================================
print("🚀 Starting training...")
## trainer.train()
trainer.train(resume_from_checkpoint="./gemma3n-summarizer/checkpoint-1000")

Ping 740 ⏱️


Step,Training Loss
1010,17.8353
1020,17.8255
1030,17.682
1040,18.2554
1050,17.8822
1060,17.901
1070,17.584
1080,17.5857
1090,17.4908
1100,17.5357


TrainOutput(global_step=15000, training_loss=14.426356958007812, metrics={'train_runtime': 44463.8771, 'train_samples_per_second': 1.349, 'train_steps_per_second': 0.337, 'total_flos': 1.06625625882624e+18, 'train_loss': 14.426356958007812})

In [27]:
# ===================================================
# 6. 모델 저장 및 Hugging Face 업로드
# ===================================================
save_path = "/content/drive/MyDrive/gemma3n_lora_summary"
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

# Hugging Face 업로드
repo_id = "LeannaJ/gemma3n-lora-summary"
model.push_to_hub(repo_id)
tokenizer.push_to_hub(repo_id)
print(f"✅ Hugging Face에 업로드 완료: https://huggingface.co/{repo_id}")

Ping 743 ⏱️


tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

✅ Hugging Face에 업로드 완료: https://huggingface.co/LeannaJ/gemma3n-lora-summary
