In [1]:
import torch
from transformers import (GPT2Tokenizer, 
                          GPT2LMHeadModel, 
                          Trainer, 
                          TrainingArguments)
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, TaskType

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#GPU 사용
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
#GPT-2 모델, 토크나이저 로드
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# EOS token 확인
print(f"EOS token: {tokenizer.eos_token}, EOS token ID: {tokenizer.eos_token_id}")

# pad_token 추가 (eos_token을 pad_token으로 사용)
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
# EOS token 확인
print(f"EOS token: {tokenizer.eos_token}, EOS token ID: {tokenizer.eos_token_id}")


EOS token: <|endoftext|>, EOS token ID: 50256
EOS token: <|endoftext|>, EOS token ID: 50256




In [4]:
#dataset load 
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
eval_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")

In [5]:


# 데이터셋 전처리 (토큰화 + labels 추가)
def preprocess_function(examples):
    inputs = tokenizer(examples["text"], return_tensors="pt", padding="max_length", truncation=True, max_length=128)
    inputs["labels"] = inputs["input_ids"].clone()  # labels에 input_ids를 복사하여 사용
    return inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["text"])
tokenized_eval_dataset = eval_dataset.map(preprocess_function, batched=True, remove_columns=["text"])


Map: 100%|██████████| 36718/36718 [00:05<00:00, 6514.83 examples/s]
Map: 100%|██████████| 3760/3760 [00:00<00:00, 5231.66 examples/s]


# LoRA 적용

In [15]:
lora_config = LoraConfig(
    r=4,
    lora_alpha=16, #학습률 스케일링 
    lora_dropout=0.2, 
    target_modules=["c_attn"],  # attention 레이어에 LoRA적용 
    # Q,K,V를 다르게 적용하고자 하여도 Hugging face에서 하나의 모듈로 통합되어서 개별 적용 못함. 
    task_type=TaskType.CAUSAL_LM #언어 모델링 태스크 
)

In [16]:
# gpt-2 모델 로드 후 모델에 LoRA적용 
lora_model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
lora_model = get_peft_model(lora_model, lora_config)


# pad_token이 추가된 후 토큰 임베딩 크기 재설정
lora_model.resize_token_embeddings(len(tokenizer))



Embedding(50257, 768)

In [17]:
# loss 반환시 오류가 발생하므로 Trainer에 compute_loss() 메서드를 재정의 

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs= False):
        #labels를 input에서 가져옴
        labels = inputs.get("labels")

        #inputs에서 labels를 제거하고 나머지를 모델에 전달
        outputs=model(**inputs)

        #model에서 logits를 반환
        logits = outputs.get("logits")
        # Shifted logits and labels로 손실 계산 (Causal Language Modeling Loss)
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        
        # CrossEntropyLoss 함수로 손실 계산
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        return (loss, outputs) if return_outputs else loss

In [18]:
training_args_lora = TrainingArguments(
    output_dir="./lora_results",
    per_device_train_batch_size=4,
    num_train_epochs=5,
    logging_dir="./logs",
    logging_steps=500,
    save_steps=500,
    evaluation_strategy="steps",
    eval_steps=500,
    save_total_limit=2,
    learning_rate=5e-4,
    fp16=True,  # GPU의 half precision 사용
    report_to="none"  # 원하지 않는 리포팅 툴로의 연결을 방지
)



In [19]:
# LoRA 모델 트레이너 정의
lora_trainer = Trainer(
    model=lora_model,
    args=training_args_lora,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_eval_dataset,  # 평가용 데이터셋

)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [20]:
# LoRA 모델 학습
lora_trainer.train()

  1%|          | 500/45900 [00:08<12:25, 60.87it/s]

{'loss': 1.6042, 'grad_norm': 0.5960968732833862, 'learning_rate': 0.0004945751633986928, 'epoch': 0.05}


                                                   
  1%|          | 500/45900 [00:14<12:25, 60.87it/s]

{'eval_loss': 1.3665893077850342, 'eval_runtime': 5.6658, 'eval_samples_per_second': 663.631, 'eval_steps_per_second': 82.954, 'epoch': 0.05}


  2%|▏         | 1000/45900 [00:22<12:21, 60.58it/s] 

{'loss': 1.3897, 'grad_norm': 0.11616022139787674, 'learning_rate': 0.0004891285403050109, 'epoch': 0.11}


                                                    
  2%|▏         | 1000/45900 [00:28<12:21, 60.58it/s]

{'eval_loss': 1.3424890041351318, 'eval_runtime': 5.6811, 'eval_samples_per_second': 661.849, 'eval_steps_per_second': 82.731, 'epoch': 0.11}


  3%|▎         | 1500/45900 [00:37<12:13, 60.52it/s]  

{'loss': 1.4279, 'grad_norm': 0.26121753454208374, 'learning_rate': 0.000483681917211329, 'epoch': 0.16}


                                                    
  3%|▎         | 1500/45900 [00:42<12:13, 60.52it/s]

{'eval_loss': 1.3298457860946655, 'eval_runtime': 5.6379, 'eval_samples_per_second': 666.912, 'eval_steps_per_second': 83.364, 'epoch': 0.16}


  4%|▍         | 2000/45900 [00:51<12:07, 60.38it/s]  

{'loss': 1.3338, 'grad_norm': 0.35731732845306396, 'learning_rate': 0.00047823529411764704, 'epoch': 0.22}


                                                    
  4%|▍         | 2000/45900 [00:57<12:07, 60.38it/s]

{'eval_loss': 1.322253942489624, 'eval_runtime': 5.636, 'eval_samples_per_second': 667.141, 'eval_steps_per_second': 83.393, 'epoch': 0.22}


  5%|▌         | 2500/45900 [01:06<12:04, 59.91it/s]  

{'loss': 1.326, 'grad_norm': 0.41638997197151184, 'learning_rate': 0.00047278867102396515, 'epoch': 0.27}


                                                    
  5%|▌         | 2500/45900 [01:11<12:04, 59.91it/s]

{'eval_loss': 1.320235252380371, 'eval_runtime': 5.6848, 'eval_samples_per_second': 661.408, 'eval_steps_per_second': 82.676, 'epoch': 0.27}


  7%|▋         | 3000/45900 [01:20<11:56, 59.85it/s]  

{'loss': 1.3079, 'grad_norm': 0.3831005096435547, 'learning_rate': 0.0004673420479302832, 'epoch': 0.33}


                                                    
  7%|▋         | 3000/45900 [01:26<11:56, 59.85it/s]

{'eval_loss': 1.317421793937683, 'eval_runtime': 5.6632, 'eval_samples_per_second': 663.935, 'eval_steps_per_second': 82.992, 'epoch': 0.33}


  8%|▊         | 3500/45900 [01:34<11:34, 61.01it/s]  

{'loss': 1.3438, 'grad_norm': 0.6888113617897034, 'learning_rate': 0.0004618954248366013, 'epoch': 0.38}


                                                    
  8%|▊         | 3500/45900 [01:40<11:34, 61.01it/s]

{'eval_loss': 1.3141781091690063, 'eval_runtime': 5.6224, 'eval_samples_per_second': 668.749, 'eval_steps_per_second': 83.594, 'epoch': 0.38}


  9%|▊         | 4000/45900 [01:49<11:26, 61.00it/s]  

