In [54]:
import yaml
import wandb
from volume_dataloader import CTScanDataModule
from unet import UNet
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

# wandb.login(relogin=True)

DATA_PATH = '/media/gaetano/DATA/DATA_NIFTI_JAWS/'

In [55]:
sweep_config = {
    'method': 'random'
}

metric = {
    'name': 'val_loss',
    'goal': 'minimize'
}

parameters_dict = {
    'loss_alpha': {
        'distribution': 'uniform',
        'min': 1e-1,
        'max': 1.0
    },
    'loss_gamma': {
        'distribution': 'uniform',
        'min': 1e-1,
        'max': 3.0
    },
    'learning_rate': {
        'distribution': 'uniform',
        'min': 1e-5,
        'max': 1e-2
    },
    'batch_size': {
        'value': 5
    },
    'in_channels': {
        'value': 1
    },
    'out_channels': {
        'value': 4
    },
    'dim': {
        'value': 3
    },
    'attention': {
        'value': False
    },
    'max_epochs': {
        'value': 50
    }
}

sweep_config['metric'] = metric
sweep_config['parameters'] = parameters_dict

with open('/home/gaetano/PycharmProjects/ct-volume-preprocessing/wandb_config.yaml', 'w') as file:
    documents = yaml.dump(sweep_config, file)

In [56]:
def train(config=None):
    with wandb.init(config=config, project='utooth', entity='utooth'):
        config = wandb.config

        data_loader = CTScanDataModule(DATA_PATH, batch_size=config.batch_size)
        model = UNet(in_channels=config.in_channels,
                     out_channels=config.out_channels,
                     n_blocks=4,
                     start_filters=32,
                     activation='relu',
                     normalization='batch',
                     conv_mode='same',
                     dim=config.dim,
                     attention=config.attention,
                     loss_alpha=config.loss_alpha,  # .7
                     loss_gamma=config.loss_gamma, # 3/4
                     learning_rate=config.learning_rate)
        checkpoint = ModelCheckpoint(monitor='val_loss')
        wandb.finish()
        wandb_logger = WandbLogger(project='utooth', entity='utooth')
        trainer = Trainer(gpus=-1,
                          log_every_n_steps=1,
                          max_epochs=config.max_epochs,
                          auto_lr_find=False,
                          callbacks=[checkpoint],
                          logger=wandb_logger)
        trainer.fit(model, data_loader)

In [57]:
# sweep_id = wandb.sweep(sweep_config, project='utooth', entity="utooth")

In [None]:
wandb.agent('a8uyj71e', train, count=2)

[34m[1mwandb[0m: Agent Starting Run: ohoekt9x with config:
[34m[1mwandb[0m: 	attention: False
[34m[1mwandb[0m: 	batch_size: 5
[34m[1mwandb[0m: 	dim: 3
[34m[1mwandb[0m: 	in_channels: 1
[34m[1mwandb[0m: 	learning_rate: 0.0020359865172947687
[34m[1mwandb[0m: 	loss_alpha: 0.17049478320410105
[34m[1mwandb[0m: 	loss_gamma: 2.506042915435155
[34m[1mwandb[0m: 	max_epochs: 50
[34m[1mwandb[0m: 	out_channels: 4


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



  | Name       | Type       | Params
------------------------------------------
0 | down_convs | ModuleList | 3.5 M 
1 | up_convs   | ModuleList | 2.1 M 
2 | conv_final | Conv3d     | 132   
------------------------------------------
5.6 M     Trainable params
0         Non-trainable params
5.6 M     Total params
22.412    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]