In [1]:
from TinyImageNetLoader import TinyImageNetDataset
import torchvision.transforms as transforms
import torch

normalize = transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

valset = TinyImageNetDataset("/datasets/tiny-imagenet-200", mode="val", transform=val_transform)
#print(next(enumerate(validation_loader)))

train_transform =  transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomRotation(20),
    transforms.RandomHorizontalFlip(0.5),
    normalize
])

trainset = TinyImageNetDataset("/datasets/tiny-imagenet-200", transform=train_transform)

  from tqdm.autonotebook import tqdm


Preloading val data...:   0%|          | 0/10000 [00:00<?, ?it/s]

Preloading train data...:   0%|          | 0/100000 [00:00<?, ?it/s]

In [2]:
validation_loader = torch.utils.data.DataLoader(
        valset, batch_size=1024, shuffle=False, num_workers=4)
training_loader = torch.utils.data.DataLoader(
        trainset, batch_size=1024, shuffle=True, num_workers=4)

In [3]:
eps=1e-10
import torch.nn as nn
def loss_fn(out, labels, predicate_matrix):
    out = out.view(-1, 1, NUM_FEATURES) # out is a batch of 1D binary vectors
    ANDed = out * predicate_matrix # AND operation
    diff = ANDed - out # Difference of ANDed and out => if equal, then out is a subset of its class' predicates

    entr_loss = nn.CrossEntropyLoss()
    loss_cl = entr_loss(diff.sum(dim=2), labels) # Is "out" a subset of its class' predicates?

    batch_size = out.shape[0]

    classes = torch.zeros(batch_size, NUM_CLASSES, device="cuda")
    classes[torch.arange(batch_size), labels] = 1
    classes = classes.view(batch_size, NUM_CLASSES, 1).expand(batch_size, NUM_CLASSES, NUM_FEATURES)

    extra_features = out - predicate_matrix + (out - predicate_matrix).pow(2)

    loss_neg_ft = torch.masked_select(extra_features, (1-classes).bool()).view(-1, NUM_FEATURES).sum() / batch_size

    labels_predicate = predicate_matrix[labels]
    extra_features_in = torch.masked_select(extra_features, classes.bool()).view(-1, NUM_FEATURES)
    loss_pos_ft = (labels_predicate - out.view(batch_size, NUM_FEATURES) + extra_features_in/2).sum() / batch_size
    
    feature_losses = loss_neg_ft * FT_WEIGHT * loss_cl.item()/(loss_neg_ft.item() + eps) + loss_pos_ft * POS_FT_WEIGHT * loss_cl.item()/(loss_pos_ft.item() + eps)

    return loss_cl + feature_losses

def train_one_epoch():
    running_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data["images"], data["labels"]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs, commit_loss, predicate_matrix = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels, predicate_matrix) + commit_loss
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()

    return running_loss / (i+1)

In [6]:
from torchmetrics import Accuracy
import sys, os
sys.path.insert(0, "/".join(os.path.abspath('').split("/")[:-1]) + "/models")
print("/".join(os.path.abspath('').split("/")[:-1]) + "/models")
from ResnetAutoPredicates import ResExtr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

NUM_FEATURES = 352
NUM_CLASSES = 200
EPOCHS = 30
accuracy = Accuracy(task="multiclass", num_classes=NUM_CLASSES, top_k=1).to(device)

POS_FT_WEIGHT = 0
FT_WEIGHT = 0

model = ResExtr(NUM_FEATURES, NUM_CLASSES, pretrained=True).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, weight_decay=1e-5)

best_vloss = 1_000_000.

best_stats = {
    "epoch": 0,
    "train_loss": 0,
    "val_loss": 0,
    "val_acc": 0,
    "val_fp": 0,
}

from tqdm import tqdm
for epoch in tqdm(range(EPOCHS)):
    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch()

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()
    running_acc = 0.0
    running_false_positives = 0.0

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata["images"], vdata["labels"]
            vinputs = vinputs.to(device)
            vlabels = vlabels.to(device)
            voutputs, vcommit_loss, predicate_matrix = model(vinputs)
            vloss = loss_fn(voutputs, vlabels, predicate_matrix) + vcommit_loss
            running_vloss += vloss.item()
            voutputs = voutputs.view(-1, 1, NUM_FEATURES)
            ANDed = voutputs * predicate_matrix
            diff = ANDed - voutputs
            running_acc += accuracy(diff.sum(dim=2), vlabels)
            voutputs = voutputs.view(-1, NUM_FEATURES)
            running_false_positives += ((predicate_matrix[vlabels] - voutputs) == -1).sum() / voutputs.shape[0]

    avg_vloss = running_vloss / (i + 1)
    avg_acc = running_acc / (i + 1)
    avg_false_positives = running_false_positives / (i + 1)
    print(f"LOSS: {avg_vloss}, ACC: {avg_acc}, FP: {avg_false_positives}")
    #print(model.bin_quantize._codebook.embed)

    if best_stats["val_acc"] < avg_acc:
        best_vloss = avg_vloss
        best_stats["epoch"] = epoch
        best_stats["train_loss"] = avg_loss
        best_stats["val_loss"] = avg_vloss
        best_stats["val_acc"] = avg_acc.item()
        best_stats["val_fp"] = avg_false_positives.item()


