In [6]:
import math
import wandb
from volume_dataloader import CTScanDataModule
from unet import UNet
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

wandb.login()

DATA_PATH = '/media/gaetano/DATA/DATA_NIFTI_JAWS/'

In [7]:
sweep_config = {
    'method': 'random'
}

metric = {
    'name': 'val_loss',
    'goal': 'minimize'
}

parameters_dict = {
    'loss_alpha': {
        'distribution': 'q_log_uniform',
        'q': 1,
        'min': math.log(1e-4),
        'max': math.log(1.0)
    },
    'loss_gamma': {
        'distribution': 'q_log_uniform',
        'q': 1,
        'min': math.log(1e-4),
        'max': math.log(3.0)
    },
    'learning_rate': {
        'distribution': 'q_log_uniform',
        'q': 1,
        'min': math.log(1e-5),
        'max': math.log(1e-1)
    },
    'batch_size': {
        'value': 5
    },
    'in_channels': {
        'value': 1
    },
    'out_channels': {
        'value': 4
    },
    'dim': {
        'value': 3
    },
    'attention': {
        'value': False
    },
    'max_epochs': {
        'value': 50
    }

}

sweep_config['metric'] = metric
sweep_config['parameters'] = parameters_dict

In [8]:
def train(config=None):
    with wandb.init(config=config):
        config = wandb.config

        data_loader = CTScanDataModule(DATA_PATH, batch_size=config.batch_size)
        model = UNet(in_channels=config.in_channels,
                     out_channels=config.out_channels,
                     n_blocks=4,
                     start_filters=32,
                     activation='relu',
                     normalization='batch',
                     conv_mode='same',
                     dim=config.dim,
                     attention=config.attention,
                     loss_alpha=config.loss_alpha,  # .7
                     loss_gamma=config.loss_gamma, # 3/4
                     learning_rate=config.learning_rate)
        checkpoint = ModelCheckpoint(monitor='val_loss')
        wandb_logger = WandbLogger()
        trainer = Trainer(gpus=-1,
                          log_every_n_steps=1,
                          max_epochs=config.max_epochs,
                          auto_lr_find=False,
                          callbacks=[checkpoint],
                          logger=wandb_logger)
        trainer.fit(model, data_loader)

In [9]:
sweep_id = wandb.sweep(sweep_config, project="ct-volume-preprocessing")

Create sweep with ID: up1hd729
Sweep URL: https://wandb.ai/playweird/ct-volume-preprocessing/sweeps/up1hd729


In [None]:
wandb.agent(sweep_id, train, count=2)

[34m[1mwandb[0m: Agent Starting Run: 0e4owwbx with config:
[34m[1mwandb[0m: 	attention: False
[34m[1mwandb[0m: 	batch_size: 5
[34m[1mwandb[0m: 	dim: 3
[34m[1mwandb[0m: 	in_channels: 1
[34m[1mwandb[0m: 	learning_rate: 0
[34m[1mwandb[0m: 	loss_alpha: 0
[34m[1mwandb[0m: 	loss_gamma: 1
[34m[1mwandb[0m: 	max_epochs: 50
[34m[1mwandb[0m: 	out_channels: 4


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"

  | Name       | Type       | Params
------------------------------------------
0 | down_convs | ModuleList | 3.5 M 
1 | up_convs   | ModuleList | 2.1 M 
2 | conv_final | Conv3d     | 132   
------------------------------------------
5.6 M     Trainable params
0         Non-trainable params
5.6 M     Total params
22.412    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]