In [2]:
from time import time

import torch
import sklearn.datasets
import sklearn.preprocessing
import sklearn.model_selection
import numpy as np
import pandas as pd
import random

import onlinehd

In [48]:
# loads simple mnist dataset
def load():
    # fetches data
    train_dataset = pd.read_csv('data/isolet1+2+3+4.data', header=None)
    test_dataset = pd.read_csv('data/isolet5.data', header=None)

    x = train_dataset.iloc[:, :-1]
    y = train_dataset.iloc[:, -1]

    x_test = test_dataset.iloc[:, :-1]
    y_test = test_dataset.iloc[:, -1]

    print("train_X shape:", x.shape, "train_y shape:", y.shape, "test_X shape:", x_test.shape, "test_y shape:", y_test.shape)

    scaler = sklearn.preprocessing.Normalizer().fit(x)
    x = scaler.transform(x)
    x_test = scaler.transform(x_test)
 
    x = torch.from_numpy(x).float()
    y = torch.from_numpy(np.array(y)).long()
    x_test = torch.from_numpy(x_test).float()
    y_test = torch.from_numpy(np.array(y_test)).long()

    return x, x_test, y-1, y_test-1

In [49]:
# simple OnlineHD training

print('Loading...')
x, x_test, y, y_test = load()

Loading...
train_X shape: (6238, 617) train_y shape: (6238,) test_X shape: (1559, 617) test_y shape: (1559,)


In [50]:
def train_parameter(x, y, x_test, y_test, lr, epochs, dim, bootstrap):
    classes = y.unique().size(0)
    features = x.size(1)
    model = onlinehd.OnlineHD(classes, features, dim)

    if torch.cuda.is_available():
        x = x.cuda()
        y = y.cuda()
        x_test = x_test.cuda()
        y_test = y_test.cuda()
        model = model.to('cuda')
        print('Using GPU!')

    print('Training...')
    t = time()
    model = model.fit(x, y, bootstrap=bootstrap, lr=lr, epochs=epochs)
    t = time() - t

    print('Validating...')
    yhat = model(x)
    yhat_test = model(x_test)
    acc = (y == yhat).float().mean()
    acc_test = (y_test == yhat_test).float().mean()
    print(f'{acc = :6f}')
    print(f'{acc_test = :6f}')
    print(f'{t = :6f}')
    
    return [lr, epochs, dim, bootstrap, acc.item(), acc_test.item(), t]