In [19]:
import pandas as pd
import numpy as np
import pyarrow.parquet as pq

# then import to verify
import torch
print("torch", torch.__version__)
from torch.utils.data import Dataset, DataLoader

# Load train/val/test
train = pd.read_parquet("data/train.parquet")
val   = pd.read_parquet("data/val.parquet")
test  = pd.read_parquet("data/test.parquet")

num_features = [
    "position",
    "popularity",
    "acousticness",
    "danceability",
    "energy",
    "tempo",
    "duration_sec",
    "skip_prob",
]

target_col = "skip"

# Make sure target is numeric 0/1
for df_ in [train, val, test]:
    df_[target_col] = df_[target_col].astype(int)


train.head()


torch 2.8.0


Unnamed: 0,session_id,user_id,position,time_of_day,day_type,location,track_id,age_group,gender,country,...,genre,popularity,acousticness,danceability,energy,tempo,duration_sec,skip_prob,skip,split
0,1,1067,1,afternoon,weekend,work,2575,18-24,female,LATAM,...,electronic,35,0.72135,0.911924,0.490107,147,182,0.63,1,train
1,1,1067,2,afternoon,weekend,work,1418,18-24,female,LATAM,...,hiphop,85,0.038647,0.263042,0.396802,146,335,0.68,0,train
2,1,1067,3,afternoon,weekend,work,4203,18-24,female,LATAM,...,hiphop,1,0.575145,0.408896,0.433143,94,124,0.68,1,train
3,1,1067,4,afternoon,weekend,work,3896,18-24,female,LATAM,...,latin,21,0.538074,0.584303,0.219636,151,174,0.68,1,train
4,1,1067,6,afternoon,weekend,work,2085,18-24,female,LATAM,...,rock,66,0.085926,0.18691,0.31143,71,193,0.68,0,train


In [12]:
import torch
from torch.utils.data import Dataset

class SessionDataset(Dataset):
    def __init__(self, df, feature_cols, target_col="skip", max_len=50):
        self.feature_cols = feature_cols
        self.target_col = target_col
        self.max_len = max_len

        # group by session
        groups = df.groupby("session_id")

        self.sessions = []
        for sid, g in groups:
            g = g.sort_values("position")
            feats = g[self.feature_cols].values.astype("float32")
            target = g[self.target_col].values.astype("float32")

            # padding
            pad_len = max_len - len(g)
            if pad_len > 0:
                feats = np.vstack([feats, np.zeros((pad_len, len(feature_cols)))])
                target = np.concatenate([target, np.zeros(pad_len)])
            else:
                feats = feats[:max_len]
                target = target[:max_len]

            self.sessions.append((feats, target))

    def __len__(self):
        return len(self.sessions)

    def __getitem__(self, idx):
        feats, target = self.sessions[idx]
        return (
            torch.tensor(feats, dtype=torch.float32),
            torch.tensor(target, dtype=torch.float32)
        )


In [9]:
def collate_batch(batch):
    xs = [torch.tensor(item[0]) for item in batch]
    ys = [torch.tensor(item[1]) for item in batch]

    x_padded = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True)
    y_padded = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True)

    lengths = torch.tensor([len(x) for x in xs])

    return x_padded, y_padded, lengths


In [17]:
# Build datasets + loaders
feature_cols = [
    "position", "popularity", "acousticness", "danceability",
    "energy", "tempo", "duration_sec"
]

train_ds = SessionDataset(train, feature_cols, target_col="skip", max_len=50)
val_ds   = SessionDataset(val,   feature_cols, target_col="skip", max_len=50)
test_ds  = SessionDataset(test,  feature_cols, target_col="skip", max_len=50)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_batch)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False, collate_fn=collate_batch)

# Training batch
batch_x, batch_y, lengths = next(iter(train_loader))

print("\n=== üîç Batch Example ===")
print("batch_x shape:", batch_x.shape)   # (B, T, F)
print("batch_y shape:", batch_y.shape)   # (B, T)
print("lengths:", lengths[:10])

# Data sizes
print("\n=== üì¶ Dataset sizes ===")
print("Train sessions:", len(train_ds))
print("Val sessions:  ", len(val_ds))
print("Test sessions: ", len(test_ds))

# Example session
x0, y0 = train_ds[0]
print("\n=== üß™ Example Session (idx=0) ===")
print("features shape:", x0.shape)  # (seq_len, num_features)
print("labels shape:  ", y0.shape)

print("\nFirst 3 timesteps (features):")
print(x0[:3])

print("First 3 timesteps (labels):")
print(y0[:3])



=== üîç Batch Example ===
batch_x shape: torch.Size([32, 50, 7])
batch_y shape: torch.Size([32, 50])
lengths: tensor([50, 50, 50, 50, 50, 50, 50, 50, 50, 50])

=== üì¶ Dataset sizes ===
Train sessions: 19994
Val sessions:   19095
Test sessions:  16696

=== üß™ Example Session (idx=0) ===
features shape: torch.Size([50, 7])
labels shape:   torch.Size([50])

First 3 timesteps (features):
tensor([[1.0000e+00, 3.5000e+01, 7.2135e-01, 9.1192e-01, 4.9011e-01, 1.4700e+02,
         1.8200e+02],
        [2.0000e+00, 8.5000e+01, 3.8647e-02, 2.6304e-01, 3.9680e-01, 1.4600e+02,
         3.3500e+02],
        [3.0000e+00, 1.0000e+00, 5.7515e-01, 4.0890e-01, 4.3314e-01, 9.4000e+01,
         1.2400e+02]])
First 3 timesteps (labels):
tensor([1., 0., 1.])


  xs = [torch.tensor(item[0]) for item in batch]
  ys = [torch.tensor(item[1]) for item in batch]
