In [40]:

from net import GaussianDiffusion
from net import EpsilonTheta
import torch
from torch.utils.data import DataLoader
import wandb
import math
import optuna
import os


from tqdm import tqdm
import os


In [41]:
def custom_collate_fn(batch):
    """
    Custom collate function to reshape data into [batch size, channels, size].
    """
    # Assuming your signals are originally in the shape [size]
    # and you want to add a single channel dimension
    signals = torch.stack([item['signals'] for item in batch]).unsqueeze(1)  # Adds a channel dimension
    gt = torch.stack([item['gt'] for item in batch])
    sc = torch.stack([item['sc'] for item in batch])
    
    return {'signals': signals, 'gt': gt, 'sc': sc}

In [42]:
def train_model(epochs, train_loader, num_batches_per_epoch, model, optimizer, 
                validation_iter=None, device='cuda', model_save_path='./model_checkpoints/', 
                log_interval=100):
    """
    A function to train the model and return the average validation loss.
    :param epochs: Number of epochs to train.
    :param train_loader: DataLoader for training data.
    :param num_batches_per_epoch: Number of batches in each epoch.
    :param model: The neural network model to train.
    :param optimizer: Optimizer used for training.
    :param validation_iter: DataLoader for validation data, if any.
    :param device: Device to train on, 'cuda' or 'cpu'.
    :param model_save_path: Path to save model checkpoints.
    :param log_interval: Interval to log training progress.
    
    :return: Average validation loss over the validation dataset.
    """
    losses_t = []
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)

    for epoch in range(epochs):
        model.train()
        cumm_epoch_loss = 0.0

        with tqdm(train_loader, total=num_batches_per_epoch - 1) as it:
            for batch_no, data_entry in enumerate(it, start=1):
                optimizer.zero_grad()
                signals = data_entry['signals'].to(device)
                losses = model.log_prob(signals)
                cumm_epoch_loss += losses.item()

                avg_epoch_loss = cumm_epoch_loss / batch_no
                it.set_postfix({"epoch": f"{epoch + 1}/{epochs}", "avg_loss": avg_epoch_loss}, refresh=False)

                wandb.log({"train_loss": losses.item()})
                losses.backward()
                optimizer.step()
                #lr_scheduler.step()

                if num_batches_per_epoch == batch_no:
                    break

        losses_t.append(avg_epoch_loss)
        if (epoch + 1) % log_interval == 0:
            model_checkpoint_path = os.path.join(model_save_path, f'model_epoch_{epoch+1}.pth')
            torch.save(model.state_dict(), model_checkpoint_path)
            print(f'Model saved to {model_checkpoint_path}')

    # Validation loop
    if validation_iter is not None:
        model.eval()
        cumm_epoch_loss_val = 0.0
        with tqdm(validation_iter, total=num_batches_per_epoch - 1, colour="green") as it:
            for batch_no, data_entry in enumerate(it, start=1):
                signals = data_entry['signals']
                with torch.no_grad():
                    losses = model.log_prob(signals)

                cumm_epoch_loss_val += losses.item()
                avg_epoch_loss_val = cumm_epoch_loss_val / batch_no

                it.set_postfix({"epoch": f"{epoch + 1}/{epochs}", "avg_val_loss": avg_epoch_loss_val}, refresh=False)

        return avg_epoch_loss_val
    else:
        return sum(losses_t) / len(losses_t)


