In [1]:
# !pip install segmentation_models_pytorch
# !pip install ray[tune]

In [2]:
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
from mydataset import MyDataset
from torch.cuda.amp import autocast
from data_augmentation import RandomRotation, RandomVerticalFlip, RandomHorizontalFlip, Compose, ToTensor
import segmentation_models_pytorch as smp
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.tune.schedulers import ASHAScheduler
import os

In [3]:
import warnings

warnings.filterwarnings("ignore")

In [4]:
batch_size = 100

data_dir = os.path.abspath("data")

train_set = MyDataset(root=data_dir, is_train=True, transform=Compose([
    ToTensor(),
    RandomHorizontalFlip(),
    RandomVerticalFlip(),
    RandomRotation([0, 90, 180, 270]),
]), normalize=transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))

val_set = MyDataset(root=data_dir, is_train=False, transform=Compose([ToTensor()]),
                    normalize=transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=24
)

val_loader = DataLoader(
    val_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=24
)

In [5]:
def get_index(pred, label):
    eps = 1e-7
    tp = torch.sum(label * pred)
    fp = torch.sum(pred) - tp
    fn = torch.sum(label) - tp

    p = (tp + eps) / (tp + fp + eps)
    r = (tp + eps) / (tp + fn + eps)
    f1 = (2 * p * r + eps) / (p + r + eps)
    iou = (tp + eps) / (tp + fn + fp + eps)
    return p, r, f1, iou

In [6]:
def train_tune(config):
    model = smp.UnetPlusPlus(
        encoder_name="resnet34",
        encoder_weights="imagenet",
        in_channels=6,
        classes=2,
    )
    criterion = nn.CrossEntropyLoss()
    optim = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-9)
    scaler = torch.cuda.amp.GradScaler(enabled=True)

    # To restore a checkpoint, use `session.get_checkpoint()`.
    loaded_checkpoint = session.get_checkpoint()
    if loaded_checkpoint:
        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
            model_state, optimizer_state = torch.load(os.path.join(loaded_checkpoint_dir, "checkpoint.pt"))
        model.load_state_dict(model_state)
        optim.load_state_dict(optimizer_state)
    model.cuda()
    num_epoch = 10

    def train_model():
        model.train()
        for img1, img2, mask in train_loader:
            img1, img2, mask = img1.cuda(), img2.cuda(), mask.cuda()
            optim.zero_grad()
            mask = mask.long()
            with autocast():
                img = torch.cat((img1, img2), 1)
                outputs = model(img)
                mask = mask.squeeze(1)
                loss = criterion(outputs, mask)
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
            # _, pred = torch.max(outputs.data, 1)
            # p, r, f1, iou = get_index(pred, mask)

    def test_model():
        model.eval()
        f1s = 0
        with torch.no_grad():
            for img1, img2, mask in val_loader:
                img1, img2, mask = img1.cuda(), img2.cuda(), mask.cuda()
                img = torch.cat((img1, img2), 1)
                outputs = model(img)
                _, pred = torch.max(outputs.data, 1)
                mask = mask.squeeze(1)
                p, r, f1, iou = get_index(pred, mask)
                f1s += f1
        f1s /= len(val_loader)
        # Here we save a checkpoint. It is automatically registered with
        # Ray Tune and can be accessed through `session.get_checkpoint()`
        # API in future iterations.
        os.makedirs("my_model", exist_ok=True)
        torch.save((model.state_dict(), optim.state_dict()), "my_model/checkpoint.pt")
        checkpoint = Checkpoint.from_directory("my_model")
        session.report({"f1_score": f1s.item()}, checkpoint=checkpoint)

    for epoch in range(0, num_epoch):
        train_model()
        test_model()

    print(config["lr"], " -------- completed!")

In [7]:
def main(num_samples=10, max_num_epochs=10):
    config = {
        "lr": tune.loguniform(1e-4, 1e-1),
    }
    scheduler = ASHAScheduler(
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2,
    )

    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(train_tune),
            resources={"cpu": 24, "gpu": 1}
        ),
        tune_config=tune.TuneConfig(
            metric="f1_score",
            mode="max",
            scheduler=scheduler,
            num_samples=num_samples,
        ),
        param_space=config,
    )

    results = tuner.fit()
    best_result = results.get_best_result("f1_score", "max")

    print("Best trial config: {}".format(best_result.config))
    print("Best trial final f1_score: {}".format(
        best_result.metrics["f1_score"]))