print(best_stats)

/notebooks/Concept_ZSL/src/models
Device: cuda


  3%|▎         | 1/30 [00:15<07:22, 15.25s/it]

LOSS: 4.7909797668457035, ACC: 0.21082191169261932, FP: 42.2048225402832


  7%|▋         | 2/30 [00:30<07:14, 15.53s/it]

LOSS: 4.130327272415161, ACC: 0.3118981122970581, FP: 52.329994201660156


 10%|█         | 3/30 [00:46<06:54, 15.37s/it]

LOSS: 3.8393144845962524, ACC: 0.35829681158065796, FP: 56.3805046081543


 13%|█▎        | 4/30 [01:01<06:37, 15.29s/it]

LOSS: 3.7001068353652955, ACC: 0.38954877853393555, FP: 59.440650939941406


 17%|█▋        | 5/30 [01:16<06:22, 15.28s/it]

LOSS: 3.7008814573287965, ACC: 0.4046715795993805, FP: 61.81391525268555


 20%|██        | 6/30 [01:32<06:12, 15.51s/it]

LOSS: 3.6758432149887086, ACC: 0.41243621706962585, FP: 63.7144889831543


 23%|██▎       | 7/30 [01:48<05:57, 15.53s/it]

LOSS: 3.6778097629547117, ACC: 0.43310546875, FP: 66.6617202758789


 27%|██▋       | 8/30 [02:03<05:40, 15.47s/it]

LOSS: 3.8360251903533937, ACC: 0.4284937083721161, FP: 70.02547454833984


 30%|███       | 9/30 [02:19<05:26, 15.56s/it]

LOSS: 3.788875389099121, ACC: 0.4422851502895355, FP: 70.44908905029297


 33%|███▎      | 10/30 [02:34<05:11, 15.59s/it]

LOSS: 3.915377116203308, ACC: 0.4408721625804901, FP: 71.62947082519531


 37%|███▋      | 11/30 [02:50<04:55, 15.56s/it]

LOSS: 4.000402545928955, ACC: 0.4496830999851227, FP: 73.68814086914062


 40%|████      | 12/30 [03:05<04:38, 15.48s/it]

LOSS: 4.025996589660645, ACC: 0.45195111632347107, FP: 75.21492767333984


 43%|████▎     | 13/30 [03:21<04:23, 15.48s/it]

LOSS: 4.091743397712707, ACC: 0.45336219668388367, FP: 76.37532806396484


 47%|████▋     | 14/30 [03:36<04:08, 15.50s/it]

LOSS: 4.217182159423828, ACC: 0.45047831535339355, FP: 78.08949279785156


 50%|█████     | 15/30 [03:51<03:51, 15.43s/it]

LOSS: 4.279665064811707, ACC: 0.46148356795310974, FP: 79.62401580810547


 53%|█████▎    | 16/30 [04:07<03:36, 15.45s/it]

LOSS: 4.398585557937622, ACC: 0.45488080382347107, FP: 81.26549530029297


 57%|█████▋    | 17/30 [04:23<03:22, 15.56s/it]

LOSS: 4.451663589477539, ACC: 0.4615214467048645, FP: 82.36861419677734


 60%|██████    | 18/30 [04:38<03:06, 15.51s/it]

LOSS: 4.501697206497193, ACC: 0.46306800842285156, FP: 84.1758041381836


 63%|██████▎   | 19/30 [04:54<02:52, 15.72s/it]

LOSS: 4.556060838699341, ACC: 0.4603196680545807, FP: 85.20101165771484


 67%|██████▋   | 20/30 [05:10<02:36, 15.61s/it]

LOSS: 4.616769504547119, ACC: 0.46093350648880005, FP: 85.8296127319336


 70%|███████   | 21/30 [05:27<02:24, 16.03s/it]

LOSS: 4.790542507171631, ACC: 0.45850205421447754, FP: 87.0041275024414


 73%|███████▎  | 22/30 [05:42<02:05, 15.74s/it]

