In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import math
import unicodedata
import re
import gc
import torch.nn as nn
import torch.nn.functional as F
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW
import CustomBert
from WordPieceTokenizer import WordPieceTokenizer as Tokenizer
import os
import matplotlib.pyplot as plt
from datasets import load_dataset
from PretrainDataset import TokenizedDataset
from PretrainDataset import CustomDataCollatorForMLM

In [2]:
def group_texts(examples, tokenizer, MAX_SEQUENCE_LENGTH):
    concatenated_text = " ".join(examples["text"])

    encoded_output = tokenizer.encode(
        concatenated_text,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        padding=True,
        add_special_tokens=True
    )
    
    return {
        "input_ids": [torch.tensor(encoded_output['input_ids'], dtype=torch.long)],
        "attention_mask": [torch.tensor(encoded_output['attention_mask'], dtype=torch.long)],
        "token_type_ids": [torch.tensor(encoded_output['token_type_ids'], dtype=torch.long)],
    }

In [3]:
tokenizer = Tokenizer(vocab_file_path="Pretrained/vocab.txt",do_lower_case=False,strip_accents=False,clean_text=True)
VOCAB_SIZE = tokenizer.get_vocab_size()
MAX_SEQUENCE_LENGTH = 128
BATCH_SIZE = 8

datasetsPath = 'datasets/'
PREPROCESSED_TEXT_DIR = f'{datasetsPath}preprocess_wiki_text'

text_files = [os.path.join(PREPROCESSED_TEXT_DIR, f) for f in os.listdir(PREPROCESSED_TEXT_DIR) if f.endswith('.txt')]

if not text_files:
    exit()
print(f'총 {len(text_files)}개의 텍스트 파일 로드 시작...')
raw_dataset = load_dataset("text", data_files={"train": text_files}, split="train")
print(f"원시 데이터셋 로드 완료. 총 {len(raw_dataset)}개의 샘플.")
    
print('데이터셋 토큰화 및 청킹 시작...')
tokenized_dataset = raw_dataset.map(
    group_texts,
    batched=True,
    num_proc=1,
    remove_columns=["text"],
    fn_kwargs={"tokenizer":tokenizer,"MAX_SEQUENCE_LENGTH":MAX_SEQUENCE_LENGTH},
    desc=f"맵핑 데이터셋 (토큰화 및 청킹, 최대 길이 {MAX_SEQUENCE_LENGTH})"
)
print(len(tokenized_dataset))

총 8개의 텍스트 파일 로드 시작...
원시 데이터셋 로드 완료. 총 702964개의 샘플.
데이터셋 토큰화 및 청킹 시작...
703


In [4]:
print("토큰화된 데이터셋 저장 중...")
tokenized_dataset_path = "datasets/tokenized_dataset"
tokenized_dataset.save_to_disk(tokenized_dataset_path)
print(f"토큰화된 데이터셋이 '{tokenized_dataset_path}'에 저장되었습니다.")

토큰화된 데이터셋 저장 중...


Saving the dataset (0/1 shards):   0%|          | 0/703 [00:00<?, ? examples/s]

토큰화된 데이터셋이 'datasets/tokenized_dataset'에 저장되었습니다.


In [5]:
train_dataset = TokenizedDataset(tokenized_dataset)

data_collator = CustomDataCollatorForMLM(tokenizer=tokenizer, mlm_probability=0.15)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=data_collator,
    num_workers=os.cpu_count()
)

print(f"PyTorch DataLoader 생성 완료. 배치 크기: {BATCH_SIZE}")
print(f"총 훈련 배치 수: {len(train_dataloader)}")

PyTorch DataLoader 생성 완료. 배치 크기: 8
총 훈련 배치 수: 88


In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_SAVE_PATH = "saves/Pretrain.pt"

HIDDEN_SIZE = 768
NUM_HIDDEN_LAYERS = 12
NUM_ATTENTION_HEADS = 12
INTERMEDIATE_SIZE = 3072
TYPE_VOCAB_SIZE = 2
DROPOUT_PROB = 0.1
if os.path.exists(MODEL_SAVE_PATH):
    model = torch.load(MODEL_SAVE_PATH,weights_only=False)
else:
    model = CustomBert.CustomBertForMaskedLM(
        vocab_size=VOCAB_SIZE,
        hidden_size=HIDDEN_SIZE,
        num_hidden_layers=NUM_HIDDEN_LAYERS,
        num_attention_heads=NUM_ATTENTION_HEADS,
        intermediate_size=INTERMEDIATE_SIZE,
        max_position_embeddings=MAX_SEQUENCE_LENGTH,
        type_vocab_size=TYPE_VOCAB_SIZE,
        dropout_prob=DROPOUT_PROB
    )
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Custom Bert 모델 초기화 완료. 총 학습 가능 파라미터 수 : {num_params}')
print(f'모델이 담긴 장치 : {device}')

AttributeError: 'collections.OrderedDict' object has no attribute 'to'

In [None]:
EPOCHS = 50
LEARNING_RATE = 2e-4
WEIGHT_DECAY = 0.01
WARMUP_STEPS = 5000
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
total_steps = len(train_dataloader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS,num_training_steps=total_steps)

In [None]:
train_losses = []
scaler = torch.cuda.amp.GradScaler()

print(f"\n<--- 학습 시작 ---> ({EPOCHS} 에폭)")
model.train()

CHECKPOINT_DIR = "pretrain_checkpoints"

prev_loss = 10
for e in range(EPOCHS):
    loss_sum = 0
    progress_bar = tqdm(train_dataloader, desc=f"Pre-train Epoch {e+1}")

    for step, batch in enumerate(progress_bar):
        batch = {k: v.to(device) for k, v in batch.items()}

        for k, v in batch.items():
            if torch.isnan(v).any():
                print(f"Warning: NaN found in batch[{k}] at step {step}")
            if torch.isinf(v).any():
                print(f"Warning: Inf found in batch[{k}] at step {step}")
        
        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            token_type_ids=batch["token_type_ids"],
            labels=batch["labels"]
        )
        loss = outputs["loss"]

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        optimizer.zero_grad()

        loss_sum += loss.item()
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

        del outputs, loss
        if 'cuda' in str(device):
            torch.cuda.empty_cache()
        gc.collect()
        
    avg_train_loss = loss_sum / len(train_dataloader)
    train_losses.append(avg_train_loss) # 에폭별 평균 손실 저장

    print(f"Pre-train Epoch {e+1} 완료. 평균 학습 손실: {avg_train_loss:.4f}")
    
    if avg_train_loss < prev_loss - 0.0001:
        torch.save(model.state_dict(),MODEL_SAVE_PATH)
        prev_loss = avg_train_loss
    
    print(f"모델 가중치 '{MODEL_SAVE_PATH}' 저장 완료.")

print("\n<--- 학습 완료 --->")

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(range(1, EPOCHS + 1), train_losses, marker='o', linestyle='-', color='skyblue')
plt.title('Pre-training Learning Curve')
plt.xlabel('Epoch')
plt.ylabel('Average Training Loss')
plt.grid(True)
plt.xticks(range(1, EPOCHS + 1)) # x축 눈금 에폭 수에 맞춰 표시
plt.show()