In [43]:
def objective(trial):
    # Hyperparameters to optimize
    learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1e-2)
    batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])
    num_layers = trial.suggest_int('num_layers', 4, 16)
    residual_channels = trial.suggest_categorical('residual_channels', [16, 32, 64])
    dilation_cycle_length = trial.suggest_categorical('dilation_cycle_length', [1, 2, 4])

    nb_samples = 10000
    num_batches_per_epoch = math.ceil(nb_samples / batch_size)

    
    
    file_path = 'datasets/train_set.pth'
    dataset = torch.load(file_path)

    # DataLoader for training
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True,collate_fn=custom_collate_fn)
    # Initialize model, optimizer, etc. using the suggested hyperparameters
    # For example:
    denoise_fn = EpsilonTheta(target_dim=[256],
                              residual_layers=num_layers,
                              residual_channels=residual_channels,
                              dilation_cycle_length=dilation_cycle_length)
    model = GaussianDiffusion(denoise_fn=denoise_fn, input_size=[256], beta_end=0.1, diff_steps=100, loss_type="l2", betas=None, beta_schedule="linear")
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    

    # Your training function should return the validation loss
    val_loss = train_model(100, train_loader, num_batches_per_epoch, model, optimizer, 
                validation_iter=None, device='cpu', model_save_path='./model_checkpoints_optuna/', 
                log_interval=100)  
    
    # Log to wandb
    wandb.log({'learning_rate': learning_rate,
               'batch_size': batch_size,
               'num_layers': num_layers,
               'residual_channels': residual_channels,
               'dilation_cycle_length': dilation_cycle_length,
               'val_loss': val_loss})
    
    return val_loss

In [44]:
wandb.init(project="optuna_hyperparameter_optimization")

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train_loss,██▄▄▄▂▃▂▃▃▂▃▄▃▂▁▂▄▂▁▂▃▂▃▁▂▂▂▃▂▂▂▂▂▂▂▂▁▂▂

0,1
train_loss,0.07155


In [45]:
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=100)

    # Complete the wandb run
wandb.run.finish()


[I 2024-03-03 19:53:04,413] A new study created in memory with name: no-name-469d8df2-0f76-4831-8b0e-70b2ac50953c


  This is separate from the ipykernel package so we can avoid doing imports until