LOSS: 4.765489625930786, ACC: 0.4667111933231354, FP: 87.40892791748047


 77%|███████▋  | 23/30 [05:57<01:48, 15.56s/it]

LOSS: 4.78899359703064, ACC: 0.4632274806499481, FP: 89.09098815917969


 80%|████████  | 24/30 [06:12<01:32, 15.48s/it]

LOSS: 4.926737642288208, ACC: 0.46592196822166443, FP: 89.8622817993164


 83%|████████▎ | 25/30 [06:28<01:17, 15.45s/it]

LOSS: 4.998494672775268, ACC: 0.4665597081184387, FP: 90.2645492553711


 87%|████████▋ | 26/30 [06:43<01:02, 15.56s/it]

LOSS: 5.008395385742188, ACC: 0.4600406587123871, FP: 91.72443389892578


 90%|█████████ | 27/30 [06:59<00:46, 15.60s/it]

LOSS: 5.115498781204224, ACC: 0.46100130677223206, FP: 91.05828094482422


 93%|█████████▎| 28/30 [07:15<00:31, 15.62s/it]

LOSS: 5.126066446304321, ACC: 0.46229472756385803, FP: 92.78614044189453


 97%|█████████▋| 29/30 [07:31<00:15, 15.65s/it]

LOSS: 5.183367586135864, ACC: 0.46258172392845154, FP: 92.47284698486328


100%|██████████| 30/30 [07:46<00:00, 15.56s/it]

LOSS: 5.274007844924927, ACC: 0.4619559347629547, FP: 93.1252212524414
{'epoch': 21, 'train_loss': 1.5405966140785996, 'val_loss': 4.765489625930786, 'val_acc': 0.4667111933231354, 'val_fp': 87.40892791748047}





# Resnet Baseline

In [7]:
def train_one_epoch_baseline():
    running_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data["images"], data["labels"]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn_baseline(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()

    return running_loss / (i+1)

In [13]:
from torchmetrics import Accuracy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

NUM_CLASSES = 200
EPOCHS = 15
accuracy = Accuracy(task="multiclass", num_classes=NUM_CLASSES, top_k=1).to(device)

POS_FT_WEIGHT = 0
FT_WEIGHT = 0

from torchvision.models import resnet18, ResNet18_Weights
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(512, NUM_CLASSES)

model = model.to(device)

loss_fn_baseline = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, weight_decay=1e-5)

best_vloss = 1_000_000.

best_stats = {
    "epoch": 0,
    "train_loss": 0,
    "val_loss": 0,
    "val_acc": 0,
}

from tqdm import tqdm
for epoch in tqdm(range(EPOCHS)):
    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch_baseline()

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()
    running_acc = 0.0
    running_false_positives = 0.0

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata["images"], vdata["labels"]
            vinputs = vinputs.to(device)
            vlabels = vlabels.to(device)
            voutputs = model(vinputs)
            vloss = loss_fn_baseline(voutputs, vlabels)
            running_vloss += vloss.item()
            running_acc += accuracy(voutputs, vlabels)

    avg_vloss = running_vloss / (i + 1)
    avg_acc = running_acc / (i + 1)
    print(f"LOSS: {avg_vloss}, ACC: {avg_acc}")

    if best_stats["val_acc"] < avg_acc:
        best_vloss = avg_vloss
        best_stats["epoch"] = epoch
        best_stats["train_loss"] = avg_loss
        best_stats["val_loss"] = avg_vloss
        best_stats["val_acc"] = avg_acc.item()

print(best_stats)

Device: cuda


  7%|▋         | 1/15 [00:15<03:32, 15.20s/it]

LOSS: 2.7538333177566527, ACC: 0.3666932284832001


 13%|█▎        | 2/15 [00:30<03:16, 15.14s/it]

LOSS: 2.268373465538025, ACC: 0.45415934920310974


 20%|██        | 3/15 [00:45<03:02, 15.21s/it]

LOSS: 2.1400376319885255, ACC: 0.4812440574169159


 27%|██▋       | 4/15 [01:01<02:48, 15.34s/it]

LOSS: 1.9924053430557251, ACC: 0.5112384557723999


 33%|███▎      | 5/15 [01:16<02:33, 15.37s/it]

LOSS: 1.9608011841773987, ACC: 0.5201749801635742


 40%|████      | 6/15 [01:31<02:18, 15.38s/it]

LOSS: 1.9408017992973328, ACC: 0.5298649072647095


 47%|████▋     | 7/15 [01:47<02:04, 15.54s/it]

LOSS: 1.9406044483184814, ACC: 0.5277164578437805


 53%|█████▎    | 8/15 [02:02<01:47, 15.40s/it]

