In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import time
import json
import os
import random
from typing import Dict, Iterator, List, Tuple

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from loguru import logger
from sklearn.datasets import load_breast_cancer

from nam.config.default import defaults
from nam.types import Config
from nam.utils.args import parse_args
from nam.data.base import NAMDataset

In [3]:
config = defaults()
config

namespace(device='cpu',
          logdir='logs',
          lr=0.01,
          batch_size=1024,
          l2_regularization=0.0,
          output_regularization=0.0,
          decay_rate=0.995,
          dropout=0.5,
          feature_dropout=0.0,
          data_split=1,
          seed=1377,
          num_basis_functions=1000,
          units_multiplier=2,
          cross_val=False,
          max_checkpoints_to_keep=1,
          save_checkpoint_every_n_epochs=10,
          n_models=1,
          num_splits=3,
          fold_num=1,
          activation='exu',
          regression=False,
          debug=False,
          shallow=False,
          use_dnn=False,
          early_stopping_epochs=60,
          n_folds=5)

In [4]:
features_columns = ["longitude", "latitude", "housing_median_age", "total_rooms",  "total_bedrooms", "population", "households",                                "median_income"]
targets_column = ["median_house_value"]

In [5]:
dataset = NAMDataset(config=config,
                    csv_file='data/housing.csv',
                    features_columns=features_columns,
                    targets_column=targets_column)
dataset

<nam.data.base.NAMDataset at 0x7f8400880610>

In [6]:
dl = DataLoader(dataset, batch_size=32, shuffle=True)

In [7]:
batch = next(iter(dl))

In [8]:
batch

[tensor([[-1.2164e+02,  3.6740e+01,  3.0000e+01,  2.6280e+03,  4.4400e+02,
           1.3720e+03,  4.3200e+02,  4.1696e+00],
         [-1.1708e+02,  3.2680e+01,  2.6000e+01,  3.0710e+03,  6.1500e+02,
           2.1560e+03,  5.6800e+02,  2.9318e+00],
         [-1.1954e+02,  3.6520e+01,  1.6000e+01,  2.7030e+03,  4.1500e+02,
           1.1060e+03,  3.7200e+02,  4.2045e+00],
         [-1.1805e+02,  3.3930e+01,  3.1000e+01,  8.9400e+02,  2.0300e+02,
           8.8300e+02,  1.9000e+02,  3.6771e+00],
         [-1.2182e+02,  3.6610e+01,  2.4000e+01,  2.4370e+03,  4.3800e+02,
           1.4300e+03,  4.4400e+02,  3.8015e+00],
         [-1.1823e+02,  3.3910e+01,  3.4000e+01,  1.0600e+03,  2.7600e+02,
           1.2150e+03,  2.5000e+02,  2.0804e+00],
         [-1.1758e+02,  3.3870e+01,  1.7000e+01,  2.7720e+03,  4.4900e+02,
           1.6850e+03,  4.6100e+02,  5.0464e+00],
         [-1.1898e+02,  3.7640e+01,  1.7000e+01,  3.7690e+03,  9.0800e+02,
           1.1600e+03,  4.5300e+02,  3.0500e+00],