100%|██████████| 156/156 [00:42<00:00,  3.69it/s, epoch=1/100, avg_loss=0.606]
100%|██████████| 156/156 [00:46<00:00,  3.33it/s, epoch=2/100, avg_loss=0.254]
100%|██████████| 156/156 [00:40<00:00,  3.83it/s, epoch=3/100, avg_loss=0.223]
100%|██████████| 156/156 [00:41<00:00,  3.74it/s, epoch=4/100, avg_loss=0.2]  
100%|██████████| 156/156 [00:42<00:00,  3.66it/s, epoch=5/100, avg_loss=0.179]
100%|██████████| 156/156 [00:36<00:00,  4.29it/s, epoch=6/100, avg_loss=0.167]
100%|██████████| 156/156 [00:42<00:00,  3.70it/s, epoch=7/100, avg_loss=0.151]
100%|██████████| 156/156 [00:41<00:00,  3.72it/s, epoch=8/100, avg_loss=0.149]
100%|██████████| 156/156 [00:43<00:00,  3.55it/s, epoch=9/100, avg_loss=0.133]
100%|██████████| 156/156 [00:43<00:00,  3.58it/s, epoch=10/100, avg_loss=0.127]
100%|██████████| 156/156 [00:43<00:00,  3.61it/s, epoch=11/100, avg_loss=0.125]
100%|██████████| 156/156 [00:41<00:00,  3.79it/

Model saved to ./model_checkpoints_optuna/model_epoch_100.pth


100%|██████████| 156/156 [00:38<00:00,  4.01it/s, epoch=1/100, avg_loss=0.266]
100%|██████████| 156/156 [00:42<00:00,  3.68it/s, epoch=2/100, avg_loss=0.151]
100%|██████████| 156/156 [00:43<00:00,  3.60it/s, epoch=3/100, avg_loss=0.126]
100%|██████████| 156/156 [00:51<00:00,  3.05it/s, epoch=4/100, avg_loss=0.112]
100%|██████████| 156/156 [00:54<00:00,  2.87it/s, epoch=5/100, avg_loss=0.103]
100%|██████████| 156/156 [00:48<00:00,  3.20it/s, epoch=6/100, avg_loss=0.0955]
100%|██████████| 156/156 [00:48<00:00,  3.19it/s, epoch=7/100, avg_loss=0.0925]
100%|██████████| 156/156 [00:40<00:00,  3.81it/s, epoch=8/100, avg_loss=0.0903]
100%|██████████| 156/156 [00:47<00:00,  3.28it/s, epoch=9/100, avg_loss=0.0888]
100%|██████████| 156/156 [00:50<00:00,  3.10it/s, epoch=10/100, avg_loss=0.09]  
100%|██████████| 156/156 [00:46<00:00,  3.34it/s, epoch=11/100, avg_loss=0.0864]
100%|██████████| 156/156 [00:46<00:00,  3.32it/s, epoch=12/100, avg_loss=0.0861]
100%|██████████| 156/156 [00:48<00:00,  3.

Model saved to ./model_checkpoints_optuna/model_epoch_100.pth


100%|██████████| 156/156 [00:46<00:00,  3.33it/s, epoch=1/100, avg_loss=1.05]
100%|██████████| 156/156 [00:45<00:00,  3.42it/s, epoch=2/100, avg_loss=0.787]
100%|██████████| 156/156 [00:47<00:00,  3.31it/s, epoch=3/100, avg_loss=0.639]
100%|██████████| 156/156 [00:44<00:00,  3.51it/s, epoch=4/100, avg_loss=0.511]
100%|██████████| 156/156 [00:38<00:00,  4.02it/s, epoch=5/100, avg_loss=0.412]
100%|██████████| 156/156 [00:34<00:00,  4.51it/s, epoch=6/100, avg_loss=0.335]
100%|██████████| 156/156 [00:35<00:00,  4.44it/s, epoch=7/100, avg_loss=0.279]
100%|██████████| 156/156 [00:36<00:00,  4.25it/s, epoch=8/100, avg_loss=0.254]
100%|██████████| 156/156 [00:40<00:00,  3.88it/s, epoch=9/100, avg_loss=0.237]
100%|██████████| 156/156 [00:38<00:00,  4.04it/s, epoch=10/100, avg_loss=0.232]
100%|██████████| 156/156 [00:47<00:00,  3.30it/s, epoch=11/100, avg_loss=0.228]
100%|██████████| 156/156 [00:46<00:00,  3.35it/s, epoch=12/100, avg_loss=0.231]
100%|██████████| 156/156 [00:43<00:00,  3.58it/s, 

Model saved to ./model_checkpoints_optuna/model_epoch_100.pth


100%|██████████| 78/78 [00:25<00:00,  3.03it/s, epoch=1/100, avg_loss=1.05]
100%|██████████| 78/78 [00:26<00:00,  2.91it/s, epoch=2/100, avg_loss=0.726]
100%|██████████| 78/78 [00:26<00:00,  2.93it/s, epoch=3/100, avg_loss=0.529]
100%|██████████| 78/78 [00:25<00:00,  3.09it/s, epoch=4/100, avg_loss=0.401]
100%|██████████| 78/78 [00:24<00:00,  3.22it/s, epoch=5/100, avg_loss=0.309]
100%|██████████| 78/78 [00:23<00:00,  3.36it/s, epoch=6/100, avg_loss=0.262]
100%|██████████| 78/78 [00:23<00:00,  3.36it/s, epoch=7/100, avg_loss=0.252]
100%|██████████| 78/78 [00:23<00:00,  3.30it/s, epoch=8/100, avg_loss=0.25] 
100%|██████████| 78/78 [00:23<00:00,  3.36it/s, epoch=9/100, avg_loss=0.234]
100%|██████████| 78/78 [00:23<00:00,  3.34it/s, epoch=10/100, avg_loss=0.229]
100%|██████████| 78/78 [00:23<00:00,  3.36it/s, epoch=11/100, avg_loss=0.233]
100%|██████████| 78/78 [00:23<00:00,  3.34it/s, epoch=12/100, avg_loss=0.225]
100%|██████████| 78/78 [00:23<00:00,  3.39it/s, epoch=13/100, avg_loss=0.2

Model saved to ./model_checkpoints_optuna/model_epoch_100.pth


100%|██████████| 156/156 [00:35<00:00,  4.45it/s, epoch=1/100, avg_loss=0.788]
100%|██████████| 156/156 [00:37<00:00,  4.20it/s, epoch=2/100, avg_loss=0.298]
100%|██████████| 156/156 [00:37<00:00,  4.16it/s, epoch=3/100, avg_loss=0.232]
100%|██████████| 156/156 [00:36<00:00,  4.24it/s, epoch=4/100, avg_loss=0.217]
100%|██████████| 156/156 [00:31<00:00,  4.95it/s, epoch=5/100, avg_loss=0.204]
100%|██████████| 156/156 [00:31<00:00,  4.90it/s, epoch=6/100, avg_loss=0.187]
100%|██████████| 156/156 [00:31<00:00,  4.93it/s, epoch=7/100, avg_loss=0.175]
100%|██████████| 156/156 [00:32<00:00,  4.83it/s, epoch=8/100, avg_loss=0.164]
100%|██████████| 156/156 [00:31<00:00,  4.95it/s, epoch=9/100, avg_loss=0.157]
100%|██████████| 156/156 [00:32<00:00,  4.87it/s, epoch=10/100, avg_loss=0.146]
100%|██████████| 156/156 [00:31<00:00,  4.90it/s, epoch=11/100, avg_loss=0.14] 
100%|██████████| 156/156 [00:32<00:00,  4.86it/s, epoch=12/100, avg_loss=0.136]
100%|██████████| 156/156 [00:31<00:00,  4.89it/s,

Model saved to ./model_checkpoints_optuna/model_epoch_100.pth


100%|██████████| 78/78 [00:13<00:00,  5.93it/s, epoch=1/100, avg_loss=0.775]
100%|██████████| 78/78 [00:13<00:00,  5.92it/s, epoch=2/100, avg_loss=0.305]
100%|██████████| 78/78 [00:13<00:00,  5.87it/s, epoch=3/100, avg_loss=0.233]
100%|██████████| 78/78 [00:12<00:00,  6.24it/s, epoch=4/100, avg_loss=0.228]
100%|██████████| 78/78 [00:11<00:00,  6.87it/s, epoch=5/100, avg_loss=0.22] 
100%|██████████| 78/78 [00:11<00:00,  6.85it/s, epoch=6/100, avg_loss=0.207]
100%|██████████| 78/78 [00:11<00:00,  6.73it/s, epoch=7/100, avg_loss=0.205]
100%|██████████| 78/78 [00:11<00:00,  6.74it/s, epoch=8/100, avg_loss=0.204]
100%|██████████| 78/78 [00:11<00:00,  6.79it/s, epoch=9/100, avg_loss=0.205]
100%|██████████| 78/78 [00:11<00:00,  6.82it/s, epoch=10/100, avg_loss=0.203]
100%|██████████| 78/78 [00:11<00:00,  6.79it/s, epoch=11/100, avg_loss=0.2]  
100%|██████████| 78/78 [00:11<00:00,  6.82it/s, epoch=12/100, avg_loss=0.185]
100%|██████████| 78/78 [00:11<00:00,  6.64it/s, epoch=13/100, avg_loss=0.

Model saved to ./model_checkpoints_optuna/model_epoch_100.pth


100%|██████████| 312/312 [02:03<00:00,  2.53it/s, epoch=1/100, avg_loss=0.279]
100%|██████████| 312/312 [01:49<00:00,  2.84it/s, epoch=2/100, avg_loss=0.168]
100%|██████████| 312/312 [02:04<00:00,  2.52it/s, epoch=3/100, avg_loss=0.135]
100%|██████████| 312/312 [02:05<00:00,  2.48it/s, epoch=4/100, avg_loss=0.115]
100%|██████████| 312/312 [01:55<00:00,  2.69it/s, epoch=5/100, avg_loss=0.104]
100%|██████████| 312/312 [01:48<00:00,  2.88it/s, epoch=6/100, avg_loss=0.0999]
100%|██████████| 312/312 [01:48<00:00,  2.87it/s, epoch=7/100, avg_loss=0.095] 
100%|██████████| 312/312 [02:05<00:00,  2.50it/s, epoch=8/100, avg_loss=0.0907]
100%|██████████| 312/312 [01:54<00:00,  2.72it/s, epoch=9/100, avg_loss=0.0888]
100%|██████████| 312/312 [01:48<00:00,  2.86it/s, epoch=10/100, avg_loss=0.0877]
100%|██████████| 312/312 [01:54<00:00,  2.73it/s, epoch=11/100, avg_loss=0.0872]
100%|██████████| 312/312 [01:59<00:00,  2.61it/s, epoch=12/100, avg_loss=0.0878]
100%|██████████| 312/312 [01:48<00:00,  2.

Model saved to ./model_checkpoints_optuna/model_epoch_100.pth


100%|██████████| 78/78 [00:14<00:00,  5.24it/s, epoch=1/100, avg_loss=0.837]
100%|██████████| 78/78 [2:47:33<00:00, 128.90s/it, epoch=2/100, avg_loss=0.555]    
100%|██████████| 78/78 [00:12<00:00,  6.20it/s, epoch=3/100, avg_loss=0.381]
100%|██████████| 78/78 [00:12<00:00,  6.16it/s, epoch=4/100, avg_loss=0.293]
100%|██████████| 78/78 [00:12<00:00,  6.12it/s, epoch=5/100, avg_loss=0.261]
100%|██████████| 78/78 [00:11<00:00,  6.55it/s, epoch=6/100, avg_loss=0.243]
100%|██████████| 78/78 [00:12<00:00,  6.01it/s, epoch=7/100, avg_loss=0.239]
100%|██████████| 78/78 [00:12<00:00,  6.28it/s, epoch=8/100, avg_loss=0.234]
100%|██████████| 78/78 [00:12<00:00,  6.42it/s, epoch=9/100, avg_loss=0.224]
100%|██████████| 78/78 [00:12<00:00,  6.41it/s, epoch=10/100, avg_loss=0.224]
100%|██████████| 78/78 [1:11:44<00:00, 55.19s/it, epoch=11/100, avg_loss=0.226]     
100%|██████████| 78/78 [00:12<00:00,  6.33it/s, epoch=12/100, avg_loss=0.212]
100%|██████████| 78/78 [00:13<00:00,  5.78it/s, epoch=13/10

Model saved to ./model_checkpoints_optuna/model_epoch_100.pth


 88%|████████▊ | 273/312 [04:26<00:38,  1.03it/s, epoch=1/100, avg_loss=0.18] 
[W 2024-03-04 09:27:05,569] Trial 8 failed with parameters: {'learning_rate': 0.006144956989616047, 'batch_size': 32, 'num_layers': 14, 'residual_channels': 64, 'dilation_cycle_length': 4} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "c:\Users\Admin\anaconda3\envs\difonedseg\lib\site-packages\optuna\study\_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
  File "C:\Users\Admin\AppData\local\Temp\ipykernel_25692\735566430.py", line 32, in objective
    log_interval=100)
  File "C:\Users\Admin\AppData\local\Temp\ipykernel_25692\2610101303.py", line 30, in train_model
    losses = model.log_prob(signals)
  File "c:\Users\Admin\Desktop\diffusion_ts\s\net\gaussian_diffusion.py", line 280, in log_prob
    x.reshape(B * T, 1, -1), time, *args, **kwargs
  File "c:\Users\Admin\Desktop\diffusion_ts\s\net\gaussian_diffusion.py", line 259, in p_los

KeyboardInterrupt: 

In [46]:
print("Number of finished trials: ", len(study.trials))
print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)
print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

Number of finished trials:  9
Best trial:
  Value:  0.07706547251979637
  Params: 
    learning_rate: 0.0007153912999894027
    batch_size: 64
    num_layers: 15
    residual_channels: 16
    dilation_cycle_length: 2
