In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
from GB_SDT import GB_SDT
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Parameters
input_dim = 28 * 28  # For MNIST
output_dim = 10  # Number of classes in MNIST
n_trees = 4
depth = 5
lr = 0.01  # Learning rate for the ensemble update, not individual tree training
internal_lr = 0.001  # Learning rate for training individual trees
lamda = 1e-3
weight_decay = 5e-4
batch_size = 128
epochs = 50
log_interval = 10
use_cuda = False #torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Data loading and augmentation setup
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Loading the MNIST dataset
full_train_dataset = datasets.MNIST(
    '../data', train=True, download=True, transform=transform)

# Splitting the dataset into training and validation
validation_split = 0.2
shuffle_dataset = True
random_seed = 42

dataset_size = len(full_train_dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))

if shuffle_dataset:
    np.random.seed(random_seed)
    np.random.shuffle(indices)

train_indices, val_indices = indices[split:], indices[:split]

# Creating data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
validation_sampler = SubsetRandomSampler(val_indices)

train_loader = DataLoader(
    full_train_dataset, batch_size=batch_size, sampler=train_sampler)
validation_loader = DataLoader(
    full_train_dataset, batch_size=batch_size, sampler=validation_sampler)

test_loader = DataLoader(
    datasets.MNIST('../data', train=False, transform=transform),
    batch_size=batch_size, shuffle=False
)

In [3]:
# Initialize the GB_SDT model
model = GB_SDT(input_dim=input_dim, output_dim=output_dim, n_trees=n_trees, lr=lr, internal_lr=internal_lr,
                   depth=depth, lamda=lamda, weight_decay=weight_decay, epochs=epochs, log_interval=log_interval, use_cuda=use_cuda)

In [4]:
model.load_state_dict(torch.load('saved_models/xg_sdt_model_2.pth'))

In [18]:
model.lr = 1

In [19]:
model.eval()

In [20]:
preds = None
real_labels = None
for X, labels in test_loader:
    p = model.predict(X)
    p = torch.argmax(p, dim=1)
    if preds == None:
        preds = p
        real_labels = labels
    else:
        preds = torch.cat((preds, p), 0)
        real_labels = torch.cat((real_labels, labels), 0)

In [16]:
torch.mean(preds == real_labels, dtype=float)

tensor(0.9534, dtype=torch.float64)