In [66]:
import yaml
import wandb
from volume_dataloader import CTScanDataModule
from unet import UNet
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

# wandb.login(relogin=True)

DATA_PATH = '/media/gaetano/DATA/DATA_NIFTI_JAWS/'

In [67]:
# Create YAML to paste into wandb sweep
sweep_config = {
    # 'method': 'random'
    'method': 'bayes'
}

metric = {
    'name': 'val_loss',
    'goal': 'minimize'
}

parameters_dict = {
    'loss_alpha': {
        'distribution': 'uniform',
        'min': 1e-2,
        'max': 1.0
    },
    'loss_gamma': {
        'distribution': 'uniform',
        'min': 1e-1,
        'max': 3.0
    },
    'learning_rate': {
        'distribution': 'uniform',
        'min': 1e-5,
        'max': 1e-2
    },
    'batch_size': {
        'value': 5
    },
    'in_channels': {
        'value': 1
    },
    'out_channels': {
        'value': 4
    },
    'dim': {
        'value': 3
    },
    'attention': {
        'value': False
    },
    'max_epochs': {
        'value': 50
    }
}

sweep_config['metric'] = metric
sweep_config['parameters'] = parameters_dict

with open('/home/gaetano/PycharmProjects/ct-volume-preprocessing/wandb_config.yaml', 'w') as file:
    documents = yaml.dump(sweep_config, file)

In [68]:
def train(config=None):
    with wandb.init(config=config): # project='utooth', entity='utooth'
        config = wandb.config

        data_loader = CTScanDataModule(DATA_PATH, batch_size=config.batch_size)
        model = UNet(in_channels=config.in_channels,
                     out_channels=config.out_channels,
                     n_blocks=4,
                     start_filters=32,
                     activation='relu',
                     normalization='batch',
                     conv_mode='same',
                     dim=config.dim,
                     attention=config.attention,
                     loss_alpha=config.loss_alpha,  # .7
                     loss_gamma=config.loss_gamma, # 3/4
                     learning_rate=config.learning_rate)
        checkpoint = ModelCheckpoint(monitor='val_loss')
        wandb.finish()
        wandb_logger = WandbLogger() # project='utooth', entity='utooth'
        trainer = Trainer(gpus=-1,
                          log_every_n_steps=1,
                          max_epochs=config.max_epochs,
                          auto_lr_find=False,
                          callbacks=[checkpoint],
                          logger=wandb_logger)
        trainer.fit(model, data_loader)

In [69]:
# this creates a new sweep from the above config
# sweep_id = wandb.sweep(sweep_config, project='utooth', entity="utooth")

In [None]:
sweep_id_path = 'utooth/utooth/nx80ytvr' #'utooth/utooth/a8uyj71e'
wandb.agent(sweep_id_path, train, count=50)