LOSS: 1.9414733171463012, ACC: 0.5363759398460388


 60%|██████    | 9/15 [02:18<01:32, 15.38s/it]

LOSS: 1.9837289929389954, ACC: 0.5289102792739868


 67%|██████▋   | 10/15 [02:33<01:16, 15.33s/it]

LOSS: 2.0138878226280212, ACC: 0.5291215181350708


 73%|███████▎  | 11/15 [02:48<01:01, 15.31s/it]

LOSS: 2.0528515219688415, ACC: 0.5269291996955872


 80%|████████  | 12/15 [03:03<00:45, 15.20s/it]

LOSS: 2.074764096736908, ACC: 0.5255321860313416


 87%|████████▋ | 13/15 [03:18<00:30, 15.17s/it]

LOSS: 2.1348794221878054, ACC: 0.5197624564170837


 93%|█████████▎| 14/15 [03:33<00:15, 15.18s/it]

LOSS: 2.1823304891586304, ACC: 0.523909866809845


100%|██████████| 15/15 [03:48<00:00, 15.26s/it]

LOSS: 2.1891473531723022, ACC: 0.528218686580658
{'epoch': 7, 'train_loss': 1.1221286131411183, 'val_loss': 1.9414733171463012, 'val_acc': 0.5363759398460388}





# Optuna

In [20]:
eps = 1e-10
def loss_fn_optuna(out, labels, predicate_matrix, NUM_FEATURES, FT_WEIGHT, POS_FT_WEIGHT):
        out = out.view(-1, 1, NUM_FEATURES) # out is a batch of 1D binary vectors
        ANDed = out * predicate_matrix # AND operation
        diff = ANDed - out # Difference of ANDed and out => if equal, then out is a subset of its class' predicates

        entr_loss = nn.CrossEntropyLoss()
        loss_cl = entr_loss(diff.sum(dim=2), labels) # Is "out" a subset of its class' predicates?

        batch_size = out.shape[0]

        classes = torch.zeros(batch_size, NUM_CLASSES, device="cuda")
        classes[torch.arange(batch_size), labels] = 1
        classes = classes.view(batch_size, NUM_CLASSES, 1).expand(batch_size, NUM_CLASSES, NUM_FEATURES)

        extra_features = out - predicate_matrix + (out - predicate_matrix).pow(2)

        loss_neg_ft = torch.masked_select(extra_features, (1-classes).bool()).view(-1, NUM_FEATURES).sum() / batch_size

        labels_predicate = predicate_matrix[labels]
        extra_features_in = torch.masked_select(extra_features, classes.bool()).view(-1, NUM_FEATURES)
        loss_pos_ft = (labels_predicate - out.view(batch_size, NUM_FEATURES) + extra_features_in/2).sum() / batch_size

        return loss_cl + loss_neg_ft * FT_WEIGHT * loss_cl.item()/(loss_neg_ft.item() + eps) + loss_pos_ft * POS_FT_WEIGHT * loss_cl.item()/(loss_pos_ft.item() + eps)

def train_one_epoch_optuna(model, optimizer, NUM_FEATURES, FT_WEIGHT, POS_FT_WEIGHT):
    running_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data["images"], data["labels"]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs, commit_loss, predicate_matrix = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn_optuna(outputs, labels, predicate_matrix, NUM_FEATURES, FT_WEIGHT, POS_FT_WEIGHT) + commit_loss
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()

    return running_loss / (i+1)

from tqdm import tqdm
from torch import optim
def objective(trial):
    global trial_num
    trial_num += 1
    print(f"Starting trial {trial_num}")
    NUM_FEATURES = trial.suggest_int("num_features", 0, 12)
    FT_WEIGHT = trial.suggest_float("ft_weight", 0, 1.5)
    POS_FT_WEIGHT = trial.suggest_float("ft_pos_weight", 0, 1.5)
    # Generate the model.
    model = ResExtr(256+NUM_FEATURES*16, NUM_CLASSES, pretrained=True).to(device)

    # Generate the optimizers.
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    EPOCHS = 30

    best_acc = 0.0

    for epoch in tqdm(range(EPOCHS)):
        # Make sure gradient tracking is on, and do a pass over the data
        model.train(True)
        _ = train_one_epoch_optuna(model, optimizer, 256+NUM_FEATURES*16, FT_WEIGHT, POS_FT_WEIGHT)

        model.eval()
        running_acc = 0.0

        # Disable gradient computation and reduce memory consumption.
        with torch.no_grad():
            for i, vdata in enumerate(validation_loader):
                vinputs, vlabels = vdata["images"], vdata["labels"]
                vinputs = vinputs.to(device)
                vlabels = vlabels.to(device)
                voutputs, _, predicate_matrix = model(vinputs)
                voutputs = voutputs.view(-1, 1, 256+NUM_FEATURES*16)
                ANDed = voutputs * predicate_matrix
                diff = ANDed - voutputs
                running_acc += accuracy(diff.sum(dim=2), vlabels)

        avg_acc = running_acc / (i + 1)

        if avg_acc > best_acc:
            best_acc = avg_acc

        if epoch == 1 and best_acc < 0.1:
            raise optuna.TrialPruned()
        elif epoch == 4 and best_acc < 0.3:
            raise optuna.TrialPruned()
    
    return best_acc

