In [51]:
import torch
import torch.nn.functional as F
from torch_frame import stype
from torch_frame.datasets import Yandex
from torch_frame.data import DataLoader
from torch_frame.nn import (
    EmbeddingEncoder,
    FTTransformer,
    LinearBucketEncoder,
    LinearEncoder,
    LinearPeriodicEncoder,
    ResNet
)
from icecream import ic
import tqdm

In [47]:
seed = 42
batch_size = 512
numerical_encoder_type = 'linear'
model_type = 'fttransformer'
channels = 256
num_layers = 4

compile = True
lr = 1e-3

In [37]:
dataset = Yandex(root='/tmp/yandex', name='adult')
ic(dataset)
ic(dataset.feat_cols)
dataset.materialize()
is_classification = dataset.task_type.is_classification
dataset.df.head(5)

ic| dataset: Yandex(name='adult')
ic| dataset.feat_cols: ['C_feature_0',
                        'C_feature_1',
                        'C_feature_2',
                        'C_feature_3',
                        'C_feature_4',
                        'C_feature_5',
                        'C_feature_6',
                        'C_feature_7',
                        'N_feature_0',
                        'N_feature_1',
                        'N_feature_2',
                        'N_feature_3',
                        'N_feature_4',
                        'N_feature_5']


Unnamed: 0,C_feature_0,C_feature_1,C_feature_2,C_feature_3,C_feature_4,C_feature_5,C_feature_6,C_feature_7,N_feature_0,N_feature_1,N_feature_2,N_feature_3,N_feature_4,N_feature_5,target_col,split_col
0,,Some-college,Never-married,,Other-relative,White,Female,United-States,19.0,140399.0,10.0,0.0,0.0,30.0,0,0
1,Private,Some-college,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,United-States,50.0,158284.0,10.0,0.0,0.0,40.0,0,0
2,Private,Some-college,Married-civ-spouse,Exec-managerial,Husband,White,Male,United-States,62.0,183735.0,10.0,0.0,0.0,40.0,0,0
3,Private,HS-grad,Never-married,Adm-clerical,Not-in-family,White,Female,United-States,20.0,154781.0,9.0,0.0,0.0,40.0,0,0
4,Private,Bachelors,Never-married,Adm-clerical,Own-child,White,Female,United-States,25.0,356344.0,13.0,0.0,0.0,40.0,0,0


In [38]:
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [39]:
trrain_dataset, val_dataset, test_dataset = dataset.split()

In [40]:
train_tensor_frame = trrain_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
test_tensor_frame = test_dataset.tensor_frame
train_loader = DataLoader(train_tensor_frame, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_tensor_frame, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_tensor_frame, batch_size=batch_size, shuffle=False)

In [44]:
if numerical_encoder_type == 'linear':
    numerical_encoder = LinearEncoder()
elif numerical_encoder_type == 'linear_bucket':
    numerical_encoder = LinearBucketEncoder()
elif numerical_encoder_type == 'periodic':
    numerical_encoder = LinearPeriodicEncoder()
else:
    raise ValueError(f'Unknown numerical encoder type: {numerical_encoder_type}')

stype_encoder_dict = {
    stype.categorical: EmbeddingEncoder(),
    stype.numerical: numerical_encoder
}

if is_classification:
    output_channels = dataset.num_classes
else:
    output_channels = 1

In [45]:
if model_type == 'fttransformer':
    model = FTTransformer(
        channels=channels,
        out_channels=output_channels,
        num_layers=num_layers,
        col_stats=dataset.col_stats,
        col_names_dict=train_tensor_frame.col_names_dict,
        stype_encoder_dict=stype_encoder_dict
    ).to(device)
elif model_type == 'resnet':
    model = ResNet(
        channels=channels,
        out_channels=output_channels,
        col_stats=dataset.col_stats,
        col_names_dict=train_tensor_frame.col_names_dict,
        stype_encoder_dict=stype_encoder_dict
    ).to(device)
else:
    raise ValueError(f'Unknown model type: {model_type}')

In [52]:
model = torch.compile(model, dynamic=True) if compile else model
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

def train(epoc: int) -> float:
    model.train()
    loss_accum = total_count = 0

    for tf in tqdm(train_loader, desc=f'Epoch {epoc}'):
        tf = tf.to(device)
        pred = model(tf)
        if is_classification:
            loss = F.cross_entropy(pred, tf.y)
        else:
            loss = F.mse_loss(pred.view(-1), tf.y.view(-1))
        optimizer.zero_grad()
        loss.backward()
        loss_accum += float(loss) * len(tf.y)
        total_count = len(tf.y)
        optimizer.step()
    return loss_accum / total_count

def evaluate(loader: DataLoader) -> float:
    model.eval()
    accum = total_count = 0

    for tf in loader:
        tf = tf.to(device)
        pred = model(tf)
        if is_classification:
            pred_class = pred.argmax(dim=-1)
            accum += float((tf.y == pred_class).sum())
        else:
            accum += float(F.mse_loss(pred.view(-1), tf.y.view(-1), reduction='sum'))
        total_count += len(tf.y)

        if is_classification:
            accuracy = accum / total_count
            return accuracy
        else:
            rmse = (accum / total_count) **0.5
            return rmse