{'loss': 1.3762, 'grad_norm': 0.67008376121521, 'learning_rate': 0.00045644880174291943, 'epoch': 0.44}


                                                    
  9%|▊         | 4000/45900 [01:54<11:26, 61.00it/s]

{'eval_loss': 1.315261721611023, 'eval_runtime': 5.6261, 'eval_samples_per_second': 668.317, 'eval_steps_per_second': 83.54, 'epoch': 0.44}


 10%|▉         | 4500/45900 [02:03<11:18, 61.03it/s]  

{'loss': 1.3349, 'grad_norm': 0.4809548854827881, 'learning_rate': 0.0004510021786492375, 'epoch': 0.49}


                                                    
 10%|▉         | 4500/45900 [02:08<11:18, 61.03it/s]

{'eval_loss': 1.3100987672805786, 'eval_runtime': 5.6299, 'eval_samples_per_second': 667.858, 'eval_steps_per_second': 83.482, 'epoch': 0.49}


 11%|█         | 5000/45900 [02:17<11:02, 61.71it/s]  

{'loss': 1.3384, 'grad_norm': 0.5205889344215393, 'learning_rate': 0.00044555555555555554, 'epoch': 0.54}


                                                    
 11%|█         | 5000/45900 [02:23<11:02, 61.71it/s]

{'eval_loss': 1.3099243640899658, 'eval_runtime': 5.5309, 'eval_samples_per_second': 679.813, 'eval_steps_per_second': 84.977, 'epoch': 0.54}


 12%|█▏        | 5500/45900 [02:31<10:50, 62.12it/s]  

{'loss': 1.2887, 'grad_norm': 0.4358111321926117, 'learning_rate': 0.00044010893246187365, 'epoch': 0.6}


                                                    
 12%|█▏        | 5500/45900 [02:37<10:50, 62.12it/s]

{'eval_loss': 1.3085600137710571, 'eval_runtime': 5.5316, 'eval_samples_per_second': 679.735, 'eval_steps_per_second': 84.967, 'epoch': 0.6}


 13%|█▎        | 6000/45900 [02:45<10:44, 61.93it/s]  

{'loss': 1.3023, 'grad_norm': 0.5975645780563354, 'learning_rate': 0.00043466230936819176, 'epoch': 0.65}


                                                    
 13%|█▎        | 6000/45900 [02:51<10:44, 61.93it/s]

{'eval_loss': 1.306800365447998, 'eval_runtime': 5.6003, 'eval_samples_per_second': 671.392, 'eval_steps_per_second': 83.924, 'epoch': 0.65}


 14%|█▍        | 6500/45900 [02:59<10:40, 61.48it/s]  

{'loss': 1.3842, 'grad_norm': 0.25559642910957336, 'learning_rate': 0.0004292156862745098, 'epoch': 0.71}


                                                    
 14%|█▍        | 6500/45900 [03:05<10:40, 61.48it/s]

{'eval_loss': 1.3055347204208374, 'eval_runtime': 5.6308, 'eval_samples_per_second': 667.753, 'eval_steps_per_second': 83.469, 'epoch': 0.71}


 15%|█▌        | 7000/45900 [03:14<10:42, 60.51it/s]  

{'loss': 1.386, 'grad_norm': 0.20895034074783325, 'learning_rate': 0.0004237690631808279, 'epoch': 0.76}


                                                    
 15%|█▌        | 7000/45900 [03:19<10:42, 60.51it/s]

{'eval_loss': 1.3047337532043457, 'eval_runtime': 5.614, 'eval_samples_per_second': 669.755, 'eval_steps_per_second': 83.719, 'epoch': 0.76}


 16%|█▋        | 7500/45900 [03:28<10:24, 61.49it/s]  

{'loss': 1.3277, 'grad_norm': 0.5788065791130066, 'learning_rate': 0.000418322440087146, 'epoch': 0.82}


                                                    
 16%|█▋        | 7500/45900 [03:34<10:24, 61.49it/s]

{'eval_loss': 1.3044812679290771, 'eval_runtime': 5.5737, 'eval_samples_per_second': 674.599, 'eval_steps_per_second': 84.325, 'epoch': 0.82}


 17%|█▋        | 8000/45900 [03:42<10:18, 61.29it/s]  

{'loss': 1.3247, 'grad_norm': 0.7138481140136719, 'learning_rate': 0.00041287581699346403, 'epoch': 0.87}


                                                    
 17%|█▋        | 8000/45900 [03:48<10:18, 61.29it/s]

{'eval_loss': 1.3034727573394775, 'eval_runtime': 5.5545, 'eval_samples_per_second': 676.931, 'eval_steps_per_second': 84.616, 'epoch': 0.87}


 19%|█▊        | 8500/45900 [03:56<10:07, 61.56it/s]  

{'loss': 1.3271, 'grad_norm': 0.20646969974040985, 'learning_rate': 0.00040744008714596955, 'epoch': 0.93}


                                                    
 19%|█▊        | 8500/45900 [04:02<10:07, 61.56it/s]

{'eval_loss': 1.303989291191101, 'eval_runtime': 5.5634, 'eval_samples_per_second': 675.84, 'eval_steps_per_second': 84.48, 'epoch': 0.93}


 20%|█▉        | 9000/45900 [04:11<10:02, 61.28it/s]  

{'loss': 1.3596, 'grad_norm': 0.2908323407173157, 'learning_rate': 0.0004019934640522876, 'epoch': 0.98}


                                                    
 20%|█▉        | 9000/45900 [04:16<10:02, 61.28it/s]

{'eval_loss': 1.3034775257110596, 'eval_runtime': 5.5754, 'eval_samples_per_second': 674.392, 'eval_steps_per_second': 84.299, 'epoch': 0.98}


 21%|██        | 9500/45900 [04:25<09:48, 61.82it/s]  

{'loss': 1.3302, 'grad_norm': 0.8117608428001404, 'learning_rate': 0.000396557734204793, 'epoch': 1.03}


                                                    
 21%|██        | 9500/45900 [04:30<09:48, 61.82it/s]

{'eval_loss': 1.3027015924453735, 'eval_runtime': 5.5678, 'eval_samples_per_second': 675.317, 'eval_steps_per_second': 84.415, 'epoch': 1.03}


 22%|██▏       | 10000/45900 [04:39<09:49, 60.94it/s] 

{'loss': 1.2843, 'grad_norm': 0.3860924541950226, 'learning_rate': 0.0003911111111111111, 'epoch': 1.09}


                                                     
 22%|██▏       | 10000/45900 [04:45<09:49, 60.94it/s]

{'eval_loss': 1.3019016981124878, 'eval_runtime': 5.6214, 'eval_samples_per_second': 668.867, 'eval_steps_per_second': 83.608, 'epoch': 1.09}


 23%|██▎       | 10500/45900 [04:53<09:36, 61.38it/s]  

{'loss': 1.3092, 'grad_norm': 0.3334556818008423, 'learning_rate': 0.00038566448801742923, 'epoch': 1.14}


                                                     
 23%|██▎       | 10500/45900 [04:59<09:36, 61.38it/s]

{'eval_loss': 1.3018972873687744, 'eval_runtime': 5.5957, 'eval_samples_per_second': 671.943, 'eval_steps_per_second': 83.993, 'epoch': 1.14}


 24%|██▍       | 11000/45900 [05:08<09:41, 60.07it/s]  

