# Tutorial: Train PyTorch Models

In [1]:
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split

In [2]:
features, labels = datasets.load_breast_cancer(return_X_y=True)
x_train, x_valid, y_train, y_valid = train_test_split(features, labels, test_size=0.2)

In [3]:
features.shape, labels.shape

((569, 30), (569,))

In [4]:
import torch
import torch.nn as nn
import torch.utils.data as D
from kuma_utils.torch import TorchTrainer, EarlyStopping, TorchLogger
from kuma_utils.torch.model_zoo import TabularNet
from kuma_utils.metrics import AUC, Accuracy
from kuma_utils.utils import sigmoid

torch_xla not found.
nvidia apex not found.


In [5]:
train_ds = D.TensorDataset(
    torch.as_tensor(x_train).float(),
    torch.as_tensor(y_train.reshape(-1, 1)).float()
)
valid_ds = D.TensorDataset(
    torch.as_tensor(x_valid).float(),
    torch.as_tensor(y_valid.reshape(-1, 1)).float()
)
train_loader = D.DataLoader(train_ds, batch_size=64, shuffle=True)
valid_loader = D.DataLoader(valid_ds, batch_size=64, shuffle=False)

In [6]:
model = TabularNet(in_features=x_train.shape[1], out_features=1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min', factor=0.5, patience=2, verbose=True, min_lr=1e-5)
criterion = nn.BCEWithLogitsLoss()
LogitsAcc = lambda approx, target: AUC().torch(sigmoid(approx), target)

fit_params = {
    'loader': train_loader,
    'loader_valid': valid_loader,
    'criterion': criterion,
    'optimizer': optimizer,
    'scheduler': scheduler,
    'num_epochs': 30,
    'callbacks': [EarlyStopping(5, maximize=True, skip_epoch=10)],
    'eval_metric': AUC().torch,
    'monitor_metrics': [LogitsAcc], 
    'logger': TorchLogger('logger.log', log_items=[
        'epoch',
        'loss_train', 'loss_valid', 'metric_valid', 'monitor_metrics_valid', 'earlystop'
    ], stdout=True, file=False),
    # 'calibrate_model': True, 
    # 'snapshot_path': 'results/baseline/',
    # 'resume': True
}

TorchLogger created at 20/11/26:09:21:16


In [7]:
trn = TorchTrainer(model)
trn.train(**fit_params)

09:21:16 Model is on cpu
09:21:16 [Epoch   1/ 30] loss_train=4.653262 | loss_valid=0.662130 | metric_valid=0.939241 | monitor_metrics_valid=[0.939241] 
09:21:16 [Epoch   2/ 30] loss_train=2.303672 | loss_valid=0.738571 | metric_valid=0.159494 | monitor_metrics_valid=[0.159494] 
09:21:16 [Epoch   3/ 30] loss_train=1.616207 | loss_valid=0.670362 | metric_valid=0.772152 | monitor_metrics_valid=[0.772152] 
09:21:16 [Epoch   4/ 30] loss_train=1.262613 | loss_valid=0.682979 | metric_valid=0.176492 | monitor_metrics_valid=[0.176492] 
Epoch     5: reducing learning rate of group 0 to 5.0000e-04.
09:21:16 [Epoch   5/ 30] loss_train=0.901546 | loss_valid=0.652833 | metric_valid=0.924051 | monitor_metrics_valid=[0.924051] 
09:21:16 [Epoch   6/ 30] loss_train=0.887527 | loss_valid=0.635706 | metric_valid=0.791320 | monitor_metrics_valid=[0.791320] 
09:21:16 [Epoch   7/ 30] loss_train=0.761478 | loss_valid=0.627334 | metric_valid=0.771067 | monitor_metrics_valid=[0.771067] 
09:21:16 [Epoch   8/ 30]