# How to use Ray?

In [None]:
from ray import tune
import ray
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler


In [None]:
config = {
    "experiment_name": "test_asl_starter",
    "ds_path": Path("/mnt/e/Datasets/asl/asl_alphabet_train/asl_alphabet_train"),
    "ds_name": "asl",
    "name_fn": proxyattention.data_utils.asl_name_fn,
    "image_size": 224,
    "batch_size": 64,
    "epoch_steps": [1, 2],
    "enable_proxy_attention": True,
    "change_subset_attention": tune.loguniform(0.1, 0.8),
    "validation_split": 0.3,
    "shuffle_dataset": tune.choice([True, False]),
    "num_gpu": 1,
    "transfer_imagenet": False,
    "subset_images": 8000,
    "proxy_threshold": tune.loguniform(0.008, 0.01),
    "pixel_replacement_method": tune.choice(["mean", "max", "min", "black", "white"]),
    "model": "resnet18",
    # "proxy_steps": tune.choice([[1, "p", 1], [3, "p", 1], [1, 1], [3,1]]),
    # "proxy_steps": tune.choice([["p", 1],[1, 1], ["p",1], [1, "p",1], [1,1,1]]),
    "proxy_steps": tune.choice([["p",3]]),
    "load_proxy_data": False,
}


In [None]:
def train_model(
    model,
    criterion,
    optimizer,
    scheduler,
    dataloaders,
    dataset_sizes,
    num_epochs=25,
    proxy_step=False,
    config=None,
):
    writer = SummaryWriter(log_dir=config["fname_start"], comment=config["fname_start"])
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    pbar = tqdm(range(num_epochs), total=num_epochs)
    for epoch in pbar:
        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.

            optimizer.zero_grad(set_to_none=True)
            scaler = torch.cuda.amp.GradScaler()
            for inps in tqdm(
                dataloaders[phase], total=len(dataloaders[phase]), leave=False
            ):
                inputs = inps["x"].to(config["device"], non_blocking=True)
                labels = inps["y"].to(config["device"], non_blocking=True)

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == "train"):
                    with torch.cuda.amp.autocast():
                        if phase == "train":
                            outputs = model(inputs)
                        else:
                            with torch.no_grad():
                                outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)
        
                    # backward + optimize only if in training phase
                    if phase == "train":
                        scaler.scale(loss).backward()
                        # optimizer.step()

                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad(set_to_none=True)

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                pbar.set_postfix(
                    {
                        "Phase": "running",
                        "Loss": running_loss / dataset_sizes[phase],
                        # 'Acc' : running_corrects.double() / dataset_sizes[phase],
                    }
                )

            if phase == "train":
                scheduler.step()

            # --------------------------------------------------------------
            # --------------------------------------------------------------
            # --------------------------------------------------------------
            # --------------------------------------------------------------
            # --------------------------------------------------------------
            # --------------------------------------------------------------

            for ind in tqdm(range(len(label_wrong)), total=len(label_wrong)):
                    # original_images[ind][grad_thresholds[ind]] = pixel_replacement[ind]
                    # TODO Split these into individual comprehensions for speed
                    # TODO Check if % of image is gone or not
                    original_images[ind][
                        grads[ind].mean(axis=2) > config["proxy_threshold"]
                    ] = decide_pixel_replacement(
                        original_image=original_images[ind],
                        method=config["pixel_replacement_method"],
                    )


            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            pbar.set_postfix({"Phase": phase, "Loss": epoch_loss, "Acc": epoch_acc})

            # TODO Add more loss functions
            # TODO Classwise accuracy
            if proxy_step == True:
                writer.add_scalar("proxy_step", True)
            else:
                writer.add_scalar("proxy_step", False)

            if phase == "train":
                writer.add_scalar("Loss/Train", epoch_loss, epoch)
                writer.add_scalar("Acc/Train", epoch_acc, epoch)
            if phase == "val":
                writer.add_scalar("Loss/Val", epoch_loss, epoch)
                writer.add_scalar("Acc/Val", epoch_acc, epoch)
                with tune.checkpoint_dir(epoch) as checkpoint_dir:
                    save_path = Path(config["fname_start"]) / "checkpoint"
                    torch.save((model.state_dict(), optimizer.state_dict()), save_path)

                tune.report(loss=epoch_loss, accuracy=epoch_acc)

            # deep copy the model
            if phase == "val" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        # print()

    time_elapsed = time.time() - since
    print(f"Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    print(f"Best val Acc: {best_acc:4f}")

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model



In [None]:
def setup_train_round(config, proxy_step=False, num_epochs=1):
    # Data part
    #TODO Configure data for proxy attention
    train, val = create_folds(config)
    image_datasets, dataloaders, dataset_sizes = create_dls(
        train, val, config
    )
    class_names = image_datasets["train"].classes
    config["num_classes"] = len(config["label_map"].keys())

    model_ft = choose_network(config)
    criterion = nn.CrossEntropyLoss()

    # Observe that all parameters are being optimized
    optimizer_ft = optim.Adam(model_ft.parameters(), lr=3e-4)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    trained_model = train_model(
        model_ft,
        criterion,
        optimizer_ft,
        exp_lr_scheduler,
        dataloaders,
        dataset_sizes,
        num_epochs=num_epochs,
        config=config,
        proxy_step=proxy_step,
    )


def train_proxy_steps(config):
    assert torch.cuda.is_available()
    for step in config["proxy_steps"]:
        if step == "p":
            setup_train_round(config=config, proxy_step=True, num_epochs=1)
            config["load_proxy_data"] = True
        else:
            setup_train_round(config=config, proxy_step=False, num_epochs=step)
            config["load_proxy_data"] = False

def tune_func(config):
    # tune.utils.wait_for_gpu(target_util = .1)
    train_proxy_steps(config=config)

In [None]:
def hyperparam_tune(config):
    ray.init(num_gpus=1, num_cpus=12)
    scheduler = ASHAScheduler(
        metric="loss", mode="min", max_t=30, grace_period=1, reduction_factor=2,
    )

    reporter = CLIReporter(metric_columns=["loss", "accuracy", "training_iteration"])

    result = tune.run(
        tune_func,
        config=config,
        scheduler=scheduler,
        progress_reporter=reporter,
        checkpoint_at_end=True,
        max_failures=100,
        num_samples=50,
        resources_per_trial={
            "gpu": 1,
            "cpu": 8,
        },
        local_dir=config["fname_start"],
    )

    df_res = result.get_dataframe()
    df_res.to_csv(Path(config["fname_start"]) / "result_log.csv")
    best_trial = result.get_best_trial("loss", "min", "last")
    print("Best trial config: {}".format(best_trial.config))
    print("Best trial final validation loss: {}".format(best_trial.last_result["loss"]))
    print(
        "Best trial final validation accuracy: {}".format(
            best_trial.last_result["accuracy"]
        )
    )

    print(result)