{'loss': 1.3278, 'grad_norm': 0.5175043344497681, 'learning_rate': 0.0003802178649237473, 'epoch': 1.2}


                                                     
 24%|██▍       | 11000/45900 [05:13<09:41, 60.07it/s]

{'eval_loss': 1.3015220165252686, 'eval_runtime': 5.6145, 'eval_samples_per_second': 669.695, 'eval_steps_per_second': 83.712, 'epoch': 1.2}


 25%|██▌       | 11500/45900 [05:22<09:40, 59.24it/s]  

{'loss': 1.3487, 'grad_norm': 0.5583910942077637, 'learning_rate': 0.0003747712418300654, 'epoch': 1.25}


                                                     
 25%|██▌       | 11500/45900 [05:28<09:40, 59.24it/s]

{'eval_loss': 1.3012371063232422, 'eval_runtime': 5.5888, 'eval_samples_per_second': 672.768, 'eval_steps_per_second': 84.096, 'epoch': 1.25}


 26%|██▌       | 12000/45900 [05:36<09:23, 60.11it/s]  

{'loss': 1.3441, 'grad_norm': 1.741873893479351e-05, 'learning_rate': 0.00036932461873638345, 'epoch': 1.31}


                                                     
 26%|██▌       | 12000/45900 [05:42<09:23, 60.11it/s]

{'eval_loss': 1.301084041595459, 'eval_runtime': 5.6364, 'eval_samples_per_second': 667.092, 'eval_steps_per_second': 83.386, 'epoch': 1.31}


 27%|██▋       | 12500/45900 [05:51<09:01, 61.71it/s]  

{'loss': 1.3768, 'grad_norm': 0.34278687834739685, 'learning_rate': 0.0003638779956427015, 'epoch': 1.36}


                                                     
 27%|██▋       | 12500/45900 [05:56<09:01, 61.71it/s]

{'eval_loss': 1.3004850149154663, 'eval_runtime': 5.6219, 'eval_samples_per_second': 668.816, 'eval_steps_per_second': 83.602, 'epoch': 1.36}


 28%|██▊       | 13000/45900 [06:05<09:00, 60.92it/s]  

{'loss': 1.2873, 'grad_norm': 0.3901369869709015, 'learning_rate': 0.00035843137254901967, 'epoch': 1.42}


                                                     
 28%|██▊       | 13000/45900 [06:11<09:00, 60.92it/s]

{'eval_loss': 1.2998204231262207, 'eval_runtime': 5.5956, 'eval_samples_per_second': 671.962, 'eval_steps_per_second': 83.995, 'epoch': 1.42}


 29%|██▉       | 13500/45900 [06:19<08:49, 61.18it/s]  

{'loss': 1.3075, 'grad_norm': 0.9095379710197449, 'learning_rate': 0.0003529956427015251, 'epoch': 1.47}


                                                     
 29%|██▉       | 13500/45900 [06:25<08:49, 61.18it/s]

{'eval_loss': 1.2996413707733154, 'eval_runtime': 5.6109, 'eval_samples_per_second': 670.124, 'eval_steps_per_second': 83.765, 'epoch': 1.47}


 31%|███       | 14000/45900 [06:34<08:40, 61.28it/s]  

{'loss': 1.3165, 'grad_norm': 0.3350384831428528, 'learning_rate': 0.0003475599128540305, 'epoch': 1.53}


                                                     
 31%|███       | 14000/45900 [06:39<08:40, 61.28it/s]

{'eval_loss': 1.299883246421814, 'eval_runtime': 5.6018, 'eval_samples_per_second': 671.207, 'eval_steps_per_second': 83.901, 'epoch': 1.53}


 32%|███▏      | 14500/45900 [06:48<08:31, 61.37it/s]  

{'loss': 1.3261, 'grad_norm': 0.5183135271072388, 'learning_rate': 0.0003421132897603486, 'epoch': 1.58}


                                                     
 32%|███▏      | 14500/45900 [06:53<08:31, 61.37it/s]

{'eval_loss': 1.2995007038116455, 'eval_runtime': 5.6026, 'eval_samples_per_second': 671.116, 'eval_steps_per_second': 83.89, 'epoch': 1.58}


 33%|███▎      | 15000/45900 [07:02<08:24, 61.23it/s]  

{'loss': 1.2694, 'grad_norm': 0.7671722173690796, 'learning_rate': 0.0003366666666666667, 'epoch': 1.63}


                                                     
 33%|███▎      | 15000/45900 [07:08<08:24, 61.23it/s]

{'eval_loss': 1.2982982397079468, 'eval_runtime': 5.6001, 'eval_samples_per_second': 671.413, 'eval_steps_per_second': 83.927, 'epoch': 1.63}


 34%|███▍      | 15500/45900 [07:16<08:17, 61.07it/s]  

{'loss': 1.28, 'grad_norm': 0.8340439200401306, 'learning_rate': 0.00033122004357298476, 'epoch': 1.69}


                                                     
 34%|███▍      | 15500/45900 [07:22<08:17, 61.07it/s]

{'eval_loss': 1.2985029220581055, 'eval_runtime': 5.6109, 'eval_samples_per_second': 670.122, 'eval_steps_per_second': 83.765, 'epoch': 1.69}


 35%|███▍      | 16000/45900 [07:31<08:09, 61.11it/s]  

{'loss': 1.2834, 'grad_norm': 0.35049009323120117, 'learning_rate': 0.00032577342047930286, 'epoch': 1.74}


                                                     
 35%|███▍      | 16000/45900 [07:36<08:09, 61.11it/s]

{'eval_loss': 1.2971850633621216, 'eval_runtime': 5.5978, 'eval_samples_per_second': 671.689, 'eval_steps_per_second': 83.961, 'epoch': 1.74}


 36%|███▌      | 16500/45900 [07:45<08:12, 59.70it/s]  

{'loss': 1.3577, 'grad_norm': 0.5624255537986755, 'learning_rate': 0.0003203267973856209, 'epoch': 1.8}


                                                     
 36%|███▌      | 16500/45900 [07:50<08:12, 59.70it/s]

{'eval_loss': 1.2979676723480225, 'eval_runtime': 5.6372, 'eval_samples_per_second': 667.001, 'eval_steps_per_second': 83.375, 'epoch': 1.8}


 37%|███▋      | 17000/45900 [07:59<08:04, 59.66it/s]  

{'loss': 1.3597, 'grad_norm': 0.5020273923873901, 'learning_rate': 0.000314880174291939, 'epoch': 1.85}


                                                     
 37%|███▋      | 17000/45900 [08:05<08:04, 59.66it/s]

{'eval_loss': 1.2985306978225708, 'eval_runtime': 5.6366, 'eval_samples_per_second': 667.07, 'eval_steps_per_second': 83.384, 'epoch': 1.85}


 38%|███▊      | 17500/45900 [08:14<07:53, 60.02it/s]  

{'loss': 1.327, 'grad_norm': 0.48254236578941345, 'learning_rate': 0.00030943355119825714, 'epoch': 1.91}


                                                     
 38%|███▊      | 17500/45900 [08:19<07:53, 60.02it/s]

{'eval_loss': 1.2968897819519043, 'eval_runtime': 5.6289, 'eval_samples_per_second': 667.983, 'eval_steps_per_second': 83.498, 'epoch': 1.91}


 39%|███▉      | 18000/45900 [08:28<07:43, 60.18it/s]  

