In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')
from src.models import CrossEntropyClassification
from src.data import train_val_test_split, get_descriptor_and_labels
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, RichProgressBar
import torch

In [2]:
desc_type = "steinhardt"
use_mda = True
numb_train_samples = 8_000

In [3]:
train_structs, val_structs, test_structs = train_val_test_split(mda=use_mda,num_files=None)

In [4]:
len(train_structs), len(val_structs), len(test_structs)

(1285, 20, 1245)

In [5]:
train_x, train_y, label_mapping = get_descriptor_and_labels(train_structs, num_samples_per_type=numb_train_samples)
val_x, val_y, _ = get_descriptor_and_labels(val_structs, num_samples_per_type=2_500)
test_x, test_y, _ = get_descriptor_and_labels(test_structs, num_samples_per_type=2_500)

In [7]:
label_mapping

{'hda': 0, 'lda': 1, 'mda': 2}

In [8]:
from sklearn import preprocessing

# fit to training data
scaler = preprocessing.StandardScaler().fit(train_x)
scaled_train_x = torch.FloatTensor(scaler.transform(train_x))
scaled_val_x = torch.FloatTensor(scaler.transform(val_x))

In [9]:
train_dataset = TensorDataset(scaled_train_x,train_y)
val_dataset = TensorDataset(scaled_val_x,val_y)

train_loader = DataLoader(train_dataset, batch_size=250, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=10000, shuffle=False)

output_size = train_y.shape[1]

In [13]:
import optuna
from src.data import predict_test_set_classes
from sklearn.metrics import balanced_accuracy_score
from pytorch_lightning.loggers import TensorBoardLogger

def optimise_NN(trial: optuna.Trial):
    # Optuna optimisation function for the NN
    
    # 1. Suggest the hyperparameters
    n_layers = trial.suggest_int("n_layers", 1, 5)
    neurons_per_layer = trial.suggest_int("n_units_l0", 8, 256, log=True)
    hidden_units = [neurons_per_layer] * n_layers
    weight_decay = trial.suggest_float("weight_decay", 1e-8, 1e-1, log=True)
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
    
    input_size = 30

    # 2. Create the model
    model = CrossEntropyClassification(
        input_size,
        *hidden_units,
        output_size,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
    )

    # 3. Train the model
    trainer = Trainer(
        accelerator="auto",
        max_epochs=200,
        callbacks=[
            RichProgressBar(),
            EarlyStopping(monitor="validation_loss", patience=10),
        ],
        logger=TensorBoardLogger("lightning_logs"),
    )
    trainer.fit(model, train_loader, val_loader)
    
    # 4. Load the best model
    model.load_state_dict(torch.load(trainer.checkpoint_callback.best_model_path)['state_dict'])
    
    # 5. Evaluate the model
    pred_classes, val_classes, _ = predict_test_set_classes(val_structs,model=model, scaler=scaler)
    
    return balanced_accuracy_score(val_classes, pred_classes)

In [14]:
study_name = "optimise_NN"  # Unique identifier of the study.
storage_name = f"sqlite:///{study_name}.db"
study = optuna.create_study(study_name=study_name, storage=storage_name, direction="maximize",load_if_exists=True)

[I 2023-12-01 12:55:06,276] Using an existing study with name 'optimise_NN' instead of creating a new one.


