In [1]:
from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter
from dataset import BaseDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
from config import model_name 
from tqdm import tqdm
import os
from pathlib import Path
from evaluate import evaluate
import importlib
import datetime
import copy

try:
    # getattr 함수로 model.model_name 모듈 안에서 이름이 model_nam
    # e인 클래스(__init__.py에서서)를 가져옴
    Model = getattr(importlib.import_module(f"model.{model_name}"), model_name)
    config = getattr(importlib.import_module('config'), f"{model_name}Config")
except AttributeError:
    print(f"{model_name} not included!")
    exit()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('Model : ', Model)
print('config : ', config)
print('device : ', device)

Model :  <class 'model.NAML.NAML'>
config :  <class 'config.NAMLConfig'>
device :  cpu


In [2]:
class EarlyStopping:
    def __init__(self, patience=config.early_stop_patience):
        # 조기 종료 기준을 위한 patience 설정
        self.patience = patience
        self.counter = 0
        self.best_loss = np.Inf

    def __call__(self, val_loss):
        """
        val_loss가 감소했는지(=성능이 좋아졌는지) 확인
        => 좋아졌다면 self.best_loss 갱신, counter=0
        => 아니라면 counter += 1
        => counter >= patience면 early_stop
        """
        if val_loss < self.best_loss:
            early_stop = False
            get_better = True
            self.counter = 0
            self.best_loss = val_loss
        else:
            get_better = False
            self.counter += 1
            if self.counter >= self.patience:
                early_stop = True
            else:
                early_stop = False
        if np.isnan(val_loss):
            early_stop = True
        return early_stop, get_better

In [3]:
# def latest_checkpoint(directory):
#     # 가장 최신 체크포인트 찾기 (특정 형식에 맞는 파일 찾기)
#     if not os.path.exists(directory):
#         return None
#     all_checkpoints = {
#         int(x.split('.')[0].split('-')[1]): x
#         for x in os.listdir(directory)
#         if (x.split('.')[0].split('-')[2] == config.candidate_type)
#         if (x.split('.')[0].split('-')[3] == config.our_type)
#         if (x.split('.')[0].split('-')[4] == config.loss_function)
#     }
#     if not all_checkpoints:
#         return None
#     return os.path.join(directory,
#                         all_checkpoints[max(all_checkpoints.keys())])


# def latest_checkpoint(directory):
#     """
#     디렉토리 내에서
#     '{experiment_data}_ep{epoch}.ckpt' 형태의
#     가장 큰 epoch 번호를 가진 체크포인트 파일을 찾아 반환.
#     """
#     if not os.path.exists(directory):
#         return None

#     all_checkpoints = {}
#     for x in os.listdir(directory):
#         # 파일명 예: MyExp_ep3.ckpt
#         if x.endswith('.ckpt'):
#             parts = x.split('_ep')
#             if len(parts) != 2:
#                 continue
#             exp_data_part = parts[0]  # 예: "MyExp"
#             ep_part = parts[1].split('.')[0]  # "3"
#             if exp_data_part == config.experiment_data:
#                 try:
#                     epoch_num = int(ep_part)
#                     all_checkpoints[epoch_num] = x
#                 except ValueError:
#                     continue

#     if not all_checkpoints:
#         return None

#     # 가장 큰 epoch 번호
#     latest_epoch = max(all_checkpoints.keys())
#     return os.path.join(directory, all_checkpoints[latest_epoch])

In [3]:
def time_since(since):
    """
    Format elapsed time string.
    """
    now = time.time()
    elapsed_time = now - since
    return time.strftime("%H:%M:%S", time.gmtime(elapsed_time))