{'loss': 1.2895, 'grad_norm': 0.5284808278083801, 'learning_rate': 0.0003039869281045752, 'epoch': 1.96}


                                                     
 39%|███▉      | 18000/45900 [08:34<07:43, 60.18it/s]

{'eval_loss': 1.2972544431686401, 'eval_runtime': 5.692, 'eval_samples_per_second': 660.576, 'eval_steps_per_second': 82.572, 'epoch': 1.96}


 40%|████      | 18500/45900 [08:42<07:29, 61.01it/s]  

{'loss': 1.2946, 'grad_norm': 0.5666742324829102, 'learning_rate': 0.0002985511982570806, 'epoch': 2.02}


                                                     
 40%|████      | 18500/45900 [08:48<07:29, 61.01it/s]

{'eval_loss': 1.2974754571914673, 'eval_runtime': 5.6535, 'eval_samples_per_second': 665.079, 'eval_steps_per_second': 83.135, 'epoch': 2.02}


 41%|████▏     | 19000/45900 [08:57<07:18, 61.32it/s]  

{'loss': 1.2729, 'grad_norm': 0.4621662497520447, 'learning_rate': 0.0002931045751633987, 'epoch': 2.07}


                                                     
 41%|████▏     | 19000/45900 [09:02<07:18, 61.32it/s]

{'eval_loss': 1.2962725162506104, 'eval_runtime': 5.6032, 'eval_samples_per_second': 671.044, 'eval_steps_per_second': 83.881, 'epoch': 2.07}


 42%|████▏     | 19500/45900 [09:11<07:05, 62.07it/s]  

{'loss': 1.299, 'grad_norm': 0.6738652586936951, 'learning_rate': 0.00028765795206971677, 'epoch': 2.12}


                                                     
 42%|████▏     | 19500/45900 [09:16<07:05, 62.07it/s]

{'eval_loss': 1.2966628074645996, 'eval_runtime': 5.544, 'eval_samples_per_second': 678.209, 'eval_steps_per_second': 84.776, 'epoch': 2.12}


 44%|████▎     | 20000/45900 [09:25<06:55, 62.40it/s]  

{'loss': 1.3229, 'grad_norm': 0.0956282764673233, 'learning_rate': 0.0002822113289760349, 'epoch': 2.18}


                                                     
 44%|████▎     | 20000/45900 [09:31<06:55, 62.40it/s]

{'eval_loss': 1.2965831756591797, 'eval_runtime': 5.5291, 'eval_samples_per_second': 680.037, 'eval_steps_per_second': 85.005, 'epoch': 2.18}


 45%|████▍     | 20500/45900 [09:39<07:00, 60.43it/s]  

{'loss': 1.2647, 'grad_norm': 0.38905590772628784, 'learning_rate': 0.00027677559912854034, 'epoch': 2.23}


                                                     
 45%|████▍     | 20500/45900 [09:45<07:00, 60.43it/s]

{'eval_loss': 1.2954447269439697, 'eval_runtime': 5.6003, 'eval_samples_per_second': 671.387, 'eval_steps_per_second': 83.923, 'epoch': 2.23}


 46%|████▌     | 21000/45900 [09:53<06:44, 61.63it/s]  

{'loss': 1.3577, 'grad_norm': 0.7945974469184875, 'learning_rate': 0.0002713289760348584, 'epoch': 2.29}


                                                     
 46%|████▌     | 21000/45900 [09:59<06:44, 61.63it/s]

{'eval_loss': 1.2966231107711792, 'eval_runtime': 5.5677, 'eval_samples_per_second': 675.322, 'eval_steps_per_second': 84.415, 'epoch': 2.29}


 47%|████▋     | 21500/45900 [10:08<06:45, 60.23it/s]  

{'loss': 1.3211, 'grad_norm': 0.6977376937866211, 'learning_rate': 0.00026588235294117645, 'epoch': 2.34}


                                                     
 47%|████▋     | 21500/45900 [10:13<06:45, 60.23it/s]

{'eval_loss': 1.2959312200546265, 'eval_runtime': 5.599, 'eval_samples_per_second': 671.545, 'eval_steps_per_second': 83.943, 'epoch': 2.34}


 48%|████▊     | 22000/45900 [10:22<06:34, 60.55it/s]  

{'loss': 1.2751, 'grad_norm': 0.5582348108291626, 'learning_rate': 0.00026044662309368196, 'epoch': 2.4}


                                                     
 48%|████▊     | 22000/45900 [10:27<06:34, 60.55it/s]

{'eval_loss': 1.2957379817962646, 'eval_runtime': 5.5879, 'eval_samples_per_second': 672.882, 'eval_steps_per_second': 84.11, 'epoch': 2.4}


 49%|████▉     | 22500/45900 [10:36<06:24, 60.78it/s]  

{'loss': 1.344, 'grad_norm': 0.7002683281898499, 'learning_rate': 0.000255, 'epoch': 2.45}


                                                     
 49%|████▉     | 22500/45900 [10:42<06:24, 60.78it/s]

{'eval_loss': 1.2957994937896729, 'eval_runtime': 5.5899, 'eval_samples_per_second': 672.643, 'eval_steps_per_second': 84.08, 'epoch': 2.45}


 50%|█████     | 23000/45900 [10:50<06:13, 61.36it/s]  

{'loss': 1.3548, 'grad_norm': 0.5973064303398132, 'learning_rate': 0.0002495533769063181, 'epoch': 2.51}


                                                     
 50%|█████     | 23000/45900 [10:56<06:13, 61.36it/s]

{'eval_loss': 1.296613097190857, 'eval_runtime': 5.5897, 'eval_samples_per_second': 672.666, 'eval_steps_per_second': 84.083, 'epoch': 2.51}


 51%|█████     | 23500/45900 [11:05<06:07, 60.98it/s]  

{'loss': 1.2652, 'grad_norm': 0.42135030031204224, 'learning_rate': 0.00024410675381263616, 'epoch': 2.56}


                                                     
 51%|█████     | 23500/45900 [11:10<06:07, 60.98it/s]

{'eval_loss': 1.2943086624145508, 'eval_runtime': 5.5856, 'eval_samples_per_second': 673.161, 'eval_steps_per_second': 84.145, 'epoch': 2.56}


 52%|█████▏    | 24000/45900 [11:19<05:53, 61.91it/s]  

{'loss': 1.2924, 'grad_norm': 0.6236741542816162, 'learning_rate': 0.00023866013071895427, 'epoch': 2.61}


                                                     
 52%|█████▏    | 24000/45900 [11:24<05:53, 61.91it/s]

{'eval_loss': 1.294661283493042, 'eval_runtime': 5.6265, 'eval_samples_per_second': 668.266, 'eval_steps_per_second': 83.533, 'epoch': 2.61}


 53%|█████▎    | 24500/45900 [11:33<05:46, 61.69it/s]  

{'loss': 1.3007, 'grad_norm': 9.954725828720257e-05, 'learning_rate': 0.00023321350762527232, 'epoch': 2.67}


                                                     
 53%|█████▎    | 24500/45900 [11:39<05:46, 61.69it/s]

{'eval_loss': 1.294481635093689, 'eval_runtime': 5.5669, 'eval_samples_per_second': 675.416, 'eval_steps_per_second': 84.427, 'epoch': 2.67}


 54%|█████▍    | 25000/45900 [11:47<05:43, 60.87it/s]  