In [21]:
import optuna

accuracy = Accuracy(task="multiclass", num_classes=NUM_CLASSES, top_k=1).to(device)

trial_num = -1

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100)

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    if key == "num_features":
        print("    {}: {}".format(key, 256+value*16))
    else:
        print("    {}: {}".format(key, value))

[I 2023-09-01 20:39:07,461] A new study created in memory with name: no-name-e69b5ece-29ae-4674-9f35-ca2216a41575


Starting trial 0


  3%|▎         | 1/30 [00:29<14:09, 29.28s/it]
[I 2023-09-01 20:39:36,953] Trial 0 pruned. 


Starting trial 1


  3%|▎         | 1/30 [00:29<14:09, 29.30s/it]
[I 2023-09-01 20:40:06,425] Trial 1 pruned. 


Starting trial 2


  3%|▎         | 1/30 [00:29<14:19, 29.64s/it]
[I 2023-09-01 20:40:36,242] Trial 2 pruned. 


Starting trial 3


  3%|▎         | 1/30 [00:29<14:17, 29.56s/it]
[I 2023-09-01 20:41:06,028] Trial 3 pruned. 


Starting trial 4


  3%|▎         | 1/30 [00:29<14:12, 29.39s/it]
[I 2023-09-01 20:41:35,601] Trial 4 pruned. 


Starting trial 5


  3%|▎         | 1/30 [00:29<14:23, 29.77s/it]
[I 2023-09-01 20:42:05,630] Trial 5 pruned. 


Starting trial 6


  3%|▎         | 1/30 [00:29<14:16, 29.55s/it]
[I 2023-09-01 20:42:35,517] Trial 6 pruned. 


Starting trial 7


  3%|▎         | 1/30 [00:29<14:18, 29.62s/it]
[I 2023-09-01 20:43:05,320] Trial 7 pruned. 


Starting trial 8


  3%|▎         | 1/30 [00:29<14:06, 29.19s/it]
[I 2023-09-01 20:43:34,693] Trial 8 pruned. 


Starting trial 9


  3%|▎         | 1/30 [00:29<14:13, 29.44s/it]
[I 2023-09-01 20:44:04,367] Trial 9 pruned. 


Starting trial 10


  3%|▎         | 1/30 [00:29<14:16, 29.53s/it]
[I 2023-09-01 20:44:34,099] Trial 10 pruned. 


Starting trial 11


  3%|▎         | 1/30 [00:29<14:13, 29.42s/it]
[I 2023-09-01 20:45:03,715] Trial 11 pruned. 


Starting trial 12


  3%|▎         | 1/30 [00:29<14:23, 29.79s/it]
[I 2023-09-01 20:45:33,787] Trial 12 pruned. 


Starting trial 13


  3%|▎         | 1/30 [00:29<14:13, 29.41s/it]
[I 2023-09-01 20:46:03,417] Trial 13 pruned. 


Starting trial 14


  3%|▎         | 1/30 [00:29<14:04, 29.11s/it]
[I 2023-09-01 20:46:32,763] Trial 14 pruned. 


Starting trial 15


  3%|▎         | 1/30 [00:29<14:16, 29.53s/it]
[I 2023-09-01 20:47:02,490] Trial 15 pruned. 


Starting trial 16


  3%|▎         | 1/30 [00:30<14:40, 30.37s/it]
[I 2023-09-01 20:47:33,058] Trial 16 pruned. 


Starting trial 17


  3%|▎         | 1/30 [00:29<14:18, 29.62s/it]
[I 2023-09-01 20:48:02,935] Trial 17 pruned. 


Starting trial 18


  3%|▎         | 1/30 [00:29<14:07, 29.24s/it]
[I 2023-09-01 20:48:32,377] Trial 18 pruned. 


Starting trial 19


  3%|▎         | 1/30 [00:29<14:06, 29.18s/it]
[I 2023-09-01 20:49:01,758] Trial 19 pruned. 