In [4]:
def train():
    """
    모델 학습을 위한 train 함수 정의
    """

    # 결과 파일 이름 설정
    result_file = f"./results/{model_name}/{config.experiment_data}.txt"
    if not os.path.exists('checkpoint'):
        os.makedirs('checkpoint')


    # 사전 학습된 임베딩 불러오기
    try:
        pretrained_word_embedding = torch.from_numpy(
             np.load(f'{config.data_folder}/pretrained_word_embedding.npy')
             ).float()
    except FileNotFoundError:
        pretrained_word_embedding = None


    if model_name == 'DKN':
        # DKN 모델의 경우 추가 임베딩 불러오기 (entity 및 context 임베딩)
        try:
            pretrained_entity_embedding = torch.from_numpy(
                np.load(
                    f'{config.data_folder}/pretrained_entity_embedding.npy')).float()
        except FileNotFoundError:
            pretrained_entity_embedding = None

        try:
            pretrained_context_embedding = torch.from_numpy(
                np.load(
                    f'{config.data_folder}/pretrained_context_embedding.npy')).float()
        except FileNotFoundError:
            pretrained_context_embedding = None

        # 모델 초기화
        model = Model(config, pretrained_word_embedding,
                      pretrained_entity_embedding,
                      pretrained_context_embedding).to(device)
    elif model_name == 'Exp1':
        # Exp1 모델의 경우 앙상블 사용
        models = nn.ModuleList([
            Model(config, pretrained_word_embedding).to(device)
            for _ in range(config.ensemble_factor)
        ])
    elif model_name == 'Exp2':
        model = Model(config).to(device)
    else:
        model = Model(config, pretrained_word_embedding).to(device)


    if model_name != 'Exp1':
        print(model)
    else:
        print(models[0])

 
    
    dataset = BaseDataset(f'{config.data_folder}/train/{config.experiment_data}.tsv',   # behaviors_path
                          f'{config.data_folder}/news_parsed.tsv',                      # news_path
                          f'{config.data_folder}/roberta')                              # roberta_embedding_dir

    print(f"Load training dataset with size {len(dataset)}.")

    dataloader = iter(
        DataLoader(dataset,
                   batch_size=config.batch_size,
                   shuffle=True,
                   num_workers=config.num_workers,
                   drop_last=True,
                   pin_memory=True))
    


    # 손실 함수 및 옵티마이저 설정
    criterion = nn.CrossEntropyLoss()
    # criterion = nn.NLLLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                    lr=config.learning_rate)
    print(f"Loss function:{config.loss_function}, NS Type: {config.candidate_type}_{config.our_type}")

    # 조기 종료 & “최고 성능” 추적
    early_stopping = EarlyStopping()
    best_ndcg5 = -1.0  # nDCG@5는 0~1 범위가 일반적이므로 -1로 초기화

    start_time = time.time()
    loss_full = []
    exhaustion_count = 0
    step = 0

    # 체크포인트 및 결과 디렉토리 생성
    checkpoint_dir = os.path.join('./checkpoint', model_name)
    result_dir = os.path.join('./results', model_name)
    Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
    Path(result_dir).mkdir(parents=True, exist_ok=True)

    # 가장 최근 체크포인트 불러오기 (없으면 None)
    # checkpoint_path = latest_checkpoint(checkpoint_dir)
    
    epoch_result = []
    # if checkpoint_path is not None:
    #     # 체크포인트에서 파라미터 불러오기
    #     print(f"Load saved parameters in {checkpoint_path}")
    #     checkpoint = torch.load(checkpoint_path)
    #     early_stopping(checkpoint['early_stop_value'])
    #     step = checkpoint['step']
    #     exhaustion_count = checkpoint['exhaustion_count']
    #     epoch_result = [x.split(' ') for x in checkpoint['epoch_result'].split('\n')]
    #     model.load_state_dict(checkpoint['model_state_dict'])
    #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    #     model.train()


    # -----------------------
    #     메인 학습 루프
    # -----------------------
    total_iters = config.num_epochs * len(dataset) // config.batch_size

    for i in tqdm(range(1, total_iters + 1), desc="Training"):
        try:
            # dataloader가 반환할 수 있는 데이터(미니배치)가 남아 있는 경우
            minibatch = next(dataloader)
        except StopIteration:
            # dataloader가 데이터셋의 끝에 도달하여 반환할 데이터가 없는 경우 (한 epoch 종료)
            exhaustion_count += 1

            # 검증 데이터 평가 및 결과 저장
            model.eval()
            val_auc, val_mrr, val_ndcg5, val_ndcg10, val_acc = evaluate(
                model, f'{config.data_folder}')
            model.train()

            tqdm.write(
                f"Time {time_since(start_time)}, "
                f"epoch {exhaustion_count}, batch {i}\n"
                f"validation AUC: {val_auc:.4f}, MRR: {val_mrr:.4f}, "
                f"nDCG@5: {val_ndcg5:.4f}, nDCG@10: {val_ndcg10:.4f}, "
                f"ACC: {val_acc:.4f}"
            )
            
            print()
            print('┌─────────────┐')
            print(f'│{exhaustion_count} Epoch Done!│')
            print('└─────────────┘')
            print()

            # 로그 기록
            epoch_result.append([
                exhaustion_count,
                round(val_auc, 4),
                round(val_mrr, 4),
                round(val_ndcg5, 4),
                round(val_ndcg10, 4),
                round(val_acc, 4)
            ])
            with open(result_file, 'w') as wf:
                wf.write("Epoch\tValidation AUC\tValidation MRR\tValidation nDCG@5\tValidation nDCG@10\tValidation ACC\n")
                for row in epoch_result:
                    wf.write(f"{row[0]}\t{row[1]:.4f}\t{row[2]:.4f}\t{row[3]:.4f}\t{row[4]:.4f}\t{row[5]:.4f}\n")

            # epoch_result.append([str(val_auc),str(val_mrr),str(val_ndcg5),str(val_ndcg10)])
            # with open(result_file,'w') as wf:
            #     line = '\n'.join([ ' '.join(x) for x in epoch_result])
            #     wf.write(line)

            val_loss = -sum([val_auc, val_mrr, val_ndcg5, val_ndcg10])  # 지표 합이 커질수록 좋음 → 음수화
            early_stop, get_better = early_stopping(val_loss)

            # # === “가장 성능이 좋았을 때만” 저장 ===
            # if get_better:
            #     checkpoint_name = f"{config.experiment_data}_ep{exhaustion_count}.ckpt"
            #     save_path = os.path.join(checkpoint_dir, checkpoint_name)
            #     torch.save({
            #         'step': step,
            #         'exhaustion_count': exhaustion_count,
            #         'model_state_dict': model.state_dict(),
            #         'optimizer_state_dict': optimizer.state_dict(),
            #         'early_stop_value': val_loss,
            #         'epoch_result': '\n'.join(
            #             [' '.join(map(str, x)) for x in epoch_result]
            #         )
            #     }, save_path)
            #     print(f"Model improved; saved at {save_path}")


            # === nDCG@5가 더 좋아졌을 때만 모델 저장 ===
            if val_ndcg5 > best_ndcg5:
                best_ndcg5 = val_ndcg5
                # EarlyStopping에 사용할 loss = -nDCG@5 (클수록 좋은 지표 → 음수화해서 'loss'로 사용)
                val_loss = -val_ndcg5

                # 모델 저장
                checkpoint_name = f"{config.experiment_data}_ep{exhaustion_count}.ckpt"
                save_path = os.path.join(checkpoint_dir, checkpoint_name)
                torch.save({
                    'epoch': exhaustion_count,
                    'step': step,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_ndcg5': best_ndcg5
                }, save_path)
                tqdm.write(f"  >>> Model improved (nDCG@5={val_ndcg5:.4f}); saved at {save_path}")

                # EarlyStopping 업데이트
                early_stop, get_better = early_stopping(val_loss)
            else:
                # nDCG@5가 좋아지지 않았다면 EarlyStopping만 확인
                val_loss = -val_ndcg5
                early_stop, get_better = early_stopping(val_loss)

            # Early stop
            if early_stop:
                tqdm.write(f'Epoch {exhaustion_count} Done! Early stop triggered.')
                break

            # 만약 지정된 총 epoch 수에 도달했으면 중단
            if exhaustion_count == config.num_epochs:
                break

            # 다음 epoch 준비
            dataloader = iter(
                DataLoader(dataset,
                           batch_size=config.batch_size,
                           shuffle=True,
                           num_workers=config.num_workers,
                           drop_last=True,
                           pin_memory=True))
            try:
                minibatch = next(dataloader)
            except StopIteration:
                # 혹시 데이터가 아주 작아서 바로 끝날 수도 있음
                break

        # -----------------------
        #  배치 학습 (forward/backward)
        # -----------------------
        step += 1

        # 모델 예측 및 손실 계산 (forward() 메서드 실행)
        if model_name == 'LSTUR':
            y_pred = model(
                minibatch["user"], 
                minibatch["clicked_news_length"],
                minibatch["candidate_news"],
                minibatch["clicked_news"]
                )
        elif model_name == 'HiFiArk':
            y_pred, regularizer_loss = model(
                minibatch["candidate_news"],
                minibatch["clicked_news"]
                )
        elif model_name == 'TANR':
            y_pred, topic_classification_loss = model(
                minibatch["candidate_news"], 
                minibatch["clicked_news"]
                )
        else:
            y_pred = model(
                minibatch["candidate_news"], 
                minibatch["clicked_news"]
                )

        y_true = torch.zeros(len(y_pred)).long().to(device)
        loss = criterion(y_pred, y_true)
        loss_full.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 손실 값을 일정 주기로 출력
        if i % config.num_batches_show_loss == 0:
            tqdm.write(
                f"Time {time_since(start_time)}, batches {i}, "
                f"current loss {loss.item():.4f}, "
                f"average loss: {np.mean(loss_full):.4f}, "
                f"latest average loss: {np.mean(loss_full[-256:]):.4f}"
            )
            if np.isnan(loss.item()):
                break

 

    # -----------------------
    #  전체 학습 끝난 후 최종 평가 (선택)
    # -----------------------
    # print("\n=== Training Finished! Evaluating final model... ===")
    # model.eval()
    # val_auc, val_mrr, val_ndcg5, val_ndcg10 = evaluate(model, f'{config.data_folder}')
    # print(f"Final val AUC: {val_auc:.4f}, MRR: {val_mrr:.4f}, "
    #       f"nDCG@5: {val_ndcg5:.4f}, nDCG@10: {val_ndcg10:.4f}")

