In [3]:
# import FrEIA framework
import sys
sys.path.append("/FrEIA-master") 

In [4]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.pyplot import figure
from torch.utils.data import DataLoader
import plain_inn_swap
from tesladatainn import TeslaDatasetInn
import numpy as np
import time
import wandb
import pprint
import math
import torch

### Constants

In [5]:
wandb.login()

sweep_config = {
    'method': 'random'
    }

metric = {
    'name': 'Loss',
    'goal': 'minimize'  
    }

sweep_config['metric'] = metric


parameters_dict = {
    'lambd_predict_back': {
        # integers between 1 and 1024
        # with evenly-distributed logarithms 
        'distribution': 'q_log_uniform',
        'q': 1,
        'min': math.log(1),
        'max': math.log(1024),
        },
    'lambd_latent': {
        # integers between 1 and 1024
        # with evenly-distributed logarithms 
        'distribution': 'q_log_uniform',
        'q': 1,
        'min': math.log(1),
        'max': math.log(1024),
        },
    'lambd_rev': {
        # integers between 1 and 1024
        # with evenly-distributed logarithms 
        'distribution': 'q_log_uniform',
        'q': 1,
        'min': math.log(1),
        'max': math.log(1024),
        },
    'epochs': {
          'value': 1
        },
    }

sweep_config['parameters'] = parameters_dict


[34m[1mwandb[0m: Currently logged in as: [33mjeyhun[0m (use `wandb login --relogin` to force relogin)


In [6]:
pprint.pprint(sweep_config)
sweep_id = wandb.sweep(sweep_config, project="INN optimization")

{'method': 'random',
 'metric': {'goal': 'minimize', 'name': 'Loss'},
 'parameters': {'epochs': {'value': 1},
                'lambd_latent': {'distribution': 'q_log_uniform',
                                 'max': 6.931471805599453,
                                 'min': 0.0,
                                 'q': 1},
                'lambd_predict_back': {'distribution': 'q_log_uniform',
                                       'max': 6.931471805599453,
                                       'min': 0.0,
                                       'q': 1},
                'lambd_rev': {'distribution': 'q_log_uniform',
                              'max': 6.931471805599453,
                              'min': 0.0,
                              'q': 1}}}
Create sweep with ID: opxbmsu7
Sweep URL: https://wandb.ai/jeyhun/INN%20optimization/sweeps/opxbmsu7


### Training
The following cell basically trains an Invertible Neural Network as of (Ardizzone, 2018) without padding: `[y,z] <=> x` with `x` = parameters, `y` = temperature readings and `z` our latent variable (iid. Gaussian)

In [7]:
def train(config=None):
    # Initialize a new wandb run
    with wandb.init(config=config):
        batch_size = 2048
        lr = 1e-3
        #device = "cuda:0"
        # Use cuda if it is available, else use the cpu
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        num_features = 100
        num_blocks = 5

        #ds = TeslaDataset()
        ds = TeslaDatasetInn(device = device, data = "train")
        train_dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False)
        # get one sample of our dataset to infer its input and output dimension
        x,y = ds.__getitem__(0)
        dim_inp = x.shape[0]
        dim_outp = 1
         
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config
        
        #print('config.batch_size',config.batch_size)
        my_inn = plain_inn_swap.INN(ndim_tot = dim_inp, ndim_y = dim_outp, ndim_x = dim_inp, ndim_z=dim_inp-dim_outp, 
            device = device,
             lambd_predict_back = config.lambd_predict_back, lambd_latent = config.lambd_latent, lambd_rev = config.lambd_rev,
             feature = num_features, num_blocks = num_blocks, batch_size = batch_size, lr = lr)

        begin = time.time()
        my_inn.train(config.epochs, train_loader=train_dataloader, val_loader=train_dataloader, log_writer = ['wb', wandb])
        end = time.time()

In [8]:
#print("training:", end - begin)

In [None]:
wandb.agent(sweep_id, train, count=10)