{'loss': 1.284, 'grad_norm': 0.32376566529273987, 'learning_rate': 0.00022776688453159043, 'epoch': 2.72}


                                                     
 54%|█████▍    | 25000/45900 [11:53<05:43, 60.87it/s]

{'eval_loss': 1.294246792793274, 'eval_runtime': 5.5722, 'eval_samples_per_second': 674.783, 'eval_steps_per_second': 84.348, 'epoch': 2.72}


 56%|█████▌    | 25500/45900 [12:02<05:31, 61.46it/s]  

{'loss': 1.3557, 'grad_norm': 0.579584538936615, 'learning_rate': 0.0002223202614379085, 'epoch': 2.78}


                                                     
 56%|█████▌    | 25500/45900 [12:07<05:31, 61.46it/s]

{'eval_loss': 1.294352650642395, 'eval_runtime': 5.5667, 'eval_samples_per_second': 675.44, 'eval_steps_per_second': 84.43, 'epoch': 2.78}


 57%|█████▋    | 26000/45900 [12:16<05:23, 61.48it/s]  

{'loss': 1.3178, 'grad_norm': 0.4015756845474243, 'learning_rate': 0.00021687363834422657, 'epoch': 2.83}


                                                     
 57%|█████▋    | 26000/45900 [12:21<05:23, 61.48it/s]

{'eval_loss': 1.2939198017120361, 'eval_runtime': 5.5638, 'eval_samples_per_second': 675.801, 'eval_steps_per_second': 84.475, 'epoch': 2.83}


 58%|█████▊    | 26500/45900 [12:30<05:14, 61.66it/s]  

{'loss': 1.3658, 'grad_norm': 0.43814343214035034, 'learning_rate': 0.00021142701525054468, 'epoch': 2.89}


                                                     
 58%|█████▊    | 26500/45900 [12:35<05:14, 61.66it/s]

{'eval_loss': 1.2943748235702515, 'eval_runtime': 5.591, 'eval_samples_per_second': 672.506, 'eval_steps_per_second': 84.063, 'epoch': 2.89}


 59%|█████▉    | 27000/45900 [12:44<05:06, 61.59it/s]  

{'loss': 1.2802, 'grad_norm': 0.5247787833213806, 'learning_rate': 0.00020598039215686276, 'epoch': 2.94}


                                                     
 59%|█████▉    | 27000/45900 [12:50<05:06, 61.59it/s]

{'eval_loss': 1.2942544221878052, 'eval_runtime': 5.5986, 'eval_samples_per_second': 671.598, 'eval_steps_per_second': 83.95, 'epoch': 2.94}


 60%|█████▉    | 27500/45900 [12:58<05:00, 61.17it/s]  

{'loss': 1.3142, 'grad_norm': 0.5894718766212463, 'learning_rate': 0.00020053376906318081, 'epoch': 3.0}


                                                     
 60%|█████▉    | 27500/45900 [13:04<05:00, 61.17it/s]

{'eval_loss': 1.294346809387207, 'eval_runtime': 5.5864, 'eval_samples_per_second': 673.06, 'eval_steps_per_second': 84.132, 'epoch': 3.0}


 61%|██████    | 28000/45900 [13:12<04:52, 61.13it/s]  

{'loss': 1.263, 'grad_norm': 0.7964025139808655, 'learning_rate': 0.00019509803921568628, 'epoch': 3.05}


                                                     
 61%|██████    | 28000/45900 [13:18<04:52, 61.13it/s]

{'eval_loss': 1.2939245700836182, 'eval_runtime': 5.599, 'eval_samples_per_second': 671.552, 'eval_steps_per_second': 83.944, 'epoch': 3.05}


 62%|██████▏   | 28500/45900 [13:27<04:42, 61.58it/s]  

{'loss': 1.258, 'grad_norm': 0.4743024408817291, 'learning_rate': 0.00018965141612200438, 'epoch': 3.1}


                                                     
 62%|██████▏   | 28500/45900 [13:32<04:42, 61.58it/s]

{'eval_loss': 1.2932100296020508, 'eval_runtime': 5.5937, 'eval_samples_per_second': 672.179, 'eval_steps_per_second': 84.022, 'epoch': 3.1}


 63%|██████▎   | 29000/45900 [13:41<04:35, 61.36it/s]  

{'loss': 1.2892, 'grad_norm': 0.37552690505981445, 'learning_rate': 0.00018420479302832244, 'epoch': 3.16}


                                                     
 63%|██████▎   | 29000/45900 [13:47<04:35, 61.36it/s]

{'eval_loss': 1.292827844619751, 'eval_runtime': 5.5963, 'eval_samples_per_second': 671.872, 'eval_steps_per_second': 83.984, 'epoch': 3.16}


 64%|██████▍   | 29500/45900 [13:55<04:26, 61.55it/s]  

{'loss': 1.3619, 'grad_norm': 0.5696796178817749, 'learning_rate': 0.00017875816993464052, 'epoch': 3.21}


                                                     
 64%|██████▍   | 29500/45900 [14:01<04:26, 61.55it/s]

{'eval_loss': 1.2933014631271362, 'eval_runtime': 5.5504, 'eval_samples_per_second': 677.424, 'eval_steps_per_second': 84.678, 'epoch': 3.21}


 65%|██████▌   | 30000/45900 [14:09<04:18, 61.51it/s]  

{'loss': 1.3409, 'grad_norm': 0.7194510698318481, 'learning_rate': 0.00017332244008714598, 'epoch': 3.27}


                                                     
 65%|██████▌   | 30000/45900 [14:15<04:18, 61.51it/s]

{'eval_loss': 1.2930142879486084, 'eval_runtime': 5.5604, 'eval_samples_per_second': 676.207, 'eval_steps_per_second': 84.526, 'epoch': 3.27}


 66%|██████▋   | 30500/45900 [14:23<04:10, 61.53it/s]  

{'loss': 1.2782, 'grad_norm': 0.4689163565635681, 'learning_rate': 0.00016787581699346404, 'epoch': 3.32}


                                                     
 66%|██████▋   | 30500/45900 [14:29<04:10, 61.53it/s]

{'eval_loss': 1.2932772636413574, 'eval_runtime': 5.587, 'eval_samples_per_second': 672.99, 'eval_steps_per_second': 84.124, 'epoch': 3.32}


 68%|██████▊   | 31000/45900 [14:38<04:02, 61.32it/s]  

{'loss': 1.3225, 'grad_norm': 0.8495133519172668, 'learning_rate': 0.00016242919389978215, 'epoch': 3.38}


                                                     
 68%|██████▊   | 31000/45900 [14:43<04:02, 61.32it/s]

{'eval_loss': 1.2927476167678833, 'eval_runtime': 5.5689, 'eval_samples_per_second': 675.175, 'eval_steps_per_second': 84.397, 'epoch': 3.38}


 69%|██████▊   | 31500/45900 [14:52<03:53, 61.59it/s]  

{'loss': 1.2592, 'grad_norm': 0.4802933633327484, 'learning_rate': 0.00015698257080610023, 'epoch': 3.43}


                                                     
 69%|██████▊   | 31500/45900 [14:57<03:53, 61.59it/s]