In [8]:
main()

2023-05-08 14:21:52,293	INFO worker.py:1625 -- Started a local Ray instance.
2023-05-08 14:21:53,199	INFO tune.py:218 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `Tuner(...)`.


0,1
Current time:,2023-05-08 14:56:17
Running for:,00:34:24.22
Memory:,6.8/29.3 GiB

Trial name,status,loc,lr,iter,total time (s),f1_score
train_tune_adfc8_00000,TERMINATED,10.110.3.60:18529,0.000878598,10,445.054,0.835072
train_tune_adfc8_00001,TERMINATED,10.110.3.60:18529,0.000261386,10,446.552,0.889191
train_tune_adfc8_00002,TERMINATED,10.110.3.60:18529,0.00314686,1,45.378,0.571029
train_tune_adfc8_00003,TERMINATED,10.110.3.60:18529,0.0008969,2,92.2786,0.692513
train_tune_adfc8_00004,TERMINATED,10.110.3.60:18529,0.00017914,10,447.808,0.894567
train_tune_adfc8_00005,TERMINATED,10.110.3.60:18529,0.000111051,8,359.963,0.87454
train_tune_adfc8_00006,TERMINATED,10.110.3.60:18529,0.00676385,1,43.8079,0.531785
train_tune_adfc8_00007,TERMINATED,10.110.3.60:18529,0.0268289,1,44.263,0.381449
train_tune_adfc8_00008,TERMINATED,10.110.3.60:18529,0.000182723,2,89.5578,0.766983
train_tune_adfc8_00009,TERMINATED,10.110.3.60:18529,0.0498112,1,45.6962,0.496408


Trial name,date,done,f1_score,hostname,iterations_since_restore,node_ip,pid,should_checkpoint,time_since_restore,time_this_iter_s,time_total_s,timestamp,training_iteration,trial_id
train_tune_adfc8_00000,2023-05-08_14-29-21,True,0.835072,fiv-991tkrr6l4t9-main,10,10.110.3.60,18529,True,445.054,44.2231,445.054,1683556161,10,adfc8_00000
train_tune_adfc8_00001,2023-05-08_14-36-48,True,0.889191,fiv-991tkrr6l4t9-main,10,10.110.3.60,18529,True,446.552,44.7243,446.552,1683556608,10,adfc8_00001
train_tune_adfc8_00002,2023-05-08_14-37-33,True,0.571029,fiv-991tkrr6l4t9-main,1,10.110.3.60,18529,True,45.378,45.378,45.378,1683556653,1,adfc8_00002
train_tune_adfc8_00003,2023-05-08_14-39-06,True,0.692513,fiv-991tkrr6l4t9-main,2,10.110.3.60,18529,True,92.2786,46.437,92.2786,1683556746,2,adfc8_00003
train_tune_adfc8_00004,2023-05-08_14-46-34,True,0.894567,fiv-991tkrr6l4t9-main,10,10.110.3.60,18529,True,447.808,45.1459,447.808,1683557194,10,adfc8_00004
train_tune_adfc8_00005,2023-05-08_14-52-34,True,0.87454,fiv-991tkrr6l4t9-main,8,10.110.3.60,18529,True,359.963,45.7234,359.963,1683557554,8,adfc8_00005
train_tune_adfc8_00006,2023-05-08_14-53-17,True,0.531785,fiv-991tkrr6l4t9-main,1,10.110.3.60,18529,True,43.8079,43.8079,43.8079,1683557597,1,adfc8_00006
train_tune_adfc8_00007,2023-05-08_14-54-02,True,0.381449,fiv-991tkrr6l4t9-main,1,10.110.3.60,18529,True,44.263,44.263,44.263,1683557642,1,adfc8_00007
train_tune_adfc8_00008,2023-05-08_14-55-31,True,0.766983,fiv-991tkrr6l4t9-main,2,10.110.3.60,18529,True,89.5578,45.271,89.5578,1683557731,2,adfc8_00008
train_tune_adfc8_00009,2023-05-08_14-56-17,True,0.496408,fiv-991tkrr6l4t9-main,1,10.110.3.60,18529,True,45.6962,45.6962,45.6962,1683557777,1,adfc8_00009


2023-05-08 14:56:17,502	INFO tune.py:945 -- Total run time: 2064.30 seconds (2064.21 seconds for the tuning loop).


Best trial config: {'lr': 0.0001791403737353412}
Best trial final f1_score: 0.8945671461454585