In [None]:
if __name__ == '__main__':
    print('Using device:', device)
    print(f'Training model {model_name}')
    train()

Using device: cuda:0
Training model NAML
NAML(
  (news_encoder): NewsEncoder(
    (text_encoders): ModuleDict(
      (title): TextEncoder(
        (word_embedding): Embedding(330900, 100, padding_idx=0)
        (CNN): Conv2d(1, 300, kernel_size=(3, 100), stride=(1, 1), padding=(1, 0))
        (additive_attention): AdditiveAttention(
          (linear): Linear(in_features=300, out_features=200, bias=True)
        )
      )
      (abstract): TextEncoder(
        (word_embedding): Embedding(330900, 100, padding_idx=0)
        (CNN): Conv2d(1, 300, kernel_size=(3, 100), stride=(1, 1), padding=(1, 0))
        (additive_attention): AdditiveAttention(
          (linear): Linear(in_features=300, out_features=200, bias=True)
        )
      )
    )
    (element_encoders): ModuleDict(
      (category): ElementEncoder(
        (embedding): Embedding(128, 100, padding_idx=0)
        (linear): Linear(in_features=100, out_features=300, bias=True)
      )
      (subcategory): ElementEncoder(
        

Training:   8%|▊         | 100/1298 [00:43<08:44,  2.28it/s]

Time 00:00:42, batches 100, current loss 1.3193, average loss: 1.3298, latest average loss: 1.3298


Training:  15%|█▌        | 200/1298 [01:24<07:58,  2.29it/s]

Time 00:01:24, batches 200, current loss 1.1507, average loss: 1.2492, latest average loss: 1.2492


Calculating vectors for news: 100%|██████████| 24060/24060 [00:40<00:00, 587.84it/s]
Calculating vectors for users: 100%|██████████| 1000/1000 [00:02<00:00, 334.45it/s]
Calculating probabilities: 100%|██████████| 10989/10989 [00:13<00:00, 796.46it/s]
Training:  20%|█▉        | 259/1298 [02:58<07:13,  2.40it/s]

Time 00:02:58, epoch 1, batch 260
validation AUC: 0.6673, MRR: 0.2540, nDCG@5: 0.2435, nDCG@10: 0.3514, ACC: 0.0919

┌─────────────┐
│1 Epoch Done!│
└─────────────┘



Training:  20%|█▉        | 259/1298 [03:01<07:13,  2.40it/s]

  >>> Model improved (nDCG@5=0.2435); saved at ./checkpoint\NAML\behaviors_user1000_ns4_cdNone_ep1.ckpt


Training:  23%|██▎       | 300/1298 [03:18<07:19,  2.27it/s]  

Time 00:03:18, batches 300, current loss 0.8647, average loss: 1.1605, latest average loss: 1.1167


Training:  31%|███       | 400/1298 [04:00<06:36,  2.26it/s]

Time 00:04:00, batches 400, current loss 0.7779, average loss: 1.0846, latest average loss: 0.9647


Training:  39%|███▊      | 500/1298 [04:42<05:51,  2.27it/s]

Time 00:04:42, batches 500, current loss 0.7003, average loss: 1.0267, latest average loss: 0.8502


Calculating vectors for news: 100%|██████████| 24060/24060 [00:48<00:00, 495.22it/s]
Calculating vectors for users: 100%|██████████| 1000/1000 [00:03<00:00, 282.60it/s]
Calculating probabilities: 100%|██████████| 10989/10989 [00:15<00:00, 702.93it/s]
Training:  40%|███▉      | 518/1298 [06:03<05:27,  2.38it/s]

Time 00:06:03, epoch 2, batch 519
validation AUC: 0.6768, MRR: 0.2648, nDCG@5: 0.2680, nDCG@10: 0.3605, ACC: 0.0868

┌─────────────┐
│2 Epoch Done!│
└─────────────┘



Training:  40%|███▉      | 518/1298 [06:05<05:27,  2.38it/s]

  >>> Model improved (nDCG@5=0.2680); saved at ./checkpoint\NAML\behaviors_user1000_ns4_cdNone_ep2.ckpt


Training:  46%|████▌     | 600/1298 [06:40<05:09,  2.25it/s]  

Time 00:06:40, batches 600, current loss 0.6956, average loss: 0.9813, latest average loss: 0.7880


Training:  54%|█████▍    | 700/1298 [07:24<04:32,  2.20it/s]

Time 00:07:24, batches 700, current loss 0.4921, average loss: 0.9436, latest average loss: 0.7484


Calculating vectors for news: 100%|██████████| 24060/24060 [00:45<00:00, 531.25it/s]
Calculating vectors for users: 100%|██████████| 1000/1000 [00:03<00:00, 322.72it/s]
Calculating probabilities: 100%|██████████| 10989/10989 [00:15<00:00, 724.19it/s]
Training:  60%|█████▉    | 777/1298 [09:06<03:41,  2.35it/s]

Time 00:09:06, epoch 3, batch 778
validation AUC: 0.6689, MRR: 0.2564, nDCG@5: 0.2572, nDCG@10: 0.3489, ACC: 0.0826

┌─────────────┐
│3 Epoch Done!│
└─────────────┘



Training:  62%|██████▏   | 800/1298 [09:16<03:49,  2.17it/s]  

Time 00:09:16, batches 800, current loss 0.8553, average loss: 0.9136, latest average loss: 0.7194


Training:  69%|██████▉   | 900/1298 [09:59<02:57,  2.24it/s]

Time 00:09:59, batches 900, current loss 0.5693, average loss: 0.8897, latest average loss: 0.7001


Training:  77%|███████▋  | 1000/1298 [10:42<02:13,  2.23it/s]

Time 00:10:42, batches 1000, current loss 0.7339, average loss: 0.8680, latest average loss: 0.6918


Calculating vectors for news: 100%|██████████| 24060/24060 [00:42<00:00, 560.35it/s]
Calculating vectors for users: 100%|██████████| 1000/1000 [00:03<00:00, 318.41it/s]
Calculating probabilities: 100%|██████████| 10989/10989 [00:14<00:00, 758.58it/s]
Training:  80%|███████▉  | 1036/1298 [12:03<03:02,  1.43it/s]


Time 00:12:03, epoch 4, batch 1037
validation AUC: 0.6764, MRR: 0.2652, nDCG@5: 0.2676, nDCG@10: 0.3608, ACC: 0.0854

┌─────────────┐
│4 Epoch Done!│
└─────────────┘

Epoch 4 Done! Early stop triggered.


: 

In [14]:
print('┌─────────────┐')
print('│1 Epoch Done!│')
print('└─────────────┘')

┌─────────────┐
│1 Epoch Done!│
└─────────────┘
