In [None]:
import sys
import time
import torch

import logging
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)

logger.info("%s", sys.version_info)
logger.info("%s", torch.__version__)
logger.info("%s", torch.cuda.is_available())

import os
os.environ['WANDB_CONSOLE'] = "off"
torch.set_num_threads(8)
torch.set_num_interop_threads(8)
import shutil
import uuid

import wandb
from main import main
from torch.utils.tensorboard import SummaryWriter

In [None]:
import pytest
pytest.main(["-k", "."])

In [None]:
def build_name(config):
    name = []
    dataset = config["dataset"].lower()
    if config["noise_frac"] > 0.0:
        dataset = f"{dataset}[{config['noise_frac']}]"
    name.append(dataset)
    name.append(config["model"])
    loss = (
        config["loss"].lower().replace("debiased", "").replace("contrastive", "biased")
    )
    loss_args = [f"M={config['num_pos']}"]
    if config["drop_fn"]:
        loss_args.append("dropFN")
    if config["num_pos"] > 1:
        loss_args.append(config["m_agg_mode"])
    if len(loss_args) > 0:
        loss += "[{}]".format(",".join(loss_args))
    name.append(loss)
    name.append(f"bs={config['batch_size']}")
    if config.get("lr") != 1e-3:
        name.append(f"lr={config['lr']}")
    name = "-".join(name)
    return name

In [None]:
arch = "resnet50"
dataset = "CIFAR10"
batch_size = 64
noise_frac = 0.0
num_pos = 1
drop_fn = False
m_agg_mode = "loss_combination"
lr = 0.001
loss = "DebiasedPosV2"

run_uuid = uuid.uuid4()
root = "/path/to/model/data"
out = "/path/to/output/folder/{}".format(run_uuid)

os.mkdir(out)
name = build_name(
    {
        "dataset": dataset,
        "noise_frac": noise_frac,
        "model": arch,
        "loss": loss,
        "num_pos": num_pos,
        "drop_fn": drop_fn,
        "batch_size": batch_size,
        "m_agg_mode": m_agg_mode,
        "lr": lr,
    }
)
wandb.init(dir=out, name=name)
wandb.tensorboard.patch(root_logdir=out)
writer = SummaryWriter(out)

main(
    dataset,
    loss,
    root,
    batch_size,
    arch,
    cuda=True,
    writer=writer,
    epochs=200,
    tau_plus=0.1,
    num_pos=num_pos,
    drop_fn=drop_fn,
    noise_frac=noise_frac,
    m_agg_mode=m_agg_mode,
    run_uuid=run_uuid,
    lr=lr,
)
writer.close()
wandb.finish()
wandb.tensorboard.unpatch()