In [None]:
import sys
import os
from dotenv import load_dotenv

root_dir = os.path.abspath("..")
sys.path.append(root_dir)
dotenv_path = os.path.join(root_dir, ".env")
load_dotenv(dotenv_path)

In [None]:
from src.data_insert import ParquetRankDataset, ParquetResampleRankDataset
from src.model import SetRank
from src.losses import groupwise_softmax_loss
from src.metrics import hitrate_at_k

from torch.utils.data import DataLoader
import torch
from datetime import datetime

In [None]:
K_FOLD_TRAIN_0, K_FOLD_VALID_0 = [0, 1, 2, 3, 4, 5, 6, 7, 8], [9]
K_FOLD_TRAIN_1, K_FOLD_VALID_1 = [0, 1, 2, 3, 4, 5, 6, 7, 9], [8]
K_FOLD_TRAIN_2, K_FOLD_VALID_2 = [0, 1, 2, 3, 4, 5, 6, 8, 9], [7]
K_FOLD_TRAIN_3, K_FOLD_VALID_3 = [0, 1, 2, 3, 4, 5, 7, 8, 9], [6]
K_FOLD_TRAIN_4, K_FOLD_VALID_4 = [0, 1, 2, 3, 4, 6, 7, 8, 9], [5]
K_FOLD_TRAIN_5, K_FOLD_VALID_5 = [0, 1, 2, 3, 5, 6, 7, 8, 9], [4]
K_FOLD_TRAIN_6, K_FOLD_VALID_6 = [0, 1, 2, 4, 5, 6, 7, 8, 9], [3]
K_FOLD_TRAIN_7, K_FOLD_VALID_7 = [0, 1, 3, 4, 5, 6, 7, 8, 9], [2]
K_FOLD_TRAIN_8, K_FOLD_VALID_8 = [0, 2, 3, 4, 5, 6, 7, 8, 9], [1]
K_FOLD_TRAIN_9, K_FOLD_VALID_9 = [1, 2, 3, 4, 5, 6, 7, 8, 9], [0]

# K-FOLD Ensemble

In [None]:
NORMALIZATION = os.path.join(root_dir, "data", "train", "train_split_0.parquet")

train_file_paths = []
valid_file_paths = []

# 8개 train 조각
for i in K_FOLD_TRAIN_5:
    path = os.path.join(root_dir, "data", "train", f"train_split_{i}.parquet")
    train_file_paths.append(path)

# 2개 valid 조각
for i in K_FOLD_VALID_5:
    path = os.path.join(root_dir, "data", "train", f"train_split_{i}.parquet")
    valid_file_paths.append(path)

In [None]:
EXCLUDED_COLS = ['row_id', 'ranker_id', 'selected']

LABEL_COL = 'selected'
GROUP_COL = 'ranker_id'

train_dataset_stream = ParquetResampleRankDataset(
    parquet_paths=train_file_paths,
    exclude_feature_cols=EXCLUDED_COLS,
    label_col=LABEL_COL,
    group_col=GROUP_COL,
    max_rows=4096,
    normalization_parquet=NORMALIZATION,
)

valid_dataset_stream = ParquetRankDataset(
    parquet_paths=valid_file_paths,
    exclude_feature_cols=EXCLUDED_COLS,
    label_col=LABEL_COL,
    group_col=GROUP_COL,
    max_rows=4096,
    normalization_parquet=NORMALIZATION,
)

In [None]:
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

VAL_INTERVAL = 40000
PATIENCE = 3
BEST_VAL_LOSS = float("inf")
NO_IMPROVE_COUNT = 0

NUM_EPOCHS = 3

MODEL_NAME = f"best_model_3_SETRANK_6.pt"  # First of 10 models
MODEL_OUTPUT = os.path.join(root_dir, "models", MODEL_NAME)

In [None]:
model = SetRank(
    input_dim=train_dataset_stream.feature_len, 
    hidden_dim=128,
    num_heads=4,
    num_layers=2,
)

In [None]:
train_loader = DataLoader(train_dataset_stream, batch_size=None, shuffle=False)
valid_loader = DataLoader(valid_dataset_stream, batch_size=None, shuffle=False)

if torch.backends.mps.is_available():
    device = torch.device("mps")  # Apple Silicon GPU via Metal
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=NUM_EPOCHS,  # number of epochs
)

for epoch in range(NUM_EPOCHS):
    model.train()
    total_train_loss = 0.0
    step = 0

    print(f"\n[Epoch {epoch+1}] ------------------------------------------------")
    for X, y, g in train_loader:
        if X.shape[0] <= 1:
            print("[INFO] Skip batch with less than 2 rows.")
            continue
        X, y, g = X.to(device), y.to(device), g.to(device)

        optimizer.zero_grad()
        scores = model(X)
        loss = groupwise_softmax_loss(scores, y, g)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()
        scheduler.step()

        total_train_loss += loss.item()
        step += 1
        avg_loss = total_train_loss / step

        # Return per step training loss
        print(f"[Epoch {epoch+1} | Step {step}] Loss: {loss.item():.4f} | Running Avg: {avg_loss:.4f}")

        if step % VAL_INTERVAL == 0:
            model.eval()
            total_val_loss, total_hit3 = 0.0, 0.0
            val_steps = 0
            print(f"[Validate][Epoch {epoch+1} | Step {step}] --------------------------")
            with torch.no_grad():
                for Xv, yv, gv in valid_loader:
                    Xv, yv, gv = Xv.to(device), yv.to(device), gv.to(device)
                    val_scores = model(Xv)
                    val_loss = groupwise_softmax_loss(val_scores, yv, gv)

                    # Compute HitRate@3 for this group
                    hit3 = hitrate_at_k(val_scores.cpu(), yv.cpu(), k=3)

                    print(f"[Val] Loss: {val_loss.item():.4f} | HitRate@3: {hit3:.4f}")
                    total_val_loss += val_loss.item()
                    total_hit3 += hit3
                    val_steps += 1

            avg_val_loss = total_val_loss / val_steps if val_steps > 0 else 0.0
            avg_hit3 = total_hit3 / val_steps if val_steps > 0 else 0.0
            print(
                f"[Validate][Epoch {epoch+1} | Step {step}] "
                f"Average Val Loss: {avg_val_loss:.4f} | HitRate@3: {avg_hit3:.4f}",
            )
            
            # Early stop loss
            if avg_val_loss < BEST_VAL_LOSS:
                BEST_VAL_LOSS = avg_val_loss
                NO_IMPROVE_COUNT = 0
                torch.save(model.state_dict(), MODEL_OUTPUT)

            else:
                NO_IMPROVE_COUNT += 1
                if NO_IMPROVE_COUNT >= PATIENCE:
                    print(f"[INFO] Early stopping at epoch {epoch+1}")
                    break

            model.train()
    
    avg_epoch_loss = total_train_loss / step if step > 0 else 0.0
    print(f"[Epoch {epoch+1} Completed] Average Train Loss: {avg_epoch_loss:.4f}")