In [16]:
study.optimize(optimise_NN, n_trials=30)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:26:53,880] Trial 80 finished with value: 0.8560935691550927 and parameters: {'n_layers': 4, 'n_units_l0': 60, 'weight_decay': 0.0016800401880322472, 'learning_rate': 0.0001988756757628254}. Best is trial 27 with value: 0.8576949508101852.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:28:39,578] Trial 81 finished with value: 0.8568820529513889 and parameters: {'n_layers': 3, 'n_units_l0': 62, 'weight_decay': 0.003468695860286599, 'learning_rate': 0.00011035771440109698}. Best is trial 27 with value: 0.8576949508101852.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:29:40,638] Trial 82 finished with value: 0.8579137731481481 and parameters: {'n_layers': 3, 'n_units_l0': 82, 'weight_decay': 0.006363505017122275, 'learning_rate': 0.0001363056995309084}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:30:59,031] Trial 83 finished with value: 0.8535228587962963 and parameters: {'n_layers': 3, 'n_units_l0': 82, 'weight_decay': 0.0066681625295974606, 'learning_rate': 0.0002636669091054654}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:31:28,941] Trial 84 finished with value: 0.8564430519386574 and parameters: {'n_layers': 3, 'n_units_l0': 95, 'weight_decay': 0.012889415955673661, 'learning_rate': 0.0003356644991869811}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:32:08,442] Trial 85 finished with value: 0.8559371383101851 and parameters: {'n_layers': 3, 'n_units_l0': 112, 'weight_decay': 0.007206825072282303, 'learning_rate': 0.0005711026396830845}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:33:27,991] Trial 86 finished with value: 0.8557187680844908 and parameters: {'n_layers': 2, 'n_units_l0': 81, 'weight_decay': 0.03626789503428796, 'learning_rate': 0.00015326882828034204}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:33:55,091] Trial 87 finished with value: 0.8522650824652778 and parameters: {'n_layers': 3, 'n_units_l0': 102, 'weight_decay': 0.0024145917441584586, 'learning_rate': 0.00042193687077703883}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:36:19,311] Trial 88 finished with value: 0.8553371853298611 and parameters: {'n_layers': 1, 'n_units_l0': 71, 'weight_decay': 0.021897919761826593, 'learning_rate': 8.354725987115666e-05}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:36:42,259] Trial 89 finished with value: 0.8520417390046297 and parameters: {'n_layers': 3, 'n_units_l0': 127, 'weight_decay': 0.0010911101009170037, 'learning_rate': 0.0007655491264319187}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:37:15,019] Trial 90 finished with value: 0.8550713433159722 and parameters: {'n_layers': 4, 'n_units_l0': 93, 'weight_decay': 0.005084791702657964, 'learning_rate': 0.00021394779481748353}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:38:24,333] Trial 91 finished with value: 0.8556116174768519 and parameters: {'n_layers': 3, 'n_units_l0': 64, 'weight_decay': 0.004170896075837361, 'learning_rate': 0.0001272664791612903}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:40:08,043] Trial 92 finished with value: 0.8565615053530092 and parameters: {'n_layers': 3, 'n_units_l0': 58, 'weight_decay': 0.011719430709001001, 'learning_rate': 0.00013184788023783465}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:41:04,354] Trial 93 finished with value: 0.8574562355324075 and parameters: {'n_layers': 3, 'n_units_l0': 52, 'weight_decay': 0.00342337789159726, 'learning_rate': 0.0001810659487072395}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:42:28,232] Trial 94 finished with value: 0.8560840747974536 and parameters: {'n_layers': 3, 'n_units_l0': 51, 'weight_decay': 0.0017734199873340518, 'learning_rate': 0.00017554098451755316}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:43:29,398] Trial 95 finished with value: 0.8576981155960648 and parameters: {'n_layers': 3, 'n_units_l0': 73, 'weight_decay': 0.0072726248908818614, 'learning_rate': 0.0003297086064660382}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:44:55,741] Trial 96 finished with value: 0.8566691080729166 and parameters: {'n_layers': 2, 'n_units_l0': 73, 'weight_decay': 0.007285829524967987, 'learning_rate': 0.0002524963307446916}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:45:48,156] Trial 97 finished with value: 0.8561925817418982 and parameters: {'n_layers': 3, 'n_units_l0': 54, 'weight_decay': 0.01600883472999935, 'learning_rate': 0.0003045623897738026}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:46:27,109] Trial 98 finished with value: 0.8564195421006945 and parameters: {'n_layers': 3, 'n_units_l0': 48, 'weight_decay': 0.024957562855898986, 'learning_rate': 0.0003812025156283826}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:47:09,630] Trial 99 finished with value: 0.8564697265625 and parameters: {'n_layers': 3, 'n_units_l0': 43, 'weight_decay': 0.008771000677817702, 'learning_rate': 0.000501291422948635}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:48:10,399] Trial 100 finished with value: 0.8562893337673612 and parameters: {'n_layers': 2, 'n_units_l0': 67, 'weight_decay': 0.002781990696184675, 'learning_rate': 0.0002339319101073392}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:49:16,837] Trial 101 finished with value: 0.8559760199652778 and parameters: {'n_layers': 3, 'n_units_l0': 87, 'weight_decay': 0.0037250935454863, 'learning_rate': 0.00018540663840787764}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:50:58,865] Trial 102 finished with value: 0.8564493815104166 and parameters: {'n_layers': 3, 'n_units_l0': 76, 'weight_decay': 0.005679771881805263, 'learning_rate': 9.879733122813713e-05}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:51:48,004] Trial 103 finished with value: 0.8555442527488427 and parameters: {'n_layers': 3, 'n_units_l0': 61, 'weight_decay': 0.0023859612096351827, 'learning_rate': 0.00030387754418057173}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:52:14,279] Trial 104 finished with value: 0.8509046766493055 and parameters: {'n_layers': 3, 'n_units_l0': 99, 'weight_decay': 0.0005809119551623455, 'learning_rate': 0.0005975328709006673}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:53:33,366] Trial 105 finished with value: 0.8560470015914352 and parameters: {'n_layers': 3, 'n_units_l0': 109, 'weight_decay': 0.012306196394131462, 'learning_rate': 7.22724589816953e-05}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:54:53,471] Trial 106 finished with value: 0.8570158781828704 and parameters: {'n_layers': 3, 'n_units_l0': 69, 'weight_decay': 0.020596262406653607, 'learning_rate': 0.0001657963637369387}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:55:25,895] Trial 107 finished with value: 0.8539681893807871 and parameters: {'n_layers': 4, 'n_units_l0': 85, 'weight_decay': 0.0009111462713966313, 'learning_rate': 0.0004139355979785532}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:56:20,078] Trial 108 finished with value: 0.8567916304976851 and parameters: {'n_layers': 3, 'n_units_l0': 79, 'weight_decay': 0.007913807671549944, 'learning_rate': 0.00021685179775551128}. Best is trial 82 with value: 0.8579137731481481.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Output()

[I 2023-12-01 14:58:22,750] Trial 109 finished with value: 0.8561342592592593 and parameters: {'n_layers': 2, 'n_units_l0': 53, 'weight_decay': 0.0015628291398035485, 'learning_rate': 0.00010021597381068625}. Best is trial 82 with value: 0.8579137731481481.
