In [None]:
!pip install wandb -qqq
import wandb
wandb.login()

In [None]:
from collections import namedtuple
import numpy as np
import os
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Dropout, Dense, Flatten, Activation, BatchNormalization
from keras.utils import np_utils
from tensorflow.keras.optimizers import RMSprop, SGD, Adam, Nadam
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, Callback, EarlyStopping
from tensorflow import keras
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from wandb.keras import WandbCallback
from sklearn.utils import shuffle

Dataset = namedtuple("Dataset", ["x", "r", "fi"])

In [None]:
# Configure the sweep – specify the parameters to search through, the search strategy, the optimization metric et all.
sweep_config = {
    'method': 'bayes', #grid, random
    'metric': {
      'name': 'val_mae',
      'goal': 'minimize'   
    },
    'parameters': {
        'target':{
          'values': ['r']  
        },
        'epochs': {
            'values': [20]
        },
        'batch_size': {
            'values': [64, 128]
        },
        'batchnorm_for_layers':{
            'values': [0, 1]
        },
        'layer_1_size': {
            'values': [2048, 3072, 4096, 5040, 5760]
        },
        'layer_2_size': {
            'values': [3072, 4096, 5120, 5760, 6480, 7200]
        },
        'layer_3_size': {
            'values': [2160, 2880, 3600, 4096, 5120, 6144]
        },
        'learning_rate': {
            'distribution': 'uniform',
            'max': 0.001,
            'min': 1e-06
        },
        'optimizer': {
            'distribution': 'categorical',
            'values': ['sgd', 'rmsprop']
        },
        'activation1': {
            'distribution': 'categorical',
            'values': ['relu', 'tanh', 'sigmoid']
        },
        'activation2': {
            'distribution': 'categorical',
            'values': ['relu', 'tanh', 'sigmoid']
        },
        'activation3': {
            'distribution': 'categorical',
            'values': ['relu', 'tanh', 'sigmoid']
        },
        'activation4': {
            'distribution': 'categorical',
            'values': ['relu', 'tanh', 'sigmoid']
        }
    }
}

In [None]:
sweep_id = wandb.sweep(sweep_config, entity="artem-starkov", project="flatfasetgen")

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error

def read(data_dir, split):
    filename = split + ".npz"
    data = np.load(os.path.join(data_dir, filename))

    return Dataset(x=data["x"], r=data["r"], fi=data['fi'])

In [None]:
# The sweep calls this function with each set of hyperparameters
def train():
    # Default values for hyper-parameters we're going to sweep over
    config_defaults = {
        'epochs': 20,
        'batch_size': 64,
        # 'weight_decay': 0.0005,
        'learning_rate': 1e-3,
        'activation1': 'relu',
        'activation2': 'relu',
        'activation3': 'relu',
        'activation4': 'relu',
        'optimizer': 'sgd',
        'layer_1_size': 4320,
        'layer_2_size': 4320,
        'layer_3_size': 4320,
        'batchnorm_for_layers': 1,
        'Distance_distribution': 'sqr',
        'target': 'r'
    }

    run = wandb.init(project="flatfasetgen", job_type="training_distance_linear_distrubution", config=config_defaults)
    processed_data = wandb.Artifact("Clear_datasets", type="dataset")
    raw_data_artifact = run.use_artifact('Clear_datasets:v20')  
    raw_dataset = raw_data_artifact.download()
    train_dataset = read(raw_dataset, 'train_set')
    test_dataset = read(raw_dataset, 'test_set')
    run.log_artifact(processed_data)
    X_train, X_test, y_train, y_test = train_dataset.x, test_dataset.x, train_dataset.r, test_dataset.r

    config = wandb.config
    model = Sequential()
    model.add(Dense(720, input_shape=(720,)))
    if config.batchnorm_for_layers:
      model.add(BatchNormalization())

    model.add(Activation(config.activation1))


    model.add(Dense(config.layer_1_size))
    if config.batchnorm_for_layers:
      model.add(BatchNormalization())
    model.add(Activation(config.activation2))

    model.add(Dense(config.layer_2_size))
    if config.batchnorm_for_layers:
      model.add(BatchNormalization())
    model.add(Activation(config.activation3))

    model.add(Dense(config.layer_3_size))
    if config.batchnorm_for_layers:
      model.add(BatchNormalization())
    model.add(Activation(config.activation4))

    model.add(Dense(1))

    # Define the optimizer
    if config.optimizer=='sgd':
      optimizer = SGD(learning_rate=config.learning_rate, decay=1e-5, nesterov=True)
    elif config.optimizer=='rmsprop':
      optimizer = RMSprop(learning_rate=config.learning_rate, decay=1e-5)

    model.compile(loss='mae', optimizer = optimizer, metrics=['mae', 'mape'])

    model_artifact = wandb.Artifact(
            "distance_compiled_model", type="model",
            description=f"50k dataset, full search with batch_norm, desrtibution for distance: {config.Distance_distribution}",
            metadata=dict(config))
    model.save("distance_compiled_models")
    model_artifact.add_dir("distance_compiled_models")
    run.log_artifact(model_artifact)

    model.fit(X_train, y_train, batch_size=config.batch_size,
              epochs=config.epochs, validation_data=(X_test, y_test),
              callbacks=[WandbCallback(validation_data=(X_test, y_test)),
                          EarlyStopping(monitor='val_loss', min_delta=0.00001, patience=10, restore_best_weights=True)])
    wandb.log({"R2" : r2_score(y_test, model.predict(X_test))})
    wandb.log({"RMSE": mean_squared_error(y_test, model.predict(X_test), squared=False)})
    trained_model_artifact = wandb.Artifact(
            "distance_trained_model", type="model",
            description="",
            metadata=dict(config))
    
    model.save('distance_trained_models')
    trained_model_artifact.add_dir('distance_trained_models')
    run.log_artifact(trained_model_artifact)
    run.finish()

In [None]:
wandb.agent(sweep_id, train, count=10)