{'eval_loss': 1.2931280136108398, 'eval_runtime': 5.5834, 'eval_samples_per_second': 673.423, 'eval_steps_per_second': 84.178, 'epoch': 3.43}


 70%|██████▉   | 32000/45900 [15:06<04:02, 57.31it/s]  

{'loss': 1.2587, 'grad_norm': 0.7223238348960876, 'learning_rate': 0.00015154684095860567, 'epoch': 3.49}


                                                     
 70%|██████▉   | 32000/45900 [15:12<04:02, 57.31it/s]

{'eval_loss': 1.292870044708252, 'eval_runtime': 5.7069, 'eval_samples_per_second': 658.852, 'eval_steps_per_second': 82.356, 'epoch': 3.49}


 71%|███████   | 32500/45900 [15:20<03:38, 61.25it/s]  

{'loss': 1.298, 'grad_norm': 0.6399106383323669, 'learning_rate': 0.00014610021786492375, 'epoch': 3.54}


                                                     
 71%|███████   | 32500/45900 [15:26<03:38, 61.25it/s]

{'eval_loss': 1.2920992374420166, 'eval_runtime': 5.5882, 'eval_samples_per_second': 672.843, 'eval_steps_per_second': 84.105, 'epoch': 3.54}


 72%|███████▏  | 33000/45900 [15:35<03:30, 61.17it/s]  

{'loss': 1.3229, 'grad_norm': 0.29755887389183044, 'learning_rate': 0.00014065359477124186, 'epoch': 3.59}


                                                     
 72%|███████▏  | 33000/45900 [15:40<03:30, 61.17it/s]

{'eval_loss': 1.2927446365356445, 'eval_runtime': 5.5859, 'eval_samples_per_second': 673.118, 'eval_steps_per_second': 84.14, 'epoch': 3.59}


 73%|███████▎  | 33500/45900 [15:49<03:22, 61.25it/s]

{'loss': 1.3254, 'grad_norm': 0.41392335295677185, 'learning_rate': 0.0001352069716775599, 'epoch': 3.65}


                                                     
 73%|███████▎  | 33500/45900 [15:54<03:22, 61.25it/s]

{'eval_loss': 1.2925008535385132, 'eval_runtime': 5.5979, 'eval_samples_per_second': 671.675, 'eval_steps_per_second': 83.959, 'epoch': 3.65}


 74%|███████▍  | 34000/45900 [16:03<03:15, 60.96it/s]

{'loss': 1.2777, 'grad_norm': 0.5700676441192627, 'learning_rate': 0.000129760348583878, 'epoch': 3.7}


                                                     
 74%|███████▍  | 34000/45900 [16:09<03:15, 60.96it/s]

{'eval_loss': 1.2928211688995361, 'eval_runtime': 5.5965, 'eval_samples_per_second': 671.847, 'eval_steps_per_second': 83.981, 'epoch': 3.7}


 75%|███████▌  | 34500/45900 [16:17<03:05, 61.31it/s]

{'loss': 1.3477, 'grad_norm': 0.38329634070396423, 'learning_rate': 0.00012431372549019608, 'epoch': 3.76}


                                                     
 75%|███████▌  | 34500/45900 [16:23<03:05, 61.31it/s]

{'eval_loss': 1.2927082777023315, 'eval_runtime': 5.5992, 'eval_samples_per_second': 671.525, 'eval_steps_per_second': 83.941, 'epoch': 3.76}


 76%|███████▋  | 35000/45900 [16:32<03:00, 60.29it/s]

{'loss': 1.2898, 'grad_norm': 0.6423847675323486, 'learning_rate': 0.00011888888888888889, 'epoch': 3.81}


                                                     
 76%|███████▋  | 35000/45900 [16:37<03:00, 60.29it/s]

{'eval_loss': 1.291896939277649, 'eval_runtime': 5.5955, 'eval_samples_per_second': 671.971, 'eval_steps_per_second': 83.996, 'epoch': 3.81}


 77%|███████▋  | 35500/45900 [16:46<02:49, 61.20it/s]

{'loss': 1.3644, 'grad_norm': 0.43859440088272095, 'learning_rate': 0.00011344226579520697, 'epoch': 3.87}


                                                     
 77%|███████▋  | 35500/45900 [16:52<02:49, 61.20it/s]

{'eval_loss': 1.2916433811187744, 'eval_runtime': 5.5946, 'eval_samples_per_second': 672.079, 'eval_steps_per_second': 84.01, 'epoch': 3.87}


 78%|███████▊  | 36000/45900 [17:00<02:42, 60.77it/s]

{'loss': 1.3344, 'grad_norm': 0.6571110486984253, 'learning_rate': 0.00010799564270152506, 'epoch': 3.92}


                                                     
 78%|███████▊  | 36000/45900 [17:06<02:42, 60.77it/s]

{'eval_loss': 1.2921282052993774, 'eval_runtime': 5.6302, 'eval_samples_per_second': 667.832, 'eval_steps_per_second': 83.479, 'epoch': 3.92}


 80%|███████▉  | 36500/45900 [17:15<02:35, 60.35it/s]

{'loss': 1.2987, 'grad_norm': 0.43901127576828003, 'learning_rate': 0.0001025599128540305, 'epoch': 3.98}


                                                     
 80%|███████▉  | 36500/45900 [17:20<02:35, 60.35it/s]

{'eval_loss': 1.2923061847686768, 'eval_runtime': 5.6313, 'eval_samples_per_second': 667.693, 'eval_steps_per_second': 83.462, 'epoch': 3.98}


 81%|████████  | 37000/45900 [17:29<02:27, 60.36it/s]

{'loss': 1.2782, 'grad_norm': 0.4973202049732208, 'learning_rate': 9.711328976034859e-05, 'epoch': 4.03}


                                                     
 81%|████████  | 37000/45900 [17:35<02:27, 60.36it/s]

{'eval_loss': 1.2920007705688477, 'eval_runtime': 5.6258, 'eval_samples_per_second': 668.352, 'eval_steps_per_second': 83.544, 'epoch': 4.03}


 82%|████████▏ | 37500/45900 [17:43<02:19, 60.34it/s]

{'loss': 1.2808, 'grad_norm': 0.5316959619522095, 'learning_rate': 9.166666666666667e-05, 'epoch': 4.08}


                                                     
 82%|████████▏ | 37500/45900 [17:49<02:19, 60.34it/s]

{'eval_loss': 1.291361689567566, 'eval_runtime': 5.6382, 'eval_samples_per_second': 666.875, 'eval_steps_per_second': 83.359, 'epoch': 4.08}


 83%|████████▎ | 38000/45900 [17:58<02:10, 60.31it/s]

{'loss': 1.3225, 'grad_norm': 0.47546839714050293, 'learning_rate': 8.622004357298475e-05, 'epoch': 4.14}


                                                     
 83%|████████▎ | 38000/45900 [18:03<02:10, 60.31it/s]

{'eval_loss': 1.2913864850997925, 'eval_runtime': 5.6274, 'eval_samples_per_second': 668.157, 'eval_steps_per_second': 83.52, 'epoch': 4.14}


 84%|████████▍ | 38500/45900 [18:12<02:02, 60.22it/s]

{'loss': 1.2836, 'grad_norm': 0.5444508790969849, 'learning_rate': 8.077342047930283e-05, 'epoch': 4.19}


                                                     
 84%|████████▍ | 38500/45900 [18:18<02:02, 60.22it/s]

