# Tune a CNN on MNIST

This tutorial walks through using Ax to tune two hyperparameters (learning rate and momentum) for a PyTorch CNN on the MNIST dataset trained using SGD with momentum.


In [1]:
import torch
import numpy as np

from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate, CNN

init_notebook_plotting()

[INFO 05-15 13:54:14] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.


In [2]:
torch.manual_seed(12345)
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
from IPython.display import display
# import FLuID as fluid
import plotly.express as px
import matplotlib.pyplot as plt
import pandas as pd
import warnings
warnings.filterwarnings('ignore')


import os 
import glob 
import datetime


# import parameters ### TODO actually extract parameters to a separate file

# import parameters.base_parameters as parameters

# print(parameters.lhasa_params)


k = 8

params = {
    
    # experiment details
    'details' : 3,                  # level of detail of the experiment (low=1,medium=2,high=3,full=4)
    
    # datafiles
'training_data_file' : 'hERG_lhasa_training',
    'test_data_file' : 'hERG_lhasa_test',
'transfer_data_file' : 'FLuID_full',
  'fluid_label_file' : 'FLuID_labels',
    
    # data sampling
   'validation_ratio': 0.2,         # ratio validation/training
     'transfer_size' : 50000 ,      # sample for the transfer data (-1 = all)
         'test_size' : -1,          # sample for the test data (-1 = all)
     'training_size' : -1,          # sample for the training data (-1 = all)

    # number of teacher/clusters (kMean)
                 'k' : k,           # number of clusters (kMean)
     'smooth_factor' : 0.05,        # level of post-clustering mixing to avoid fully biased teachers
    
    # teachers
 'teacher_algorithm' : 'rf',        # algorithm used to build the teacher models
    
    # students
 'federated_student' : 'F' + str(k),
      'student_size' : 10000,                                              # size of the student (number of labelled Cronos data used)
      'student_sizes' : [100,250,500, 1000,2500,5000,10000,25000,50000],   # sizes of the student ti study the impact of the size
 'student_algorithm' : 'rf',                                               # default algorithm used to build the student models
      'student_mode' : 'balanced',                                         # default mode used to select the student data 
    
    # random seed for reproductibility
      'random_state' : 42,

    # t-SNE settings
         'tsne_size' : 500,
   'tsne_iterations' : 1000,
    
    # replication level
    'replicate_count' : 3,
    
    # fonts
       'figure_font' : dict(family="Arial",size=14,color="black"),
 'small_figure_font' : dict(family="Arial",size=10,color="black"),

    # colors
'figure_color_scale' : [(0,"red"),(0.2,"orange"), (0.3,'yellow'),(1,'green')],
        'bar_colors' : px.colors.qualitative.Prism,
         'green_map' : plt.get_cmap('Greens')
}

base_params = params.copy()


base_params["FP_type"] = "ECFP4"
base_params["FP_radius"] = 2
base_params["FP_length"] = 2**11


base_params["regressor_layers"] = [base_params["FP_length"], 
                                   base_params["FP_length"], 
                                   base_params["FP_length"]//2**2, 
                                   base_params["FP_length"]//2**4, 
                                   1] # slightly modified from the paper to use powers of 2 for convenience
base_params["regressor_dropout"] = [0.33] # taken from paper
base_params["max_epochs"] = 100
base_params["batch_size"] = 2**7
base_params["learning_rate"] = 0.001 ### TODO Check this is correct - find this from the paper
base_params["convergence_threshold"] = 0.01

base_params["convergence_criterion"] = ""

base_params["base_checkpoint_dir"] = "model_checkpoints"
base_params["base_results_dir"] = "model_results"
os.makedirs(base_params["base_results_dir"] , exist_ok=True)

base_params["data_dir"] = "data"

for dataset in ["training_data", "test_data", "transfer_data", "validation_data", "label_table", "federated_data"]:
  base_params[dataset] = os.path.join(base_params["data_dir"], dataset + ".pkl")


FT_params = base_params.copy()



## 1. Load MNIST data
First, we need to load the MNIST data and partition it into training, validation, and test sets.

Note: this will download the dataset if necessary.

In [4]:
BATCH_SIZE = 128
### Load datasets -- TODO make this a seperate script to pull from fluid notebook

