In [1]:
import sys
sys.path.append('../src')

import warnings
warnings.filterwarnings("ignore")

import logging
from bunch import Bunch
from torch.utils.data import DataLoader


from callbacks.output import Logger
from callbacks.early_stop_callback import EarlyStop
from callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
from callbacks.validation import Validation

from dataset.movielens import MovieLens1MDataset, MovieLens20MDataset

from logger import initialize_logger
from modules import DeepFM

from util.data_utils import train_val_split
from util.device_utils import set_device_name, set_device_memory, get_device

from sklearn.metrics import roc_auc_score

from torch.nn import BCELoss
from torch.optim import Adam

In [2]:
initialize_logger()
set_device_name('gpu')
device = get_device()

In [3]:
def load_dataset(name):
    if '1m' == name:
        dataset_path = '../datasets/ml-1m/ratings.dat'
        dataset = MovieLens1MDataset(dataset_path=dataset_path)
    else:
        dataset_path = '../datasets/ml-20m/ratings.csv'
        dataset = MovieLens20MDataset(dataset_path=dataset_path)

    logging.info('{} dataset loaded! Shape: {}'.format(dataset_path, dataset.shape))
    logging.info('Target count: {}'.format(dataset.targets_count()))

    return dataset

In [4]:
dataset = load_dataset('1m')

2022-01-08 11:07:54,237 MainProcess root INFO ../datasets/ml-1m/ratings.dat dataset loaded! Shape: (1000209, 2)
2022-01-08 11:07:54,254 MainProcess root INFO Target count: {0.0: 424928, 1.0: 575281}


In [5]:
ps = Bunch({
        'lr': 0.001,
        'lr_factor': 0.1,
        'lr_patience': 1,
        'weight_decay': 1e-6,
        'epochs': 50,
        'embedding_size': 100,
        'units_per_layer': [300, 300],
        'dropout': 0.8,
        'batch_size': 4000,
        'train_percent': 0.7,
        'num_workers': 12,
        'features_n_values': dataset.field_dims,
        'dataset': dataset,
        'device': get_device()
    })

In [6]:
train_set, val_set = train_val_split(ps.dataset, ps.train_percent)

model = DeepFM(
    ps.features_n_values,
    ps.embedding_size,
    ps.units_per_layer,
    ps.dropout
).to(ps.device)

In [7]:
logging.info('Start training...')
model.fit(
    data_loader = DataLoader(train_set, ps.batch_size, num_workers=ps.num_workers),
    loss_fn     = BCELoss(),
    epochs      = ps.epochs,
    optimizer   = Adam(
        params       = model.parameters(),
        lr           = ps.lr,
        weight_decay = ps.weight_decay
    ),
    callbacks   = [
        Validation(
            data_loader=DataLoader(val_set, ps.batch_size, num_workers=ps.num_workers),
            metrics = {
                'val_loss': lambda y_pred, y_true: BCELoss()(y_pred, y_true).item(),
                'val_auc' : lambda y_pred, y_true: roc_auc_score(y_true.cpu().numpy(), y_pred.cpu().numpy())
            },
            each_n_epochs=1
        ),
        Logger(['time', 'epoch', 'train_loss', 'val_loss', 'val_auc', 'patience', 'lr']),
        ReduceLROnPlateau(metric='val_auc', mode='max', factor=ps.lr_factor, patience=ps.lr_patience),
        EarlyStop(metric='val_auc', mode='max', patience=3)
    ]
)

2022-01-08 11:08:00,618 MainProcess root INFO Start training...
2022-01-08 11:08:11,392 MainProcess root INFO {'time': '0:00:05.15', 'epoch': 2, 'train_loss': 96.66918170452118, 'val_loss': 0.539854884147644, 'val_auc': 0.7918233424638522, 'patience': 0, 'lr': 0.001}
2022-01-08 11:08:16,911 MainProcess root INFO {'time': '0:00:05.46', 'epoch': 3, 'train_loss': 94.94861769676208, 'val_loss': 0.5367813110351562, 'val_auc': 0.7940973364903636, 'patience': 0, 'lr': 0.001}
2022-01-08 11:08:22,475 MainProcess root INFO {'time': '0:00:05.51', 'epoch': 4, 'train_loss': 93.97449707984924, 'val_loss': 0.5357990264892578, 'val_auc': 0.7951330838029054, 'patience': 0, 'lr': 0.001}
2022-01-08 11:08:27,821 MainProcess root INFO {'time': '0:00:05.29', 'epoch': 5, 'train_loss': 93.33862218260765, 'val_loss': 0.5352303981781006, 'val_auc': 0.7957603132822509, 'patience': 0, 'lr': 0.001}
2022-01-08 11:08:33,180 MainProcess root INFO {'time': '0:00:05.30', 'epoch': 6, 'train_loss': 92.76799929141998, 'va