Starting trial 20


  3%|▎         | 1/30 [00:29<14:24, 29.81s/it]
[I 2023-09-01 20:49:31,774] Trial 20 pruned. 


Starting trial 21


  3%|▎         | 1/30 [00:29<14:07, 29.23s/it]
[I 2023-09-01 20:50:01,351] Trial 21 pruned. 


Starting trial 22


  3%|▎         | 1/30 [00:29<14:24, 29.82s/it]
[I 2023-09-01 20:50:31,519] Trial 22 pruned. 


Starting trial 23


  3%|▎         | 1/30 [00:29<14:23, 29.79s/it]
[I 2023-09-01 20:51:01,496] Trial 23 pruned. 


Starting trial 24


  3%|▎         | 1/30 [00:29<14:23, 29.77s/it]
[I 2023-09-01 20:51:31,470] Trial 24 pruned. 


Starting trial 25


  3%|▎         | 1/30 [00:29<14:22, 29.75s/it]
[I 2023-09-01 20:52:01,418] Trial 25 pruned. 


Starting trial 26


  3%|▎         | 1/30 [00:30<14:31, 30.06s/it]
[I 2023-09-01 20:52:31,805] Trial 26 pruned. 


Starting trial 27


  3%|▎         | 1/30 [00:29<14:22, 29.73s/it]
[I 2023-09-01 20:53:01,766] Trial 27 pruned. 


Starting trial 28


  3%|▎         | 1/30 [00:30<14:40, 30.35s/it]
[I 2023-09-01 20:53:32,317] Trial 28 pruned. 


Starting trial 29


  3%|▎         | 1/30 [00:29<14:19, 29.65s/it]
[I 2023-09-01 20:54:02,383] Trial 29 pruned. 


Starting trial 30


  3%|▎         | 1/30 [00:29<14:12, 29.38s/it]
[I 2023-09-01 20:54:32,069] Trial 30 pruned. 


Starting trial 31


  3%|▎         | 1/30 [00:29<14:23, 29.77s/it]
[I 2023-09-01 20:55:02,038] Trial 31 pruned. 


Starting trial 32


  3%|▎         | 1/30 [00:29<14:26, 29.88s/it]
[I 2023-09-01 20:55:32,125] Trial 32 pruned. 


Starting trial 33


  3%|▎         | 1/30 [00:29<14:20, 29.69s/it]
[I 2023-09-01 20:56:02,169] Trial 33 pruned. 


Starting trial 34


  3%|▎         | 1/30 [00:30<14:32, 30.10s/it]
[I 2023-09-01 20:56:32,514] Trial 34 pruned. 


Starting trial 35


  3%|▎         | 1/30 [00:29<14:23, 29.78s/it]
[I 2023-09-01 20:57:02,504] Trial 35 pruned. 


Starting trial 36


  3%|▎         | 1/30 [00:29<14:28, 29.96s/it]
[I 2023-09-01 20:57:32,651] Trial 36 pruned. 


Starting trial 37


  3%|▎         | 1/30 [00:29<14:26, 29.88s/it]
[I 2023-09-01 20:58:02,767] Trial 37 pruned. 


Starting trial 38


  3%|▎         | 1/30 [00:30<14:37, 30.26s/it]
[I 2023-09-01 20:58:33,258] Trial 38 pruned. 


Starting trial 39


  3%|▎         | 1/30 [00:29<14:12, 29.40s/it]
[I 2023-09-01 20:59:02,892] Trial 39 pruned. 


Starting trial 40


  3%|▎         | 1/30 [00:29<14:20, 29.67s/it]
[I 2023-09-01 20:59:32,760] Trial 40 pruned. 


Starting trial 41


  3%|▎         | 1/30 [00:29<14:14, 29.48s/it]
[I 2023-09-01 21:00:02,444] Trial 41 pruned. 


Starting trial 42


  3%|▎         | 1/30 [00:29<14:14, 29.45s/it]
[I 2023-09-01 21:00:32,134] Trial 42 pruned. 


Starting trial 43


  3%|▎         | 1/30 [00:29<14:21, 29.70s/it]
[I 2023-09-01 21:01:02,032] Trial 43 pruned. 


Starting trial 44


  3%|▎         | 1/30 [00:29<14:28, 29.95s/it]
[I 2023-09-01 21:01:32,190] Trial 44 pruned. 


Starting trial 45


  3%|▎         | 1/30 [00:29<14:29, 29.97s/it]
[I 2023-09-01 21:02:02,350] Trial 45 pruned. 


Starting trial 46


  3%|▎         | 1/30 [00:29<14:12, 29.38s/it]
[I 2023-09-01 21:02:31,936] Trial 46 pruned. 