# Federated - load in transfer data
federated_data = pd.read_pickle(base_params["federated_data"])

# Clean - load in training data
clean_data = pd.read_pickle(base_params["training_data"])

# Validation - load in validation data
validation_data = pd.read_pickle(base_params["validation_data"])


# Target - load in test data
target_data = pd.read_pickle(base_params["test_data"])


#pre calculate fingerprints for all molecules

### Currently just computed within the fluid notebook


### split the data into training and validation sets

### Currently just using the split from the fluid notebook

In [5]:
# import numpy as np
# from torch.nn.modules.module import Module
# from torch import nn 
# import torch.nn.functional as F

# class Classifier(torch.nn.Module):
#     def __init__(self, layersize=[2**11, 2**11, 2**9, 2**7, 2**0], dropout=0.33):
#         super(Classifier, self).__init__()
#         self.hidden = nn.ModuleList()
#         self.batchnorm = nn.ModuleList()
#         self.dropout = dropout

#         for idx, layer in enumerate(layersize[:-2]):
#             self.hidden.append(nn.Linear(layersize[idx], layersize[idx+1]))
#             self.batchnorm.append(nn.BatchNorm1d(layersize[idx+1]))

#         self.output = nn.Linear(layersize[-2], layersize[-1])  # output layer for binary classification


#         # save names for each layer
#         for idx, layer in enumerate(self.hidden):
#             self.hidden[idx].name = f"hidden_{idx}"
#         for idx, layer in enumerate(self.batchnorm):
#             self.batchnorm[idx].name = f"batchnorm_{idx}"
#         self.output.name = "output"



#     def forward(self, x):
#         for idx, layer in enumerate(self.hidden):
#             # print(f"hidden layer {idx} output shape: {x.shape}")
#             x = F.relu(self.hidden[idx](x))
#             x = F.dropout(x, self.dropout, training=self.training)
#             x = self.batchnorm[idx](x)

#             if idx == len(self.hidden) - 1:
#                 last_hidden = x  # save activation of last hidden layer
#         # print(f"output layer input shape: {x.shape}")
#         # print(f"output layer output shape: {self.output(x).shape}")
#         output = torch.sigmoid(self.output(x))  # apply sigmoid activation to output layer for binary classification
    
#         return output




In [6]:
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.fp = self.dataframe['FP'].to_numpy()
        self.labels = self.dataframe['CLASS'].to_numpy()

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        x = torch.tensor(self.fp[index], dtype=torch.float32, device=device)
        y = torch.tensor(self.labels[index], dtype=torch.float32, device=device)
        return x, y


In [7]:
# # broad tuning on federated dataset
# FT_params["experiment_name"] = "broad_tuning"
# FT_params["checkpoint_dir"] = os.path.join(FT_params["base_checkpoint_dir"], FT_params["experiment_name"])
# # make directory for checkpoints
# os.makedirs(FT_params["checkpoint_dir"], exist_ok=True)

from torch.utils.data import DataLoader

# create a dataloader for the federated data

federated_loader = MyDataset(federated_data)
validation_loader = MyDataset(validation_data)
training_loader = MyDataset(clean_data)
testing_loader = MyDataset(target_data)

N=25000
federated_loader = DataLoader(federated_loader, batch_size=N, shuffle=True)

N = BATCH_SIZE

validation_loader = DataLoader(validation_loader, batch_size=N, shuffle=True)
training_loader = DataLoader(training_loader, batch_size=N, shuffle=True)
testing_loader = DataLoader(testing_loader, batch_size=N, shuffle=True)

In [8]:
import torch
import torch.optim as optim

import numpy as np
from torch.nn.modules.module import Module
from torch import nn 
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple



