<a href="https://colab.research.google.com/github/arminabdeh/Localization3D/blob/smlm/optuna.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# @markdown **Step 0** Prepare Colab instance (~5 mins)
# @markdown  1. install __luenn__ package
# @markdown  2. install __decode__ package
# @markdown  3. install __torchvision__
# @markdown  4. install __torch 1.12.0__
# @markdown  5. install __optuna__ and __optuna-dashboard__
# @markdown  3. install __torachtext__ and __torchaudio__

%%capture
!pip install Optuna
!pip install Optuna-dashboard
!pip uninstall -y torch torchvision torchaudio torchtext
!pip install torch==1.12.0
import torch
import yaml
!rm -r sample_data/*
!wget -c https://raw.githubusercontent.com/arminabdeh/Localization3D/smlm/gateway.yaml
with open('gateway.yaml') as f:
  gateway = yaml.safe_load(f)


wheel_spline = gateway['wheels']['spline_py310']
wheel_spline_name = "spline-0.11.1.dev0-cp310-cp310-linux_x86_64.whl"
!wget -O $wheel_spline_name $wheel_spline
!pip install $wheel_spline_name

wheel_decode = gateway['wheels']['decode']
!pip install $wheel_decode

wheel_luenn =  gateway['wheels']['luenn']
!pip install $wheel_luenn

In [2]:
#@markdown **Step 3** Mount Google Drive
#@markdown
#@markdown Execute this cell to connect your Google Drive
#@markdown This is important to save your trained model. An alternative would be to download it manually, this is however not recommended.


#@markdown * Click on the URL.

#@markdown * Sign in your Google Account.

#@markdown * Copy the authorization code.

#@markdown * Enter the authorization code.

#@markdown * Click on "Files" site on the right. Refresh the site. Your Google Drive folder should now be available here as "drive".

#@markdown Your Google Drive is accessible in Colab as `gdrive/My Drive/[YOUR Folder / File]`. **Important** It might take a while until files are being synced (particularly large files).
# mount user's Google Drive to Google Colab.
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
from luenn.model import UNet
import decode.utils.param_io as param_io
from luenn.generic import fly_simulator
from torch.utils.data import Dataset, TensorDataset,DataLoader


In [29]:
# @title Default title text
training_size = 19 # @param {type:"integer"}
test_size = 9 # @param {type:"integer"}
# @markdown Trail range of learning rate
#@markdown > Set trail range of learning rate
lr_lim_down = 1e-5 # @param {type:"raw"}
lr_lim_up = 1e-3 # @param {type:"raw"}
epochs = 10 # @param {type:"integer"}
def objective(trial):
    param_path = "/content/gdrive/MyDrive/optuna/param/param.yaml"
    param = param_io.load_params(param_path)
    calib_path = "/content/gdrive/MyDrive/optuna/calib/spline_calibration_3d_as_3dcal.mat"
    param.InOut.calibration_file = calib_path
    param.HyperParameter.pseudo_ds_size = training_size
    param.TestSet.test_size = test_size
    model = UNet()
    model.to('cuda')
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "AdamW", "RMSprop", "SGD"])
    lr = trial.suggest_float("lr", lr_lim_down, lr_lim_up, log=True)
    optimizer = getattr(torch.optim, optimizer_name)(model.parameters(), lr=lr)
    step = trial.suggest_int("step", 1, 20, log=True)
    gamma = trial.suggest_float("gamma", 0.1, 1, log=True)
    # Generate the dataloaders.
    simulator = fly_simulator(param,report=False)
    x_test, y_test, gt_test = simulator.ds_test()
    dataset_test = torch.utils.data.TensorDataset(x_test, y_test)
    streaming_dataset = training_stream(param, simulator)
    batch_size = trial.suggest_int("batch_size", 2, 16, log=True)
    dataloader_test  = DataLoader(dataset_test, batch_size=8,num_workers=0, shuffle=False, pin_memory=True)
    dataloader_train = DataLoader(streaming_dataset, batch_size=4,num_workers=0, shuffle=False, pin_memory=True)
    tqdm_enum_train = tqdm(total=epochs, desc="Processing", unit=" iterations")
    for epoch in range(epochs):
        val_loss_temp, model = train_val_loops(model, dataloader_train, dataloader_test, optimizer,step, gamma)
        # Training of the model.
        trial.report(val_loss_temp, epoch)
        tqdm_enum_train.update(1)
        # tqdm_enum_train.set_postfix({"Current Iteration": epoch}, refresh=True)
        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
    tqdm_enum_train.close()
    val_loss = val_loss_temp
    trial.report(val_loss, epoch)
    # Handle pruning based on the intermediate value.
    if trial.should_prune():
        raise optuna.exceptions.TrialPruned()
    return val_loss

class training_stream(Dataset):
    def __init__(self, param,simulator):
        self.data = simulator.ds_train()
        self.num_frames = self.data[0].numpy().shape[0]
    def __len__(self):
        return self.num_frames

    def __getitem__(self, index):
        x_sim, y_sim, gt_sim = self.data
        return x_sim[index], y_sim[index]

In [None]:
def train_val_loops(model, dataloader_train, dataloader_test,optimizer, step, gamma):
	model.train()
	steps_train = len(dataloader_train)
	# tqdm_enum_train = tqdm(total=epochs, smoothing=0.)
	tr_loss = 0
	criterion = torch.nn.MSELoss()
	lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step, gamma=gamma)
	for data in dataloader_train:
		inputs, labels = data
		inputs = inputs.cuda()
		labels = labels.cuda()
		optimizer.zero_grad()
		outputs = model(inputs)
		loss = criterion(outputs, labels)
		tr_loss += loss
		loss.backward()
		optimizer.step()
		inputs.cpu()
		labels.cpu()
		torch.cuda.empty_cache()
	# validation
	val_loss = 0
	steps_test = len(dataloader_test)
	# tqdm_enum_test = tqdm(total=steps_test, smoothing=0.)
	model.eval()
	with torch.no_grad():
		for data in dataloader_test:
			inputs, labels = data
			outputs = model(inputs.cuda())
			loss = criterion(outputs, labels.cuda())
			val_loss += loss.item()
			# tqdm_enum_test.update(1)
			inputs.cpu()
			labels.cpu()
			torch.cuda.empty_cache()
	val_loss = val_loss / steps_test
	return val_loss,model

import optuna
from tqdm import tqdm
from tqdm.notebook import tqdm


study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=50, timeout=600)
print(study.best_trial)
print(study.best_params)

[I 2023-11-16 23:45:55,617] A new study created in memory with name: no-name-8e699660-e207-45dc-92ea-900566143c5d


Processing:   0%|          | 0/10 [00:00<?, ? iterations/s]

[I 2023-11-16 23:46:08,063] Trial 0 finished with value: 122619.671875 and parameters: {'optimizer': 'AdamW', 'lr': 1.8655400924888334e-05, 'step': 13, 'gamma': 0.4974881133401941, 'batch_size': 11}. Best is trial 0 with value: 122619.671875.


Processing:   0%|          | 0/10 [00:00<?, ? iterations/s]

[I 2023-11-16 23:46:20,537] Trial 1 finished with value: 110440.33203125 and parameters: {'optimizer': 'RMSprop', 'lr': 0.0004514626552611697, 'step': 3, 'gamma': 0.5955164112179108, 'batch_size': 2}. Best is trial 1 with value: 110440.33203125.


Processing:   0%|          | 0/10 [00:00<?, ? iterations/s]

[I 2023-11-16 23:46:34,496] Trial 2 finished with value: 117122.3515625 and parameters: {'optimizer': 'SGD', 'lr': 1.0005842957009423e-05, 'step': 1, 'gamma': 0.46504287515081366, 'batch_size': 5}. Best is trial 1 with value: 110440.33203125.


Processing:   0%|          | 0/10 [00:00<?, ? iterations/s]

[I 2023-11-16 23:24:04,401] A new study created in memory with name: no-name-e67004ac-6ba7-455d-9228-0a3885366d78


Test data summary:
total seeds are 583
total frames are 40
Average seeds/frame is 14.575

Train data summary:
total seeds are 270
total frames are 20
Average seeds/frame is 13.5



100%|██████████| 5/5 [00:01<00:00,  3.46it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00,  9.68it/s][A
100%|██████████| 5/5 [00:02<00:00,  2.27it/s]
100%|██████████| 5/5 [00:00<00:00,  8.45it/s]
 80%|████████  | 4/5 [00:00<00:00,  6.32it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00, 10.20it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.26it/s]
100%|██████████| 5/5 [00:00<00:00,  8.94it/s]
 80%|████████  | 4/5 [00:00<00:00,  6.29it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00,  8.32it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.09it/s]
100%|██████████| 5/5 [00:00<00:00,  8.29it/s]
 80%|████████  | 4/5 [00:00<00:00,  6.95it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00, 10.10it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.38it/s]
100%|██████████| 5/5 [00:00<00:00,  9.07it/s]
 80%|████████  | 4/5 [00:00<00:00,  6.90it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 

Test data summary:
total seeds are 696
total frames are 40
Average seeds/frame is 17.4

Train data summary:
total seeds are 368
total frames are 20
Average seeds/frame is 18.4



100%|██████████| 5/5 [00:00<00:00,  5.56it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00,  9.81it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.31it/s]
100%|██████████| 5/5 [00:00<00:00,  8.93it/s]
100%|██████████| 5/5 [00:00<00:00,  5.49it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00, 10.01it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.29it/s]
100%|██████████| 5/5 [00:00<00:00,  9.01it/s]
100%|██████████| 5/5 [00:00<00:00,  5.49it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00,  9.98it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.31it/s]
100%|██████████| 5/5 [00:00<00:00,  9.11it/s]
100%|██████████| 5/5 [00:00<00:00,  5.51it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00, 10.17it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.31it/s]
100%|██████████| 5/5 [00:00<00:00,  9.08it/s]
100%|██████████| 5/5 [00:00<00:00,  5.56it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 

Test data summary:
total seeds are 626
total frames are 40
Average seeds/frame is 15.65

Train data summary:
total seeds are 305
total frames are 20
Average seeds/frame is 15.25



 80%|████████  | 4/5 [00:00<00:00,  6.92it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00, 10.06it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.41it/s]
100%|██████████| 5/5 [00:00<00:00,  9.24it/s]
 80%|████████  | 4/5 [00:00<00:00,  6.85it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00, 10.10it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.38it/s]
100%|██████████| 5/5 [00:00<00:00,  9.16it/s]
 80%|████████  | 4/5 [00:00<00:00,  6.78it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00, 10.24it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.37it/s]
100%|██████████| 5/5 [00:00<00:00,  9.23it/s]
 80%|████████  | 4/5 [00:00<00:00,  6.83it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00, 10.16it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.36it/s]
100%|██████████| 5/5 [00:00<00:00,  9.21it/s]
 80%|████████  | 4/5 [00:00<00:00,  6.81it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 

Test data summary:
total seeds are 673
total frames are 40
Average seeds/frame is 16.825

Train data summary:
total seeds are 338
total frames are 20
Average seeds/frame is 16.9



100%|██████████| 5/5 [00:01<00:00,  4.64it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:00<00:00,  8.05it/s][A
 40%|████      | 2/5 [00:00<00:00,  6.51it/s][A
 60%|██████    | 3/5 [00:00<00:00,  6.88it/s][A
 80%|████████  | 4/5 [00:00<00:00,  6.73it/s][A
100%|██████████| 5/5 [00:01<00:00,  2.55it/s]
100%|██████████| 5/5 [00:00<00:00,  6.24it/s]
100%|██████████| 5/5 [00:01<00:00,  4.03it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:00<00:00,  8.29it/s][A
 40%|████      | 2/5 [00:00<00:00,  7.12it/s][A
 60%|██████    | 3/5 [00:00<00:00,  6.85it/s][A
 80%|████████  | 4/5 [00:00<00:00,  6.61it/s][A
100%|██████████| 5/5 [00:02<00:00,  2.31it/s]
100%|██████████| 5/5 [00:00<00:00,  6.10it/s]
100%|██████████| 5/5 [00:01<00:00,  4.48it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:00<00:00,  7.58it/s][A
 40%|████      | 2/5 [00:00<00:00,  7.02it/s][A
 60%|██████    | 3/5 [00:00<00:00,  6.64it/s][A
 80%|████████  |

Test data summary:
total seeds are 658
total frames are 40
Average seeds/frame is 16.45

Train data summary:
total seeds are 365
total frames are 20
Average seeds/frame is 18.25



 80%|████████  | 4/5 [00:00<00:00,  6.86it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00, 10.09it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.35it/s]
100%|██████████| 5/5 [00:00<00:00,  9.01it/s]
 80%|████████  | 4/5 [00:00<00:00,  6.74it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00,  9.53it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.12it/s]
100%|██████████| 5/5 [00:00<00:00,  7.93it/s]
 80%|████████  | 4/5 [00:00<00:00,  5.80it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:00<00:00,  8.91it/s][A
 40%|████      | 2/5 [00:00<00:00,  7.94it/s][A
 60%|██████    | 3/5 [00:00<00:00,  7.72it/s][A
 80%|████████  | 4/5 [00:00<00:00,  7.81it/s][A
100%|██████████| 5/5 [00:01<00:00,  2.87it/s]
100%|██████████| 5/5 [00:00<00:00,  7.40it/s]
 80%|████████  | 4/5 [00:00<00:00,  5.57it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00,  9.84it/s][A
100%|██████████| 5/5 [00:01<00:00

FrozenTrial(number=2, state=TrialState.COMPLETE, values=[100551.95625], datetime_start=datetime.datetime(2023, 11, 16, 23, 24, 23, 911342), datetime_complete=datetime.datetime(2023, 11, 16, 23, 24, 32, 293291), params={'optimizer': 'RMSprop', 'lr': 0.0001727108722454708, 'step': 1, 'gamma': 0.12369284346571935, 'batch_size': 8}, user_attrs={}, system_attrs={}, intermediate_values={0: 125388.690625, 1: 100559.0546875, 2: 101255.7421875, 3: 101148.8578125, 4: 100551.95625}, distributions={'optimizer': CategoricalDistribution(choices=('Adam', 'AdamW', 'RMSprop', 'SGD')), 'lr': FloatDistribution(high=0.1, log=True, low=1e-05, step=None), 'step': IntDistribution(high=20, log=True, low=1, step=1), 'gamma': FloatDistribution(high=1.0, log=True, low=0.1, step=None), 'batch_size': IntDistribution(high=16, log=True, low=2, step=1)}, trial_id=2, value=None)
{'optimizer': 'RMSprop', 'lr': 0.0001727108722454708, 'step': 1, 'gamma': 0.12369284346571935, 'batch_size': 8}
