In [None]:
config = {}

In [None]:
import wandb
import torch

In [None]:
# wandbのセットアップ
with open('../.secrets/wandb_key', 'r') as f:
    api_key = f.readline().rstrip('\n')
wandb.login(key=api_key)
run = wandb.init(
    project='bachelor_research',
    group='vae_assessment',
    name=config['exp'],
    # notes='',
    config=config,
    resume=True,  # False, 'auto'(True), 'must', 'allow'
    )

In [None]:
# wandb artifact にデータセットを保存
artifact = wandb.Artifact(name='dataset', type='dataset')
artifact.add_dir('../data', name='data')  # nameでartifact内でのパス(ダウンロード時使う)を指定
run.log_artifact(artifact)

In [None]:
checkpoint_path = f"./checkpoint.pth.tar"
n_epochs = 100
model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epoch = 0
train_losses = []
val_losses = []
best_metric = -np.inf

if wandb.run.resumed:
    print('resume run by wandb.')
    checkpoint = torch.load(wandb.restore(checkpoint_path))
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    train_losses = checkpoint['train_losses']
    val_losses = checkpoint['val_losses']
    best_metric = checkpoint['best_metric']

while epoch < n_epochs:
    # ----------
    # train here
    # ----------

    wandb.log({'train_loss': train_loss.item(), 'val_loss': val_loss.item()}, step=epoch+1)

    epoch += 1

    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'best_metric': best_metric
    }, checkpoint_path)

    wandb.save(checkpoint_path)

In [None]:
wandb.run.summary["best_metric"] = best_metric

In [None]:
# table = wandb.Table(columns=['image', 'label', 'predict', 'score'])
# table.add_data(image, label, predict, score)
# wandb.log({'predict_table': table})

In [None]:
wandb.alert(
    title='Succeed',
    text=f'EXP{config["exp"]} has succeeded.',
    level=wandb.AlertLevel.INFO)

wandb.finish()