In [5]:
"""run"""
from torch import nn
import torch
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from sklearn.model_selection import train_test_split
import optuna
from models import LitAE


# load data from data folder(we"re in notebooks folder)
sim_arr_tensor = torch.load('./data/sim_arr_tensor.pt')

# train test split
train_data, test_data = train_test_split(
    sim_arr_tensor, test_size=0.2, random_state=42)
train_data, val_data = train_test_split(
    train_data, test_size=0.2, random_state=42)


def objective(trial):
    """ define the objective function."""
    num_layers = 5
    channels = [3,]
    for i in range(num_layers - 1):
        channels.append(trial.suggest_int(
            f'channels_{i}', 7, 20))
    channels.append(12)
    kernel_sizes = [trial.suggest_int(
        f'kernel_{i}', 6, 20) for i in range(num_layers)]

    dilations = [trial.suggest_int(
        f'dilation_{i}', 1, 5) for i in range(num_layers)]

    activations = [trial.suggest_categorical(
        f'activation_{i}', ['nn.Softplus',
                            'nn.SELU',
                            'nn.SiLU',
                            'nn.Tanh']) for i in range(num_layers)]

    lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
    batch_size = trial.suggest_categorical('batch_size', [32, 64, 128, 256])

    activations = [eval(activation) for activation in activations]
    hyperparameters_dict = {
        "num_layers": 5,
        "poolsize": [2, 2, 2, 2, 5],
        "channels": channels,
        "kernel_sizes": kernel_sizes,
        "dilations": dilations,
        "activations": activations,
        "lr": lr,
        "batch_size": batch_size
    }

    lit_model = LitAE(hyperparameters=hyperparameters_dict)
    trainer = L.Trainer(callbacks=[EarlyStopping(
        monitor="val_loss")], max_epochs=100)
    trainer.fit(model=lit_model,
                train_dataloaders=DataLoader(train_data,
                                             batch_size=hyperparameters_dict["batch_size"],
                                             num_workers=31),
                val_dataloaders=DataLoader(val_data,
                                           batch_size=hyperparameters_dict["batch_size"],
                                           num_workers=31))

    val_loss = trainer.callback_metrics["val_loss"].item()

    return val_loss


study = optuna.create_study(direction='minimize',
                            sampler=optuna.samplers.TPESampler(),
                            pruner=optuna.pruners.HyperbandPruner(),
                            study_name='AutoEncoder_10',
                            storage='sqlite:///optuna.db',
                            load_if_exists=True)
study.optimize(objective, n_trials=100)


  sim_arr_tensor = torch.load('./data/sim_arr_tensor.pt')
[I 2024-09-12 12:00:55,829] A new study created in RDB with name: AutoEncoder_10
Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name  | Type        | Params | Mode 
----------------------------------------------
0 | model | AutoEncoder | 84.6 K | train
----------------------------------------------
84.6 K    Trainable params
0         Non-trainable params
84.6 K    Total params
0.338     Total estimated model params size (MB)
30        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [25]:
import numpy as np
from plot import plot_class

labels = np.load("data/labels.npy")

class_dict = {"1_skyrmion": 0,
                "2_skyrmions": 1,
                "4_skyrmions": 2,
                "5_skyrmions": 3,
                "quasi_sat": 4,
                "target_skyrmion_1": 5,
                "target_skyrmion_2": 6,
                "helical_1": 7,
                "helical_2": 8,
                "cross_1": 9,
                "cross_2": 10,
                "horseshoe_1": 11,
                "horseshoe_2": 12,
                "cross_1": 13,
                "cross_2": 14,
                "cross_3": 15
                }

for i in labels:
    if i in class_dict:
        labels[labels == i] = class_dict[i]

plot_class(labels)


Number of classes:  15


KeyError: 'Class 15'