In [1]:
import sys
import os
from dotenv import load_dotenv

root_dir = os.path.abspath("..")
sys.path.append(root_dir)
dotenv_path = os.path.join(root_dir, ".env")
load_dotenv(dotenv_path)

True

In [2]:
from src.data_insert import ParquetRankDataset
from src.model import RankerNN
import polars as pl

In [3]:
train = os.path.join(root_dir, "data", "train_split.parquet")
valid = os.path.join(root_dir, "data", "valid_split.parquet")

train_df = pl.read_parquet(train)
train_df = train_df.with_columns(
    pl.col("is_preferred_airline").fill_null(-1)
)

EXCLUDED_COLS = {'row_id', 'ranker_id', 'selected'}
FEATURE_COLS = [c for c in train_df.columns if c not in EXCLUDED_COLS]

LABEL_COL = 'selected'
GROUP_COL = 'ranker_id'

train_dataset_stream = ParquetRankDataset(
    parquet_path=train,
    feature_cols=FEATURE_COLS,
    label_col=LABEL_COL,
    group_col=GROUP_COL,
    max_rows=4096,
)

valid_dataset_stream = ParquetRankDataset(
    parquet_path=valid,
    feature_cols=FEATURE_COLS,
    label_col=LABEL_COL,
    group_col=GROUP_COL,
    max_rows=4096,
)

In [4]:
n_features = len(FEATURE_COLS)
model = RankerNN(n_features=n_features, hidden_layers=[512, 256, 128], dropout=0.2)

In [5]:
import torch
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset_stream, batch_size=None, shuffle=False)
valid_loader = DataLoader(valid_dataset_stream, batch_size=None, shuffle=False)

if torch.backends.mps.is_available():
    device = torch.device("mps")  # Apple Silicon GPU via Metal
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(3):
    model.train()
    total_train_loss = 0.0
    for X, y, g in train_loader:
        X = X.to(device)
        y = y.to(device)
        g = g.to(device)

        optimizer.zero_grad()
        scores = model(X)
        loss = pairwise_loss(scores, y, g)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
    print(f"Epoch {epoch+1} - Train Loss: {total_train_loss:.4f}")

KeyboardInterrupt: 