In [11]:
import random
import numpy as np

import wandb
import torch
import lightning as L

from torchvision import transforms
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import WandbLogger

from train_adn import TrainADN
from datasets import ArtifactDataset

In [None]:
random_seed = 42

random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)

def worker_init_fn(worker_id):
    worker_seed = random_seed + worker_id
    random.seed(worker_seed)
    np.random.seed(worker_seed)
    torch.manual_seed(worker_seed)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

train = ArtifactDataset(
    'spineweb/train/artifact',
    'spineweb/train/no_artifact',
    transform=transform,
)

val = ArtifactDataset(
    'spineweb/val/artifact',
    'spineweb/val/no_artifact',
    transform=transform,
)

train_loader = DataLoader(
    train,
    batch_size=4,
    shuffle=True,
    worker_init_fn=worker_init_fn,
    num_workers=8,
    pin_memory=True,
)

val_loader = DataLoader(
    val,
    batch_size=4,
    worker_init_fn=worker_init_fn,
    num_workers=8,
    pin_memory=True,
)

In [None]:
def train():
    wandb.init()

    wandb_logger = WandbLogger()

    trainer = L.Trainer(
        logger=wandb_logger,
        precision='16-mixed',
        deterministic=True,
        max_epochs=50,
    )

    trainer.fit(
        TrainADN(hparams=dict(wandb.config)),
        train_loader,
        val_loader,
    )

    wandb.finish()

In [None]:
wandb_apikey = "YOUR_WANDB_API_KEY"
wandb.login(key=wandb_apikey)

sweep_config = {
    'method': 'random',
    'metric': {
      'name': 'val_fid',
      'goal': 'minimize',
    },
    'parameters': {
        'radon': {'value': False},
        'lr': {'max': 1e-4, 'min': 1e-7},
        'beta1': {'max': 0.6, 'min': 0.5},
        'beta2': {'max': 0.95, 'min': 0.9},
        'weight_decay': {'value': 0},#{'max': 0.1, 'min': 0.0},

        'w_adv': {'max': 20, 'min': 1},
        'w_loss': {'max': 20, 'min': 1},
    }
}

sweep_id = wandb.sweep(sweep_config, project="unsupervised-denoising")

wandb.agent(sweep_id, train, count=20)