def train(
    net: torch.nn.Module,
    train_loader: DataLoader,
    parameters: Dict[str, float],
    dtype: torch.dtype,
    device: torch.device,
) -> nn.Module:
    """
    Train CNN on provided data set.

    Args:
        net: initialized neural network
        train_loader: DataLoader containing training set
        parameters: dictionary containing parameters to be passed to the optimizer.
            - lr: default (0.001)
            - momentum: default (0.0)
            - weight_decay: default (0.0)
            - num_epochs: default (1)
        dtype: torch dtype
        device: torch device
    Returns:
        nn.Module: trained CNN.
    """
    # Initialize network
    net.to(dtype=dtype, device=device)
    net.train()
    # Define loss and optimizer
    criterion = nn.BCELoss() # CE for classifcation
    # optimizer = optim.SGD(
    #     net.parameters(),
    #     lr=parameters.get("lr", 0.001),
    #     momentum=parameters.get("momentum", 0.0),
    #     weight_decay=parameters.get("weight_decay", 0.0),
    # )

    optimizer = optim.Adam(
        net.parameters(),
        lr=parameters.get("lr", 1e-3),
        weight_decay=parameters.get("weight_decay", 0.0),
        # betas=(parameters.get("beta1", 0.9), parameters.get("beta2", 0.999)),
    )

    scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=int(parameters.get("step_size", 30)),
        gamma=parameters.get("gamma", 1.0),  # default is no learning rate decay
    )

    num_epochs = parameters.get("num_epochs", 10)

    # Train Network
    # pyre-fixme[6]: Expected `int` for 1st param but got `float`.
    for _ in range(num_epochs):
        for inputs, labels in train_loader:
            # move data to proper dtype and device
            inputs = inputs.to(dtype=dtype, device=device)
            labels = labels.to(device=device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)

            labels = labels.reshape(outputs.shape)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
    return net

In [9]:

def compute_BAC(predictions, labels):
    """Computes the Balanced Accuracy for the given predictions and labels."""
    # Convert predictions to binary values (0 or 1)
    binary_predictions = torch.round(predictions)

    # Compute true positives, true negatives, false positives, and false negatives
    true_positives = torch.sum(torch.logical_and(binary_predictions == 1, labels == 1))
    true_negatives = torch.sum(torch.logical_and(binary_predictions == 0, labels == 0))
    false_positives = torch.sum(torch.logical_and(binary_predictions == 1, labels == 0))
    false_negatives = torch.sum(torch.logical_and(binary_predictions == 0, labels == 1))

    # print("True Positives: ", true_positives)
    # print("True Negatives: ", true_negatives)
    # print("False Positives: ", false_positives)
    # print("False Negatives: ", false_negatives)


    # Compute balanced accuracy using torch tensor division
    sensitivity = true_positives.float() / (true_positives + false_negatives).float()
    specificity = true_negatives.float() / (true_negatives + false_positives).float()
    balanced_accuracy = (sensitivity + specificity) / 2

    return balanced_accuracy

def compute_MCC(predictions, labels):
    """Computes the Matthews Correlation Coefficient (MCC) for the given predictions and labels."""
    # Convert predictions to binary values (0 or 1)
    binary_predictions = torch.round(predictions)

    # Compute true positives, true negatives, false positives, and false negatives
    tp = torch.sum((binary_predictions == 1) & (labels == 1))
    tn = torch.sum((binary_predictions == 0) & (labels == 0))
    fp = torch.sum((binary_predictions == 1) & (labels == 0))
    fn = torch.sum((binary_predictions == 0) & (labels == 1))

    # Compute numerator and denominator of MCC equation
    numerator = tp * tn - fp * fn
    denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))

    # Check for division by zero
    if denominator == 0:
        return 0.01

    # Compute MCC
    mcc = numerator / denominator

    return mcc


In [10]:


def evaluate(
    net: nn.Module, data_loader: DataLoader, dtype: torch.dtype, device: torch.device
) -> float:
    """
    Compute classification accuracy on provided dataset.

    Args:
        net: trained model
        data_loader: DataLoader containing the evaluation set
        dtype: torch dtype
        device: torch device
    Returns:
        float: classification accuracy
    """
    net.eval()
    correct = 0
    total = 0
    BAC_count = []
    with torch.no_grad():
        for inputs, labels in data_loader:
            
            # move data to proper dtype and device
            inputs = inputs.to(dtype=dtype, device=device)
            labels = labels.to(device=device)
            outputs = net(inputs)

            # print(outputs)

            predicted = outputs.squeeze()  # Round the outputs to obtain binary predictions
        
            BAC_acc = compute_MCC(predicted, labels)

            correct += (predicted.round() == labels).sum().item()
            total += labels.size(0)

            # _, predicted = torch.max(outputs.data, 1)
            # print(_, predicted)
            # total += labels.size(0)
            # correct += (predicted == labels).sum().item()
            BAC_count.append(BAC_acc)
        BAC_count = np.mean(BAC_count)
        accuracy = correct / total
        print(f"Accuracy: {accuracy:.2f}", f"MCC: {BAC_count:.2f}")
    return BAC_count