Starting trial 47


  3%|▎         | 1/30 [00:29<14:20, 29.68s/it]
[I 2023-09-01 21:03:01,923] Trial 47 pruned. 


Starting trial 48


  3%|▎         | 1/30 [00:30<14:30, 30.02s/it]
[I 2023-09-01 21:03:32,148] Trial 48 pruned. 


Starting trial 49


  3%|▎         | 1/30 [00:29<14:14, 29.47s/it]
[I 2023-09-01 21:04:01,818] Trial 49 pruned. 


Starting trial 50


  3%|▎         | 1/30 [00:29<14:29, 30.00s/it]
[I 2023-09-01 21:04:32,027] Trial 50 pruned. 


Starting trial 51


  3%|▎         | 1/30 [00:29<14:15, 29.51s/it]
[I 2023-09-01 21:05:01,740] Trial 51 pruned. 


Starting trial 52


  3%|▎         | 1/30 [00:29<14:22, 29.76s/it]
[I 2023-09-01 21:05:31,707] Trial 52 pruned. 


Starting trial 53


  3%|▎         | 1/30 [00:30<14:47, 30.60s/it]
[I 2023-09-01 21:06:02,660] Trial 53 pruned. 


Starting trial 54


  3%|▎         | 1/30 [00:30<14:34, 30.16s/it]
[I 2023-09-01 21:06:33,029] Trial 54 pruned. 


Starting trial 55


  3%|▎         | 1/30 [00:29<14:29, 29.98s/it]
[I 2023-09-01 21:07:03,332] Trial 55 pruned. 


Starting trial 56


  3%|▎         | 1/30 [00:30<14:41, 30.39s/it]
[I 2023-09-01 21:07:33,938] Trial 56 pruned. 


Starting trial 57


  3%|▎         | 1/30 [00:30<14:49, 30.68s/it]
[I 2023-09-01 21:08:04,854] Trial 57 pruned. 


Starting trial 58


  3%|▎         | 1/30 [00:29<14:20, 29.69s/it]
[I 2023-09-01 21:08:34,730] Trial 58 pruned. 


Starting trial 59


  3%|▎         | 1/30 [00:30<14:30, 30.01s/it]
[I 2023-09-01 21:09:04,941] Trial 59 pruned. 


Starting trial 60


  3%|▎         | 1/30 [00:30<14:46, 30.58s/it]
[I 2023-09-01 21:09:35,873] Trial 60 pruned. 


Starting trial 61


  3%|▎         | 1/30 [00:30<14:41, 30.39s/it]
[I 2023-09-01 21:10:06,479] Trial 61 pruned. 


Starting trial 62


  3%|▎         | 1/30 [00:30<14:36, 30.22s/it]
[I 2023-09-01 21:10:36,905] Trial 62 pruned. 


Starting trial 63


  3%|▎         | 1/30 [00:30<14:47, 30.59s/it]
[I 2023-09-01 21:11:07,698] Trial 63 pruned. 


Starting trial 64


  3%|▎         | 1/30 [00:30<14:53, 30.83s/it]
[I 2023-09-01 21:11:38,720] Trial 64 pruned. 


Starting trial 65


  3%|▎         | 1/30 [00:30<14:33, 30.11s/it]
[I 2023-09-01 21:12:09,200] Trial 65 pruned. 


Starting trial 66


  3%|▎         | 1/30 [00:29<14:26, 29.88s/it]
[I 2023-09-01 21:12:39,312] Trial 66 pruned. 


Starting trial 67


  3%|▎         | 1/30 [00:30<14:34, 30.16s/it]
[I 2023-09-01 21:13:09,663] Trial 67 pruned. 


Starting trial 68


  3%|▎         | 1/30 [00:29<14:23, 29.79s/it]
[I 2023-09-01 21:13:39,674] Trial 68 pruned. 


Starting trial 69


  3%|▎         | 1/30 [00:31<15:05, 31.21s/it]
[I 2023-09-01 21:14:11,227] Trial 69 pruned. 


Starting trial 70


  3%|▎         | 1/30 [00:30<14:31, 30.06s/it]
[I 2023-09-01 21:14:41,493] Trial 70 pruned. 


Starting trial 71


  3%|▎         | 1/30 [00:29<14:25, 29.84s/it]
[I 2023-09-01 21:15:11,675] Trial 71 pruned. 


Starting trial 72


  3%|▎         | 1/30 [00:29<14:16, 29.54s/it]
[I 2023-09-01 21:15:41,421] Trial 72 pruned. 


Starting trial 73


  3%|▎         | 1/30 [00:29<14:21, 29.72s/it]