{'eval_loss': 1.2912814617156982, 'eval_runtime': 5.6278, 'eval_samples_per_second': 668.108, 'eval_steps_per_second': 83.513, 'epoch': 4.19}


 85%|████████▍ | 39000/45900 [18:26<01:54, 60.27it/s]

{'loss': 1.2505, 'grad_norm': 0.321623295545578, 'learning_rate': 7.532679738562091e-05, 'epoch': 4.25}


                                                     
 85%|████████▍ | 39000/45900 [18:32<01:54, 60.27it/s]

{'eval_loss': 1.2920604944229126, 'eval_runtime': 5.6279, 'eval_samples_per_second': 668.097, 'eval_steps_per_second': 83.512, 'epoch': 4.25}


 86%|████████▌ | 39500/45900 [18:41<01:48, 58.82it/s]

{'loss': 1.339, 'grad_norm': 0.8449842929840088, 'learning_rate': 6.988017429193901e-05, 'epoch': 4.3}


                                                     
 86%|████████▌ | 39500/45900 [18:47<01:48, 58.82it/s]

{'eval_loss': 1.2916592359542847, 'eval_runtime': 5.6229, 'eval_samples_per_second': 668.691, 'eval_steps_per_second': 83.586, 'epoch': 4.3}


 87%|████████▋ | 40000/45900 [18:55<01:36, 60.97it/s]

{'loss': 1.2575, 'grad_norm': 0.4686736464500427, 'learning_rate': 6.443355119825708e-05, 'epoch': 4.36}


                                                     
 87%|████████▋ | 40000/45900 [19:01<01:36, 60.97it/s]

{'eval_loss': 1.2915854454040527, 'eval_runtime': 5.6303, 'eval_samples_per_second': 667.813, 'eval_steps_per_second': 83.477, 'epoch': 4.36}


 88%|████████▊ | 40500/45900 [19:10<01:28, 60.74it/s]

{'loss': 1.2429, 'grad_norm': 0.5715323090553284, 'learning_rate': 5.899782135076253e-05, 'epoch': 4.41}


                                                     
 88%|████████▊ | 40500/45900 [19:15<01:28, 60.74it/s]

{'eval_loss': 1.2912774085998535, 'eval_runtime': 5.5503, 'eval_samples_per_second': 677.44, 'eval_steps_per_second': 84.68, 'epoch': 4.41}


 89%|████████▉ | 41000/45900 [19:24<01:19, 61.60it/s]

{'loss': 1.2782, 'grad_norm': 0.8759264349937439, 'learning_rate': 5.355119825708061e-05, 'epoch': 4.47}


                                                     
 89%|████████▉ | 41000/45900 [19:29<01:19, 61.60it/s]

{'eval_loss': 1.2913098335266113, 'eval_runtime': 5.5458, 'eval_samples_per_second': 677.986, 'eval_steps_per_second': 84.748, 'epoch': 4.47}


 90%|█████████ | 41500/45900 [19:38<01:11, 61.91it/s]

{'loss': 1.3237, 'grad_norm': 0.3613627552986145, 'learning_rate': 4.810457516339869e-05, 'epoch': 4.52}


                                                     
 90%|█████████ | 41500/45900 [19:44<01:11, 61.91it/s]

{'eval_loss': 1.291319489479065, 'eval_runtime': 5.5953, 'eval_samples_per_second': 671.988, 'eval_steps_per_second': 83.999, 'epoch': 4.52}


 92%|█████████▏| 42000/45900 [19:52<01:03, 61.56it/s]

{'loss': 1.3144, 'grad_norm': 0.511194109916687, 'learning_rate': 4.2657952069716774e-05, 'epoch': 4.58}


                                                     
 92%|█████████▏| 42000/45900 [19:58<01:03, 61.56it/s]

{'eval_loss': 1.2907941341400146, 'eval_runtime': 5.5941, 'eval_samples_per_second': 672.137, 'eval_steps_per_second': 84.017, 'epoch': 4.58}


 93%|█████████▎| 42500/45900 [20:07<00:56, 60.60it/s]

{'loss': 1.2473, 'grad_norm': 0.23391929268836975, 'learning_rate': 3.721132897603486e-05, 'epoch': 4.63}


                                                     
 93%|█████████▎| 42500/45900 [20:12<00:56, 60.60it/s]

{'eval_loss': 1.290838360786438, 'eval_runtime': 5.5765, 'eval_samples_per_second': 674.259, 'eval_steps_per_second': 84.282, 'epoch': 4.63}


 94%|█████████▎| 43000/45900 [20:21<00:47, 61.43it/s]

{'loss': 1.3543, 'grad_norm': 0.430411159992218, 'learning_rate': 3.1775599128540304e-05, 'epoch': 4.68}


                                                     
 94%|█████████▎| 43000/45900 [20:27<00:47, 61.43it/s]

{'eval_loss': 1.290768027305603, 'eval_runtime': 5.5824, 'eval_samples_per_second': 673.545, 'eval_steps_per_second': 84.193, 'epoch': 4.68}


 95%|█████████▍| 43500/45900 [20:35<00:39, 61.27it/s]

{'loss': 1.3953, 'grad_norm': 0.3879627585411072, 'learning_rate': 2.6339869281045752e-05, 'epoch': 4.74}


                                                     
 95%|█████████▍| 43500/45900 [20:41<00:39, 61.27it/s]

{'eval_loss': 1.2906817197799683, 'eval_runtime': 5.609, 'eval_samples_per_second': 670.347, 'eval_steps_per_second': 83.793, 'epoch': 4.74}


 96%|█████████▌| 44000/45900 [20:50<00:31, 60.83it/s]

{'loss': 1.264, 'grad_norm': 0.1860610693693161, 'learning_rate': 2.0893246187363835e-05, 'epoch': 4.79}


                                                     
 96%|█████████▌| 44000/45900 [20:56<00:31, 60.83it/s]

{'eval_loss': 1.2905021905899048, 'eval_runtime': 5.6787, 'eval_samples_per_second': 662.126, 'eval_steps_per_second': 82.766, 'epoch': 4.79}


 97%|█████████▋| 44500/45900 [21:05<00:23, 60.43it/s]

{'loss': 1.3095, 'grad_norm': 0.51039719581604, 'learning_rate': 1.5446623093681917e-05, 'epoch': 4.85}


                                                     
 97%|█████████▋| 44500/45900 [21:10<00:23, 60.43it/s]

{'eval_loss': 1.2904503345489502, 'eval_runtime': 5.6789, 'eval_samples_per_second': 662.102, 'eval_steps_per_second': 82.763, 'epoch': 4.85}


 98%|█████████▊| 45000/45900 [21:19<00:14, 61.05it/s]

{'loss': 1.3574, 'grad_norm': 0.3818575441837311, 'learning_rate': 1e-05, 'epoch': 4.9}


                                                     
 98%|█████████▊| 45000/45900 [21:25<00:14, 61.05it/s]

{'eval_loss': 1.2903918027877808, 'eval_runtime': 5.6253, 'eval_samples_per_second': 668.411, 'eval_steps_per_second': 83.551, 'epoch': 4.9}


 99%|█████████▉| 45500/45900 [21:33<00:06, 60.10it/s]

{'loss': 1.3238, 'grad_norm': 0.6512754559516907, 'learning_rate': 4.553376906318083e-06, 'epoch': 4.96}


                                                     
 99%|█████████▉| 45500/45900 [21:39<00:06, 60.10it/s]