In [11]:
from DomAdpQSAR.QSARsrgan import DomAdpQSARSRGAN
from DomAdpQSAR.QSARsettings import Settings
settings = Settings()
settings.logs_directory = "AXtuning"

In [12]:
# model = DomAdpQSARSRGAN(settings)
# model.model_setup()
# model.dataset_setup()
# test_data = model.evaluation_epoch(model.DNN, dataset=model.test_dataset)


In [15]:
def train_evaluate(parameterization):

    # layer1 = 2**11
    # layer2 = int(parameterization.get("layer2", 2**9))
    # layer3 = int(parameterization.get("layer3", 2**7))
    # layer4 = int(parameterization.get("layer4", 2**5))
    # layer5 = 2**0




    layer1 = 2**11
    layer2 = 2**int(parameterization.get("layer2", 9))
    layer3 = 2**int(parameterization.get("layer3", 7))
    layer4 = 2**int(parameterization.get("layer4", 5))
    layer5 = 2**0




    layers = [layer1, layer2, layer3, layer4, layer5]
    

    settings.learning_rate = parameterization.get("lr", 1e-4)
    settings.epochs_to_run = parameterization.get("num_epochs", 100)
    # settings.generator_layer_sizes
    settings.use_feature_angle = parameterization.get("use_feature_angle", True)
    
    settings.labeled_loss_multiplier = parameterization.get("labeled_loss_multiplier", 1e1)
    settings.matching_loss_multiplier = parameterization.get("matching_loss_multiplier", 1e0)
    settings.contrasting_loss_multiplier = parameterization.get("contrasting_loss_multiplier", 1e0)
    settings.srgan_loss_multiplier = parameterization.get("srgan_loss_multiplier", 1e1)

    settings.gradient_penalty_multiplier = parameterization.get("gradient_penalty_multiplier", 1e4)
    settings.normalize_feature_norm = parameterization.get("normalize_feature_norm", True)


    settings.layer_sizes = layers


    gan = DomAdpQSARSRGAN(settings)
    # federated_loader = 

    gan.federated_dataframe = gan.validation_dataframe
    gan.train()
    net = gan.D
    # print(test_data.MCC.detach().cpu().numpy())
    return evaluate(
        net=net,
        data_loader=testing_loader,
        dtype=dtype,
        device=device,
    )

## 2. Define function to optimize
In this tutorial, we want to optimize classification accuracy on the validation set as a function of the learning rate and momentum. The function takes in a parameterization (set of parameter values), computes the classification accuracy, and returns a dictionary of metric name ('accuracy') to a tuple with the mean and standard error.

In [16]:
# def train_evaluate(parameterization):
#     net = Classifier()

#     FT_params["experiment_name"] = "broad_tuning"
#     FT_params["checkpoint_dir"] = os.path.join(FT_params["base_checkpoint_dir"], FT_params["experiment_name"])

#     # load model
#     net = load_model(net, FT_params["checkpoint_dir"], latest=True)


#     net = train(net=net, train_loader=training_loader, parameters=parameterization, dtype=dtype, device=device)
#     return evaluate(
#         net=net,
#         data_loader=validation_loader,
#         dtype=dtype,
#         device=device,
#     )

## 3. Run the optimization loop
Here, we set the bounds on the learning rate and momentum and set the parameter space for the learning rate to be on a log scale. 

