In [1]:
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
from utils import preprocess_text
from models.transformer import Transformer
from data.collate_fn import collate_fn
from train.lr_scheduler import NoamLR
from train.loss import loss_function
from train.train import train_loop, eval_loop

import torch
from torch.optim import Adam

from data.dataset import AihubTranslationDataset
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


#### 데이터 전처리

In [3]:
# Dataset
train_ds = AihubTranslationDataset(
    csv_path="ai_hub_dataset/train_filtered.csv",
    preprocess_fn=preprocess_text,
    max_len=60,
    add_special_tokens=True,
)
test_ds = AihubTranslationDataset(
    csv_path="ai_hub_dataset/test_filtered.csv",
    preprocess_fn=preprocess_text,
    max_len=60,
    add_special_tokens=True,
)

In [4]:
print(test_ds[1972]["src_ids"])

tensor([    2, 29398,  2252,  2073,  4022,  5831,  3795,  4524,  2138,  6066,
         2118,  1380,  2073,  1643,  1890,  2328, 12850,  2299,  2118,  6082,
         2037,  2116,  2199, 13682,  2371,  4381,  4973,  2170,  1513,  2414,
         3857,  2145,  6032,   886,  2052, 29398,  2252,  2079,  8960,  2138,
        10750,  2227,  7305,  2886,  2062,    18,     3])


In [6]:
from transformers import AutoTokenizer

src_tokenizer = AutoTokenizer.from_pretrained("klue/bert-base")

print(src_tokenizer.decode(test_ds[1972]["src_ids"]))

[CLS] 손승원은 사고 직후 아무 조치를 취하지 않은 채 학동 사거리까지 150m가량 도주했으나 인근에 있던 시민과 택시 등이 손승원의 승용차를 가로막아 붙잡았다. [SEP]


#### 모델 학습

In [7]:
SRC_VOCAB_SIZE = 32000  # "klue/bert-base" tokenizer vocab_size
TGT_VOCAB_SIZE = 30522  # "bert-base-uncased" tokenizer vocab_size
EPOCHS = 8
BATCH_SIZE = 128
D_MODEL = 256
D_FF = 1024

In [8]:
train_loader = DataLoader(
    dataset=train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=4,
    collate_fn=collate_fn,
)
test_loader = DataLoader(
    dataset=test_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=4,
    collate_fn=collate_fn,
)

In [10]:
# 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 모델 생성 및 디바이스 이동
model = Transformer(
    src_vocab_size=SRC_VOCAB_SIZE,
    tgt_vocab_size=TGT_VOCAB_SIZE,
    src_len=60,  # 인코더 입력 길이 제한
    tgt_len=60,  # 디코더 입력 길이 제한
    d_model=D_MODEL,
    d_ff=D_FF,
    n_heads=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    dropout=0.3,
).to(device)

In [11]:
# optimizer & lr scheduler 정의
optimizer = Adam(model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9)
scheduler = NoamLR(optimizer, d_model=D_MODEL, warmup_steps=4000)

In [None]:
# all_losses = []
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}\n-------------------------------")
    epoch_loss, losses = train_loop(
        train_loader, model, loss_function, optimizer, scheduler, device
    )
    # all_losses.extend(losses)
    print(f"[epoch {epoch+1}] avg_train_loss = {epoch_loss:.4f}")

    epoch_eval_loss, perplexity = eval_loop(test_loader, model, loss_function, device)
    print(
        f"[epoch {epoch+1}] avg_eval_loss = {epoch_eval_loss:.4f} | perplexity = {perplexity:.4f}"
    )

    torch.save(model.state_dict(), "checkpoints/aihub-ko2en-transformer_1_epoch.pt")

Epoch 1
-------------------------------


 17%|█▋        | 1000/5782 [30:33<2:31:39,  1.90s/it]

[train] epoch_end | loss 6.0795


 20%|█▉        | 1135/5782 [34:37<2:21:44,  1.83s/it]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7430d52ec550>>
Traceback (most recent call last):
  File "/home/masang/anaconda3/envs/torchenv/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 



KeyboardInterrupt: 

#### 번역문 생성

In [26]:
from inference import get_bleu_score
import pandas as pd

In [27]:
target_df = pd.read_csv("ai_hub_dataset/test_filtered.csv")["번역문"]

In [28]:
print(get_bleu_score(target_df, target_df))

0.9999997417575798