{'eval_loss': 1.2904012203216553, 'eval_runtime': 5.6093, 'eval_samples_per_second': 670.319, 'eval_steps_per_second': 83.79, 'epoch': 4.96}


100%|██████████| 45900/45900 [21:47<00:00, 35.12it/s]

{'train_runtime': 1307.0626, 'train_samples_per_second': 140.46, 'train_steps_per_second': 35.117, 'train_loss': 1.3187543216613902, 'epoch': 5.0}





TrainOutput(global_step=45900, training_loss=1.3187543216613902, metrics={'train_runtime': 1307.0626, 'train_samples_per_second': 140.46, 'train_steps_per_second': 35.117, 'total_flos': 1.201344191004672e+16, 'train_loss': 1.3187543216613902, 'epoch': 5.0})

In [21]:
lora_model.save_pretrained("./lora_gpt2")
tokenizer.save_pretrained("./lora_gpt2")


('./lora_gpt2/tokenizer_config.json',
 './lora_gpt2/special_tokens_map.json',
 './lora_gpt2/vocab.json',
 './lora_gpt2/merges.txt',
 './lora_gpt2/added_tokens.json')

In [22]:
for name, module in lora_model.named_modules():
    print(name)


base_model
base_model.model
base_model.model.transformer
base_model.model.transformer.wte
base_model.model.transformer.wpe
base_model.model.transformer.drop
base_model.model.transformer.h
base_model.model.transformer.h.0
base_model.model.transformer.h.0.ln_1
base_model.model.transformer.h.0.attn
base_model.model.transformer.h.0.attn.c_attn
base_model.model.transformer.h.0.attn.c_attn.base_layer
base_model.model.transformer.h.0.attn.c_attn.lora_dropout
base_model.model.transformer.h.0.attn.c_attn.lora_dropout.default
base_model.model.transformer.h.0.attn.c_attn.lora_A
base_model.model.transformer.h.0.attn.c_attn.lora_A.default
base_model.model.transformer.h.0.attn.c_attn.lora_B
base_model.model.transformer.h.0.attn.c_attn.lora_B.default
base_model.model.transformer.h.0.attn.c_attn.lora_embedding_A
base_model.model.transformer.h.0.attn.c_attn.lora_embedding_B
base_model.model.transformer.h.0.attn.c_attn.lora_magnitude_vector
base_model.model.transformer.h.0.attn.c_proj
base_model.model.

In [23]:

# LoRA 모델 평가
lora_results = lora_trainer.evaluate()
print(f"LoRA Model Results: {lora_results}")

100%|██████████| 470/470 [00:05<00:00, 83.34it/s]

LoRA Model Results: {'eval_loss': 1.2904026508331299, 'eval_runtime': 5.6514, 'eval_samples_per_second': 665.328, 'eval_steps_per_second': 83.166, 'epoch': 5.0}





In [31]:
# 전체 파라미터 수 계산 함수
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# 기존 GPT-2 모델 로드 및 파라미터 수 계산
baseline_model = GPT2LMHeadModel.from_pretrained(model_name).to(device)  # 기존 GPT-2 모델을 GPU로 로드
baseline_params = count_parameters(baseline_model)
print(f"Baseline GPT-2 Model Parameters: {baseline_params}")

# LoRA가 적용된 GPT-2 모델 파라미터 수 계산
lora_model = lora_model.to(device)  # LoRA 모델을 GPU로 이동
lora_params = count_parameters(lora_model)
print(f"LoRA GPT-2 Model Parameters: {lora_params}")

# 파라미터 수 비교
print("Parameter Count Comparison:")
print(f"Baseline GPT-2 Model: {baseline_params} parameters")
print(f"LoRA GPT-2 Model: {lora_params} parameters")


Baseline GPT-2 Model Parameters: 124439808
LoRA GPT-2 Model Parameters: 147456
Parameter Count Comparison:
Baseline GPT-2 Model: 124439808 parameters
LoRA GPT-2 Model: 147456 parameters


In [32]:
print(f"Pad token: {tokenizer.pad_token}, Pad token ID: {tokenizer.pad_token_id}")

Pad token: <|endoftext|>, Pad token ID: 50256


In [33]:
from datasets import load_metric
import evaluate

In [34]:
# 평가 메트릭 로드 (
bleu_metric = evaluate.load("bleu")
meteor_metric = evaluate.load("meteor")

[nltk_data] Downloading package wordnet to /home/park/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/park/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/park/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [40]:
# 참조 문장 설정 
references = [
    ["The cat is on the mat.", "A cat is sitting on the mat.", "There is a cat on the mat.", "The feline is resting on the rug."],
    ["The feline lies on the carpet.", "The cat rests on the mat.", "A small cat is sitting on the mat."],
    ["A cat is on the floor.", "The mat is where the cat lies.", "The cat is lying on the floor mat."]
]



In [41]:
inputs = tokenizer("The cat is", return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
attention_mask = tokenizer("The cat is", return_tensors="pt", padding=True, truncation=True).attention_mask.to(device)

# 1. GPT-2 모델로 텍스트 생성 (Beam Search 및 텍스트 길이 늘리기)
gpt2_generated_outputs = baseline_model.generate(
    inputs, 
    max_length=100,  # 텍스트 길이 늘림
    attention_mask=attention_mask, 
    pad_token_id=tokenizer.pad_token_id, 
    num_beams=5,  # Beam Search 적용
    early_stopping=True  # 적절한 위치에서 멈춤
)
gpt2_generated_text = tokenizer.decode(gpt2_generated_outputs[0], skip_special_tokens=True)

# 2. LoRA 모델로 텍스트 생성 (Beam Search 및 텍스트 길이 늘리기)
lora_generated_outputs = lora_model.generate(
    inputs, 
    max_length=100,  # 텍스트 길이 늘림
    attention_mask=attention_mask, 
    pad_token_id=tokenizer.pad_token_id, 
    num_beams=5,  # Beam Search 적용
    early_stopping=True
)
lora_generated_text = tokenizer.decode(lora_generated_outputs[0], skip_special_tokens=True)

# BLEU 및 METEOR 점수 계산

# GPT-2 모델 BLEU 및 METEOR 점수 계산
gpt2_metrics = {
    "bleu": bleu_metric.compute(predictions=[gpt2_generated_text], references=[references])["bleu"],
    "meteor": meteor_metric.compute(predictions=[gpt2_generated_text], references=[references])["meteor"],
}

# LoRA 모델 BLEU 및 METEOR 점수 계산
lora_metrics = {
    "bleu": bleu_metric.compute(predictions=[lora_generated_text], references=[references])["bleu"],
    "meteor": meteor_metric.compute(predictions=[lora_generated_text], references=[references])["meteor"],
}

# 결과 출력
print("GPT-2 Model Evaluation:")
print(f"BLEU: {gpt2_metrics['bleu']}, METEOR: {gpt2_metrics['meteor']}")

print("LoRA Model Evaluation:")
print(f"BLEU: {lora_metrics['bleu']}, METEOR: {lora_metrics['meteor']}")

GPT-2 Model Evaluation:
BLEU: 0.0, METEOR: 0.2601304945054945
LoRA Model Evaluation:
BLEU: 0.0, METEOR: 0.1756440281030445