In [17]:
best_parameters, values, experiment, model = optimize(
    parameters=[
        # {"name": "lr", "type": "range", "bounds": [1e-5, 1e-3], "log_scale": True},
        # {"name": "weight_decay", "type": "range", "bounds": [0.0, 0.1]},
        {"name": "num_epochs", "type": "range", "bounds": [25, 100]},
        # {"name": "layer2", "type": "range", "bounds": [0, 2048], "value_type": "int"},
        # {"name": "layer3", "type": "range", "bounds": [0, 2048], "value_type": "int"},
        # {"name": "layer4", "type": "range", "bounds": [0, 2048], "value_type": "int"},

        # {"name": "layer2", "type": "range", "bounds": [0, 15], "value_type": "int"},
        # {"name": "layer3", "type": "range", "bounds": [0, 15], "value_type": "int"},
        # {"name": "layer4", "type": "range", "bounds": [0, 15], "value_type": "int"},


        # {"name": "feature_angle", "type": "choice", "values": [True, False]},
        {"name": "labeled_loss_multiplier", "type": "range", "bounds": [1e0, 1e3], "log_scale": True},
        {"name": "matching_loss_multiplier", "type": "range", "bounds": [1e0, 1e3], "log_scale": True},
        {"name": "contrasting_loss_multiplier", "type": "range", "bounds": [1e0, 1e3], "log_scale": True},
        {"name": "srgan_loss_multiplier", "type": "range", "bounds": [1e0, 1e3], "log_scale": True},
        {"name": "gradient_penalty_multiplier", "type": "range", "bounds": [1e0, 1e4], "log_scale": True},
        # {"name": "normalize_feature_norm", "type": "choice", "values": [True, False]},
    ],
    
    evaluation_function=train_evaluate,
    objective_name='accuracy',
    total_trials=25,
    random_seed=42,
    minimize=False,
)
best_parameters

[INFO 05-15 13:54:59] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter num_epochs. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 05-15 13:54:59] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter labeled_loss_multiplier. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 05-15 13:54:59] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter matching_loss_multiplier. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 05-15 13:54:59] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter contrasting_loss_multiplier. If that is not the expected value type, you can explicity specify 'val