[I 2023-09-01 21:16:11,341] Trial 73 pruned. 


Starting trial 74


  3%|▎         | 1/30 [00:30<14:38, 30.29s/it]
[I 2023-09-01 21:16:41,831] Trial 74 pruned. 


Starting trial 75


  3%|▎         | 1/30 [00:30<14:38, 30.30s/it]
[I 2023-09-01 21:17:12,482] Trial 75 pruned. 


Starting trial 76


  3%|▎         | 1/30 [00:30<14:33, 30.13s/it]
[I 2023-09-01 21:17:42,846] Trial 76 pruned. 


Starting trial 77


  3%|▎         | 1/30 [00:30<14:37, 30.26s/it]
[I 2023-09-01 21:18:13,293] Trial 77 pruned. 


Starting trial 78


  3%|▎         | 1/30 [00:29<14:23, 29.76s/it]
[I 2023-09-01 21:18:43,251] Trial 78 pruned. 


Starting trial 79


  3%|▎         | 1/30 [00:30<14:40, 30.35s/it]
[I 2023-09-01 21:19:13,800] Trial 79 pruned. 


Starting trial 80


  3%|▎         | 1/30 [00:29<14:29, 29.97s/it]
[I 2023-09-01 21:19:43,970] Trial 80 pruned. 


Starting trial 81


  3%|▎         | 1/30 [00:29<14:27, 29.92s/it]
[I 2023-09-01 21:20:14,226] Trial 81 pruned. 


Starting trial 82


  3%|▎         | 1/30 [00:30<14:31, 30.06s/it]
[I 2023-09-01 21:20:44,637] Trial 82 pruned. 


Starting trial 83


  3%|▎         | 1/30 [00:29<14:21, 29.72s/it]
[I 2023-09-01 21:21:14,574] Trial 83 pruned. 


Starting trial 84


  3%|▎         | 1/30 [00:30<14:32, 30.08s/it]
[I 2023-09-01 21:21:44,908] Trial 84 pruned. 


Starting trial 85


  3%|▎         | 1/30 [00:30<14:35, 30.18s/it]
[I 2023-09-01 21:22:15,293] Trial 85 pruned. 


Starting trial 86


  3%|▎         | 1/30 [00:30<14:41, 30.39s/it]
[I 2023-09-01 21:22:45,878] Trial 86 pruned. 


Starting trial 87


  3%|▎         | 1/30 [00:30<14:50, 30.72s/it]
[I 2023-09-01 21:23:16,931] Trial 87 pruned. 


Starting trial 88


  3%|▎         | 1/30 [00:29<14:11, 29.36s/it]
[I 2023-09-01 21:23:46,573] Trial 88 pruned. 


Starting trial 89


  3%|▎         | 1/30 [00:30<14:38, 30.28s/it]
[I 2023-09-01 21:24:17,144] Trial 89 pruned. 


Starting trial 90


  3%|▎         | 1/30 [00:30<14:36, 30.22s/it]
[I 2023-09-01 21:24:47,713] Trial 90 pruned. 


Starting trial 91


  3%|▎         | 1/30 [00:29<14:19, 29.62s/it]
[I 2023-09-01 21:25:17,538] Trial 91 pruned. 


Starting trial 92


  3%|▎         | 1/30 [00:31<15:00, 31.04s/it]
[I 2023-09-01 21:25:48,916] Trial 92 pruned. 


Starting trial 93


  3%|▎         | 1/30 [00:30<14:32, 30.08s/it]
[I 2023-09-01 21:26:19,219] Trial 93 pruned. 


Starting trial 94


  3%|▎         | 1/30 [00:29<14:17, 29.58s/it]
[I 2023-09-01 21:26:49,000] Trial 94 pruned. 


Starting trial 95


  3%|▎         | 1/30 [00:29<14:19, 29.65s/it]
[I 2023-09-01 21:27:18,988] Trial 95 pruned. 


Starting trial 96


  3%|▎         | 1/30 [00:30<14:43, 30.45s/it]
[I 2023-09-01 21:27:49,788] Trial 96 pruned. 


Starting trial 97


  3%|▎         | 1/30 [00:29<14:18, 29.59s/it]
[I 2023-09-01 21:28:19,712] Trial 97 pruned. 


Starting trial 98


  3%|▎         | 1/30 [00:29<14:28, 29.95s/it]
[I 2023-09-01 21:28:49,850] Trial 98 pruned. 


Starting trial 99


  3%|▎         | 1/30 [00:29<14:17, 29.57s/it]
[I 2023-09-01 21:29:19,658] Trial 99 pruned. 


Best trial:


ValueError: No trials are completed yet.