In [1]:
import h5py
import torch
import torch.nn.functional as F
import pandas as pd

from pathlib import Path
import sys
repo_root = Path.cwd().parent.resolve()
sys.path.append(str(repo_root))

from src.data.auction_dataset import AuctionDataset

In [2]:
from sklearn.model_selection import train_test_split
from src.data.utils import collate_auctions

batch_size = 512

pairs = pd.read_csv('../generated/auction_indices.csv')
train_pairs, val_pairs = train_test_split(pairs, test_size=0.25, random_state=42, shuffle=False)
train_pairs = train_pairs.sample(frac=0.05, random_state=42)

train_dataset = AuctionDataset(train_pairs)

print(f"Train dataset size: {len(train_dataset)}")

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_auctions, num_workers=2, pin_memory=True, persistent_workers=True)

Train dataset size: 2248312


In [3]:
from tqdm import tqdm
import torch

def compute_feature_stats(train_loader, output_dir='../generated/', max_batches=None):
    """
    Compute means and standard deviations per feature over all valid (non-padded) 
    timesteps in the training data.
    
    Assumes that each batch from train_loader is a tuple: (X, y, lengths),
    where X has shape (batch_size, max_seq_len, num_features).
    """
    sum_features = None
    sum_sq_features = None
    total_count = 0
    
    # For modifier values
    modifier_sum = None
    modifier_sum_sq = None
    modifier_count = 0

    for i, batch in enumerate(tqdm(train_loader)):
        if max_batches is not None and i >= max_batches:
            break
            
        auctions = batch['auctions']
        modifier_values = batch['modifier_values']

        # Handle auction features
        auction_mask = auctions[:, :, 0] != 0 # padding
        auction_mask = auction_mask.unsqueeze(2)  # shape: (B, T, 1)
        mask = auction_mask.expand(-1, -1, auctions.size(2))  # shape: (B, T, F)
        X_valid = auctions[mask].view(-1, auctions.size(2))  # shape: (total_valid, F)

        # Initialize accumulators if this is the first batch
        if sum_features is None:
            sum_features = X_valid.sum(dim=0)
            sum_sq_features = (X_valid ** 2).sum(dim=0)
        else:
            sum_features += X_valid.sum(dim=0)
            sum_sq_features += (X_valid ** 2).sum(dim=0)

        total_count += X_valid.size(0)

        # Handle modifier values separately
        modifier_mask = modifier_values != 0
        valid_modifiers = modifier_values[modifier_mask]
        
        if modifier_sum is None:
            modifier_sum = valid_modifiers.sum()
            modifier_sum_sq = (valid_modifiers ** 2).sum()
        else:
            modifier_sum += valid_modifiers.sum()
            modifier_sum_sq += (valid_modifiers ** 2).sum()
            
        modifier_count += valid_modifiers.size(0)

    # Compute stats for auction features
    means = sum_features / total_count
    variances = (sum_sq_features / total_count) - (means ** 2)
    stds = torch.sqrt(variances)

    # Compute stats for modifier values
    modifier_mean = modifier_sum / modifier_count
    modifier_variance = (modifier_sum_sq / modifier_count) - (modifier_mean ** 2)
    modifier_std = torch.sqrt(modifier_variance)

    # Store in pt file
    torch.save({
        'means': means.cpu(),
        'stds': stds.cpu(),
        'modifiers_mean': modifier_mean.cpu(),
        'modifiers_std': modifier_std.cpu()
    }, f'{output_dir}/feature_stats.pt')

    return means, stds, modifier_mean, modifier_std

compute_feature_stats(train_dataloader, max_batches=10000)

100%|██████████| 4392/4392 [51:20<00:00,  1.43it/s] 


(tensor([ 8.2451e+00,  8.3870e+00,  1.0000e+00,  3.2520e+01,  1.7332e+01,
          1.2983e-02,  1.1338e-02, -1.9910e-03, -9.9918e-03]),
 tensor([ 2.4380,  2.4546,  0.0000, 18.7104, 13.4391,  0.7069,  0.7071,  0.7034,
          0.7107]),
 tensor(5.7630),
 tensor(2.3759))