AXtuning/base r221
dataset rank:  None
7
Step 0, 0:00:00.987608...ize([1000, 2048])DNN Test Values:  ComparisonValues(ACC=tensor(0.2473), BAC=tensor(0.5000), MCC=0.01, predicted_labels=tensor([0.5271, 0.5285, 0.5282,  ..., 0.5273, 0.5274, 0.5288]))
GAN Test Values:  ComparisonValues(ACC=tensor(0.2473), BAC=tensor(0.5000), MCC=0.01, predicted_labels=tensor([0.5354, 0.5364, 0.5357,  ..., 0.5359, 0.5361, 0.5359]))
Step 50, 0:00:18.567234...ize([1000, 2048])DNN Test Values:  ComparisonValues(ACC=tensor(0.8125), BAC=tensor(0.6410), MCC=tensor(0.4263), predicted_labels=tensor([0.4893, 0.2519, 0.3428,  ..., 0.4026, 0.4543, 0.2764]))
GAN Test Values:  ComparisonValues(ACC=tensor(0.2473), BAC=tensor(0.5000), MCC=0.01, predicted_labels=tensor([0.5302, 0.5301, 0.5310,  ..., 0.5290, 0.5304, 0.5296]))
Step 100, 0:00:18.702711...ize([1000, 2048])DNN Test Values:  ComparisonValues(ACC=tensor(0.8460), BAC=tensor(0.7332), MCC=tensor(0.5504), predicted_labels=tensor([0.1294, 0.3910, 0.3430,  ..., 0.2540

In [None]:
break
best_parameters, values, experiment, model = optimize(
    parameters=[
        {"name": "lr", "type": "range", "bounds": [1e-5, 1e-3], "log_scale": True},
        {"name": "weight_decay", "type": "range", "bounds": [0.0, 0.1]},
        {"name": "num_epochs", "type": "range", "bounds": [5, 50]},
        {"name": "layer2", "type": "range", "bounds": [0, 15], "value_type": "int"},
        {"name": "layer3", "type": "range", "bounds": [0, 15], "value_type": "int"},
        {"name": "layer4", "type": "range", "bounds": [0, 15], "value_type": "int"},
    ],
    
    evaluation_function=train_evaluate,
    objective_name='accuracy',
)
best_parameters

In [None]:
# best_parameters, values, experiment, model = optimize(
#     parameters=[
#         {"name": "lr", "type": "range", "bounds": [1e-4, 1e-3], "log_scale": True},
#         {"name": "weight_decay", "type": "range", "bounds": [0.0, 0.1]},
#         {"name": "num_epochs", "type": "range", "bounds": [5, 50]},
#         {"name": "layer2", "type": "range", "bounds": [0, 15], "value_type": "int"},
#         {"name": "layer3", "type": "range", "bounds": [0, 15], "value_type": "int"},
#         {"name": "layer4", "type": "range", "bounds": [0, 15], "value_type": "int"},
#     ],
    
#     evaluation_function=train_evaluate,
#     objective_name='accuracy',
# )
# best_parameters

In [None]:
means, covariances = values
means, covariances

In [None]:
render(plot_contour(model=model, param_x='labeled_loss_multiplier', param_y='matching_loss_multiplier', metric_name='accuracy'))


## 4. Plot response surface

Contour plot showing classification accuracy as a function of the two hyperparameters.

The black squares show points that we have actually run, notice how they are clustered in the optimal region.

In [None]:

render(plot_contour(model=model, param_x='labeled_loss_multiplier', param_y='contrasting_loss_multiplier', metric_name='accuracy'))


In [None]:
render(plot_contour(model=model, param_x='gradient_penalty_multiplier', param_y='contrasting_loss_multiplier', metric_name='accuracy'))


In [None]:
render(plot_contour(model=model, param_x='gradient_penalty_multiplier', param_y='matching_loss_multiplier', metric_name='accuracy'))


In [None]:
render(plot_contour(model=model, param_x='gradient_penalty_multiplier', param_y='labeled_loss_multiplier', metric_name='accuracy'))


In [None]:
render(plot_contour(model=model, param_x='matching_loss_multiplier', param_y='contrasting_loss_multiplier', metric_name='accuracy'))


In [None]:
render(plot_contour(model=model, param_x='layer2', param_y='lr', metric_name='accuracy'))


In [None]:
render(plot_contour(model=model, param_x='layer3', param_y='lr', metric_name='accuracy'))

In [None]:
render(plot_contour(model=model, param_x='layer4', param_y='lr', metric_name='accuracy'))

In [None]:
render(plot_contour(model=model, param_x='layer2', param_y='layer3', metric_name='accuracy'))

## 5. Plot best objective as function of the iteration

Show the model accuracy improving as we identify better hyperparameters.

In [None]:
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple 
# optimization runs, so we wrap out best objectives array in another array.
best_objectives = np.array([[trial.objective_mean*100 for trial in experiment.trials.values()]])
best_objective_plot = optimization_trace_single_method(
    y=np.maximum.accumulate(best_objectives, axis=1),
    title="Model performance vs. # of iterations",
    ylabel="Classification Accuracy, %",
)
render(best_objective_plot)

## 6. Train CNN with best hyperparameters and evaluate on test set
Note that the resulting accuracy on the test set might not be exactly the same as the maximum accuracy achieved on the evaluation set throughout optimization. 

In [None]:
data = experiment.fetch_data()
df = data.df
best_arm_name = df.arm_name[df['mean'] == df['mean'].max()].values[0]
best_arm = experiment.arms_by_name[best_arm_name]
best_arm

Arm(name='21_0', parameters={'num_epochs': 47, 'labeled_loss_multiplier': 1000.0, 'matching_loss_multiplier': 1.0, 'contrasting_loss_multiplier': 1.0, 'srgan_loss_multiplier': 1000.0, 'gradient_penalty_multiplier': 3.159383803529542})

In [None]:
combined_train_valid_set = torch.utils.data.ConcatDataset([
    training_loader.dataset, 
    
    validation_loader.dataset,
])
combined_train_valid_loader = torch.utils.data.DataLoader(
    combined_train_valid_set, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
)

In [None]:
net = train(
    net=Classifier(),
    
    train_loader=combined_train_valid_loader, 
    parameters=best_arm.parameters,
    dtype=dtype,
    device=device,
)
test_accuracy = evaluate(
    net=net,
    data_loader=testing_loader,
    dtype=dtype,
    device=device,
)

In [None]:
print(f"Classification Accuracy (test set): {round(test_accuracy*100, 2)}%")