In [1]:
from training_methods import train_model
from data_gen import *
from argparse import Namespace
from DLGN_enums import *
from sklearn.model_selection import train_test_split
import numpy as np
import torch


In [2]:
use_dataset = True
if use_dataset == False:
    torch.manual_seed(42)

    dim_in = 10
    tree_depth = 3
    num_points = 10000

    X,Y = gen_spherical_data(depth=tree_depth, dim_in=dim_in, type_data='spherical', num_points=num_points, feat_index_start=0,radius=1)


    x_train, x_test, y_train, y_test = train_test_split(
        X, Y, test_size=0.1, random_state=42, stratify=Y
    )

    data_config = {
        'dim_in': dim_in,
        'tree_depth': tree_depth,
        'num_points': num_points,
        'type_data': 'spherical',
        'feat_index_start': 0,
        'radius': 1
    }
    data_config = Namespace(**data_config)
    data = {}
    data['train_data'] = x_train
    data['train_labels'] = y_train 
    data['test_data'] = x_test
    data['test_labels'] = y_test

else:
    DATA_DIR = 'data/dataset2'
    data = {}
    data['train_data'] = torch.tensor(np.load(DATA_DIR + '/x_train.npy'))
    data['train_labels'] = torch.tensor(np.load(DATA_DIR + '/y_train.npy'))
    data['test_data'] = torch.tensor(np.load(DATA_DIR + '/x_test.npy'))
    data['test_labels'] = torch.tensor(np.load(DATA_DIR + '/y_test.npy'))
    data_config = np.load(DATA_DIR + '/config.npy', allow_pickle=True).item()

In [3]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')


In [8]:
config = {
    'device' : device,
    'model_type' : ModelTypes.VN,
    'num_data' : len(data['train_data']),
    'dim_in' : data_config.dim_in,
    'num_hidden_nodes' : [500]*4, 
    'beta' : 10,
    'loss_fn_type' : LossTypes.CE,
    'optimizer_type' : Optim.ADAM,
    'mode' : 'pwc',
    'lr_ratio' : 10000,
    'log_features' : False,
    'lr' : 0.001,
    'epochs' : 1000,
    'use_wandb' : False
}
config = Namespace(**config)

In [4]:
config = {
    'device' : device,
    'model_type' : ModelTypes.VT,
    'num_data' : len(data['train_data']),
    'dim_in' : data_config.dim_in,
    'num_hidden_nodes' : [8]*4,
    'beta' : 30,
    'mode' : 'pwc',
    'value_scale' : 100,
    'BN' : False,
    'prod' : 'op',
    'feat' : 'sf', 
    'loss_fn_type' : LossTypes.CE,
    'optimizer_type' : Optim.ADAM,
    'epochs' : 200,
    'save_freq' : 100,
    'value_freq' : 100,
    'lr': 0.001,
    'vt_fit' : VtFit.LOGISTIC,
    'reg' : 1,
    'use_wandb' : False,
}
config = Namespace(**config)

In [6]:
config = {
    'device' : device,
    'model_type' : ModelTypes.KERNEL,
    'num_data' : len(data['train_data']),
    'dim_in' : data_config.dim_in,
    'width' : 128,
    'depth' : 4,
    'beta' : 30,
    'alpha_init' : None,
    'BN' : True,
    'feat' : 'cf',
    'weight_decay' : 0.01,
    'train_method' : KernelTrainMethod.PEGASOS,
    'reg' : 0.001,
    'loss_fn_type' : LossTypes.HINGE,
    'optimizer_type' : Optim.ADAM,
    'log_features' : False,
    'gates_lr' : 0.01,
    'alpha_lr' : 0.1,
    'epochs' : 1000,
    'value_freq': 100,
    'num_iter' : 50000,
    'threshold' : 0.3,
    'use_wandb' : False,
}
config = Namespace(**config)


In [5]:
model = train_model(data,config)

  0%|          | 0/201 [00:00<?, ?it/s]

Accuracy:  0.48615
Loss before updating alphas at epoch 0  is  0.6984536426544189
Time taken to fit value net:  129.04826164245605
Accuracy:  0.76735
Test Accuracy:  0.63775
Loss after updating value_net at epoch 0  is  0.4773755151748657


Loss 0.450010:  50%|████▉     | 100/201 [02:54<00:42,  2.37it/s] 

Accuracy:  0.7864
Loss before updating alphas at epoch 100  is  0.4500099825382233
Time taken to fit value net:  126.23777484893799
Accuracy:  0.811
Test Accuracy:  0.6739
Loss after updating value_net at epoch 100  is  0.4118704149723053


Loss 0.392677: 100%|█████████▉| 200/201 [05:44<00:00,  2.33it/s]  

Accuracy:  0.8194
Loss before updating alphas at epoch 200  is  0.3926772129535675
Time taken to fit value net:  106.48081302642822
Accuracy:  0.83595
Test Accuracy:  0.6907
Loss after updating value_net at epoch 200  is  0.3706959485054016


Loss 0.361747: 100%|██████████| 201/201 [07:32<00:00,  2.25s/it]
