In [1]:
from TinyImageNetLoader import TinyImageNetDataset
import torchvision.transforms as transforms
import torch

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

valset = TinyImageNetDataset("/datasets/tiny-imagenet-200", mode="val", transform=val_transform)
#print(next(enumerate(validation_loader)))

train_transform =  transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomRotation(20),
    transforms.RandomHorizontalFlip(0.5),
    normalize
])

trainset = TinyImageNetDataset("/datasets/tiny-imagenet-200", transform=train_transform)

Preloading val data...: 100%|██████████| 10000/10000 [00:56<00:00, 177.04it/s]
Preloading train data...: 100%|██████████| 100000/100000 [09:47<00:00, 170.30it/s]


In [2]:
validation_loader = torch.utils.data.DataLoader(
        valset, batch_size=256, shuffle=False, num_workers=4)
training_loader = torch.utils.data.DataLoader(
        trainset, batch_size=256, shuffle=True, num_workers=4)

In [3]:
import torch.nn as nn
eps=1e-10
def loss_fn(out, labels, predicate_matrix):
    out = out.view(-1, 1, NUM_FEATURES) # out is a batch of 1D binary vectors
    ANDed = out * predicate_matrix # AND operation
    diff = ANDed - out # Difference of ANDed and out => if equal, then out is a subset of its class' predicates

    entr_loss = torch.nn.CrossEntropyLoss()
    loss_cl = entr_loss(diff.sum(dim=2), labels) # Is "out" a subset of its class' predicates?

    batch_size = out.shape[0]

    out = out.view(-1, NUM_FEATURES)
    diff_square = (out - predicate_matrix[labels]).pow(2)
    
    false_positives = (out - predicate_matrix[labels] + diff_square).sum() / batch_size
    missing_attr = (predicate_matrix[labels] - out + diff_square).sum() / batch_size
    
    loss_ft = (1 + false_positives + missing_attr)
    loss_ft *= loss_cl.item()/(loss_ft.item() + eps)
    
    return loss_cl + loss_ft * FT_WEIGHT

def train_one_epoch():
    running_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data["images"], data["labels"]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs, commit_loss, predicate_matrix = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels, predicate_matrix) + commit_loss
        loss.backward()

        # Adjust learning weights
        optimizer.step()
        
        scheduler.step()

        # Gather data and report
        running_loss += loss.item()

    return running_loss / (i+1)

In [4]:
from torchmetrics import Accuracy
import sys, os
sys.path.insert(0, "/".join(os.path.abspath('').split("/")[:-1]) + "/models")
print("/".join(os.path.abspath('').split("/")[:-1]) + "/models")
from ResnetAutoPredicates import ResExtr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

NUM_FEATURES = 96
NUM_CLASSES = 200
EPOCHS = 50
accuracy = Accuracy(task="multiclass", num_classes=NUM_CLASSES, top_k=1).to(device)

FT_WEIGHT = 0.3

model = ResExtr(NUM_FEATURES, NUM_CLASSES, pretrained=True).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-5, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 3e-4, epochs=EPOCHS, steps_per_epoch=len(training_loader))

best_stats = {
    "epoch": 0,
    "train_loss": 0,
    "val_loss": 0,
    "val_acc": 0,
    "fp": 0,
    "ma": 0,
    "oa": 0
}

from tqdm import tqdm
for epoch in tqdm(range(EPOCHS)):
    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch()

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()
    running_acc = 0.0
    running_false_positives = 0.0
    running_missing_attr = 0.0
    running_out_attributes= 0.0

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata["images"], vdata["labels"]
            vinputs = vinputs.to(device)
            vlabels = vlabels.to(device)
            voutputs, vcommit_loss, predicate_matrix = model(vinputs)
            vloss = loss_fn(voutputs, vlabels, predicate_matrix) + vcommit_loss
            running_vloss += vloss.item()
            voutputs = voutputs.view(-1, 1, NUM_FEATURES)
            ANDed = voutputs * predicate_matrix
            diff = ANDed - voutputs
            running_acc += accuracy(diff.sum(dim=2), vlabels)
            voutputs = voutputs.view(-1, NUM_FEATURES)
            running_false_positives += ((predicate_matrix[vlabels] - voutputs) == -1).sum() / voutputs.shape[0]
            running_missing_attr += ((voutputs - predicate_matrix[vlabels]) == -1).sum() / voutputs.shape[0]
            running_out_attributes += voutputs.sum() / voutputs.shape[0]

    avg_vloss = running_vloss / (i + 1)
    avg_acc = running_acc / (i + 1)
    avg_fp = running_false_positives / (i + 1)
    avg_ma = running_missing_attr / (i + 1)
    avg_oa = running_out_attributes / (i + 1)
    print(f"LOSS: {avg_vloss}, ACC: {avg_acc}, FP: {avg_fp}, MA: {avg_ma}, OA: {avg_oa}")
    #print(model.bin_quantize._codebook.embed)
    
    with open("TINRes18AutoPredData.csv", "a") as f:
        f.write(f"{epoch}, {avg_loss}, {avg_vloss}, {avg_acc}, {avg_fp}, {avg_ma}, {avg_oa}\n")

    if best_stats["val_acc"] < avg_acc:
        best_stats["epoch"] = epoch
        best_stats["train_loss"] = avg_loss
        best_stats["val_loss"] = avg_vloss
        best_stats["val_acc"] = avg_acc.item()
        best_stats["fp"] = avg_fp.item()
        best_stats["ma"] = avg_ma.item()
        best_stats["oa"] = avg_oa.item()


print(best_stats)

/notebooks/Concept_ZSL/src/models
Device: cuda


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

  2%|▏         | 1/50 [00:17<14:25, 17.67s/it]

LOSS: 7.322363233566284, ACC: 0.01953125, FP: 4.966992378234863, MA: 26.677051544189453, OA: 8.017969131469727


  4%|▍         | 2/50 [00:32<12:36, 15.77s/it]

LOSS: 7.182964086532593, ACC: 0.04521484300494194, FP: 7.729394435882568, MA: 24.5029296875, OA: 12.925683975219727


  6%|▌         | 3/50 [00:46<11:42, 14.95s/it]

LOSS: 6.502776825428009, ACC: 0.10527344048023224, FP: 9.28525447845459, MA: 22.32392692565918, OA: 16.520898818969727


  8%|▊         | 4/50 [01:00<11:12, 14.63s/it]

LOSS: 5.453660237789154, ACC: 0.21787109971046448, FP: 9.992383003234863, MA: 20.104785919189453, OA: 19.370508193969727


 10%|█         | 5/50 [01:14<10:50, 14.46s/it]

LOSS: 4.689117521047592, ACC: 0.31494140625, FP: 11.009081840515137, MA: 17.760547637939453, OA: 22.62451171875


 12%|█▏        | 6/50 [01:28<10:32, 14.37s/it]

LOSS: 4.310273921489715, ACC: 0.37519532442092896, FP: 11.26025390625, MA: 16.459278106689453, OA: 23.99043083190918


 14%|█▍        | 7/50 [01:42<10:16, 14.34s/it]

LOSS: 4.126537710428238, ACC: 0.40351563692092896, FP: 11.674023628234863, MA: 15.395800590515137, OA: 25.283105850219727


 16%|█▌        | 8/50 [01:56<09:57, 14.23s/it]

LOSS: 3.9896661043167114, ACC: 0.41943359375, FP: 11.70029354095459, MA: 14.807519912719727, OA: 25.73017692565918


 18%|█▊        | 9/50 [02:10<09:39, 14.14s/it]

LOSS: 3.8586939096450807, ACC: 0.4413085877895355, FP: 11.50732421875, MA: 14.50449275970459, OA: 25.619043350219727


 20%|██        | 10/50 [02:25<09:27, 14.20s/it]

LOSS: 3.8799319744110106, ACC: 0.44306641817092896, FP: 11.904296875, MA: 13.771387100219727, OA: 26.468847274780273


 22%|██▏       | 11/50 [02:39<09:13, 14.18s/it]

LOSS: 4.024730110168457, ACC: 0.43876954913139343, FP: 12.254590034484863, MA: 13.42441463470459, OA: 27.002538681030273


 24%|██▍       | 12/50 [02:53<09:00, 14.21s/it]

LOSS: 3.8272547304630278, ACC: 0.46269533038139343, FP: 12.196191787719727, MA: 12.987207412719727, OA: 27.2353515625


 26%|██▌       | 13/50 [03:07<08:44, 14.18s/it]

LOSS: 3.9421212315559386, ACC: 0.44921875, FP: 12.773340225219727, MA: 12.606348037719727, OA: 28.0048828125


 28%|██▊       | 14/50 [03:21<08:31, 14.20s/it]

LOSS: 3.887382960319519, ACC: 0.46503907442092896, FP: 12.453516006469727, MA: 12.369531631469727, OA: 27.77685546875


 30%|███       | 15/50 [03:35<08:15, 14.17s/it]

LOSS: 4.028937566280365, ACC: 0.4559570252895355, FP: 12.50839900970459, MA: 12.353222846984863, OA: 27.626367568969727


 32%|███▏      | 16/50 [03:50<08:01, 14.17s/it]

LOSS: 3.9360669791698455, ACC: 0.47021484375, FP: 12.50693416595459, MA: 11.823633193969727, OA: 27.93281364440918


 34%|███▍      | 17/50 [04:04<07:45, 14.11s/it]

LOSS: 4.025540095567703, ACC: 0.46269533038139343, FP: 13.066308975219727, MA: 11.47421932220459, OA: 28.643848419189453


 36%|███▌      | 18/50 [04:18<07:32, 14.15s/it]

LOSS: 4.042792159318924, ACC: 0.46123048663139343, FP: 12.893750190734863, MA: 11.47812557220459, OA: 28.334081649780273


 38%|███▊      | 19/50 [04:32<07:18, 14.15s/it]

LOSS: 4.0317851126194, ACC: 0.46533203125, FP: 12.993555068969727, MA: 11.27304744720459, OA: 28.554492950439453


 40%|████      | 20/50 [04:46<07:03, 14.11s/it]

LOSS: 4.065060120820999, ACC: 0.47265625, FP: 13.379297256469727, MA: 10.965234756469727, OA: 29.20849609375


 42%|████▏     | 21/50 [05:00<06:50, 14.15s/it]

LOSS: 4.1562106013298035, ACC: 0.46406251192092896, FP: 13.252344131469727, MA: 11.086133003234863, OA: 28.82783317565918


 44%|████▍     | 22/50 [05:15<06:39, 14.27s/it]

LOSS: 4.073684394359589, ACC: 0.4791015684604645, FP: 13.743456840515137, MA: 10.39990234375, OA: 29.940235137939453


 46%|████▌     | 23/50 [05:29<06:27, 14.34s/it]

LOSS: 4.234736800193787, ACC: 0.46416017413139343, FP: 14.61191463470459, MA: 10.221484184265137, OA: 30.859180450439453


 48%|████▊     | 24/50 [05:43<06:11, 14.27s/it]

LOSS: 4.185402125120163, ACC: 0.466796875, FP: 14.619140625, MA: 10.045605659484863, OA: 30.927051544189453


 50%|█████     | 25/50 [05:58<05:56, 14.25s/it]

LOSS: 4.216517388820648, ACC: 0.47490236163139343, FP: 15.017969131469727, MA: 9.82236385345459, OA: 31.589649200439453


 52%|█████▏    | 26/50 [06:12<05:40, 14.19s/it]

LOSS: 4.309566354751587, ACC: 0.4681640565395355, FP: 15.560254096984863, MA: 9.55918025970459, OA: 32.375587463378906


 54%|█████▍    | 27/50 [06:26<05:27, 14.23s/it]

LOSS: 4.1802359282970425, ACC: 0.48613283038139343, FP: 15.27451229095459, MA: 9.352930068969727, OA: 32.29413986206055


 56%|█████▌    | 28/50 [06:41<05:16, 14.40s/it]

LOSS: 4.24573615193367, ACC: 0.48066407442092896, FP: 15.753710746765137, MA: 9.17929744720459, OA: 32.82050704956055


 58%|█████▊    | 29/50 [06:55<05:02, 14.40s/it]

LOSS: 4.17163981795311, ACC: 0.4859375059604645, FP: 15.735058784484863, MA: 9.108105659484863, OA: 32.94921875


 60%|██████    | 30/50 [07:09<04:47, 14.36s/it]

LOSS: 4.2264452338218685, ACC: 0.4852539002895355, FP: 15.911913871765137, MA: 8.943066596984863, OA: 33.213478088378906


 62%|██████▏   | 31/50 [07:24<04:32, 14.35s/it]

LOSS: 4.260391068458557, ACC: 0.4828124940395355, FP: 16.01123046875, MA: 8.924219131469727, OA: 33.337989807128906


 64%|██████▍   | 32/50 [07:38<04:16, 14.25s/it]

LOSS: 4.263722062110901, ACC: 0.4854492247104645, FP: 15.852441787719727, MA: 8.90341854095459, OA: 33.24687576293945


 66%|██████▌   | 33/50 [07:52<04:02, 14.28s/it]

LOSS: 4.27256680727005, ACC: 0.4876953065395355, FP: 16.228515625, MA: 8.754199028015137, OA: 33.74580001831055


 68%|██████▊   | 34/50 [08:06<03:47, 14.24s/it]

LOSS: 4.236269271373748, ACC: 0.4940429627895355, FP: 15.689844131469727, MA: 8.781054496765137, OA: 33.152931213378906


 70%|███████   | 35/50 [08:20<03:32, 14.19s/it]

LOSS: 4.230468797683716, ACC: 0.4990234375, FP: 15.686327934265137, MA: 8.740918159484863, OA: 33.191505432128906


 72%|███████▏  | 36/50 [08:34<03:17, 14.13s/it]

LOSS: 4.234293764829635, ACC: 0.4996093809604645, FP: 15.340723037719727, MA: 8.837305068969727, OA: 32.7216796875


 74%|███████▍  | 37/50 [08:49<03:03, 14.14s/it]

LOSS: 4.248275631666184, ACC: 0.4991210997104645, FP: 15.340527534484863, MA: 8.809473037719727, OA: 32.76884841918945


 76%|███████▌  | 38/50 [09:03<02:49, 14.10s/it]

LOSS: 4.235264670848847, ACC: 0.501269519329071, FP: 14.71826171875, MA: 8.983301162719727, OA: 31.959569931030273


 78%|███████▊  | 39/50 [09:17<02:35, 14.12s/it]

LOSS: 4.210160672664642, ACC: 0.49980470538139343, FP: 14.33017635345459, MA: 9.01718807220459, OA: 31.53271484375


 80%|████████  | 40/50 [09:31<02:22, 14.30s/it]

LOSS: 4.251973289251327, ACC: 0.502636730670929, FP: 14.13291072845459, MA: 9.086230278015137, OA: 31.260059356689453


 82%|████████▏ | 41/50 [09:46<02:08, 14.28s/it]

LOSS: 4.228124785423279, ACC: 0.500683605670929, FP: 13.61386775970459, MA: 9.289355278015137, OA: 30.537891387939453


 84%|████████▍ | 42/50 [10:00<01:53, 14.24s/it]

LOSS: 4.240718185901642, ACC: 0.504101574420929, FP: 13.487109184265137, MA: 9.294921875, OA: 30.400684356689453


 86%|████████▌ | 43/50 [10:14<01:39, 14.16s/it]

LOSS: 4.22068138718605, ACC: 0.503125011920929, FP: 13.04599666595459, MA: 9.46347713470459, OA: 29.791015625


 88%|████████▊ | 44/50 [10:28<01:25, 14.21s/it]

LOSS: 4.230005085468292, ACC: 0.504199206829071, FP: 12.82666015625, MA: 9.47314453125, OA: 29.56201171875


 90%|█████████ | 45/50 [10:42<01:10, 14.19s/it]

LOSS: 4.220492404699326, ACC: 0.5078125, FP: 12.753222465515137, MA: 9.493456840515137, OA: 29.46826171875


 92%|█████████▏| 46/50 [10:56<00:56, 14.10s/it]

LOSS: 4.216931927204132, ACC: 0.5086914300918579, FP: 12.486132621765137, MA: 9.591699600219727, OA: 29.102930068969727


 94%|█████████▍| 47/50 [11:10<00:42, 14.07s/it]

LOSS: 4.226433050632477, ACC: 0.509765625, FP: 12.38525390625, MA: 9.628222465515137, OA: 28.96552848815918


 96%|█████████▌| 48/50 [11:24<00:28, 14.05s/it]

LOSS: 4.208900904655456, ACC: 0.5086914300918579, FP: 12.232617378234863, MA: 9.692675590515137, OA: 28.748437881469727


 98%|█████████▊| 49/50 [11:38<00:14, 14.05s/it]

LOSS: 4.199996173381805, ACC: 0.5062500238418579, FP: 12.323144912719727, MA: 9.62451171875, OA: 28.907129287719727


100%|██████████| 50/50 [11:53<00:00, 14.26s/it]

LOSS: 4.222430622577667, ACC: 0.5067383050918579, FP: 12.328906059265137, MA: 9.63671875, OA: 28.900684356689453
{'epoch': 46, 'train_loss': 0.417494784992979, 'val_loss': 4.226433050632477, 'val_acc': 0.509765625, 'fp': 12.38525390625, 'ma': 9.628222465515137, 'oa': 28.96552848815918}





# Resnet Baseline

In [7]:
def train_one_epoch_baseline():
    running_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data["images"], data["labels"]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn_baseline(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()

    return running_loss / (i+1)

In [13]:
from torchmetrics import Accuracy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

NUM_CLASSES = 200
EPOCHS = 15
accuracy = Accuracy(task="multiclass", num_classes=NUM_CLASSES, top_k=1).to(device)

POS_FT_WEIGHT = 0
FT_WEIGHT = 0

from torchvision.models import resnet18, ResNet18_Weights
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(512, NUM_CLASSES)

model = model.to(device)

loss_fn_baseline = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, weight_decay=1e-5)

best_stats = {
    "epoch": 0,
    "train_loss": 0,
    "val_loss": 0,
    "val_acc": 0,
}

from tqdm import tqdm
for epoch in tqdm(range(EPOCHS)):
    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch_baseline()

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()
    running_acc = 0.0
    running_false_positives = 0.0

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata["images"], vdata["labels"]
            vinputs = vinputs.to(device)
            vlabels = vlabels.to(device)
            voutputs = model(vinputs)
            vloss = loss_fn_baseline(voutputs, vlabels)
            running_vloss += vloss.item()
            running_acc += accuracy(voutputs, vlabels)

    avg_vloss = running_vloss / (i + 1)
    avg_acc = running_acc / (i + 1)
    print(f"LOSS: {avg_vloss}, ACC: {avg_acc}")

    if best_stats["val_acc"] < avg_acc:
        best_stats["epoch"] = epoch
        best_stats["train_loss"] = avg_loss
        best_stats["val_loss"] = avg_vloss
        best_stats["val_acc"] = avg_acc.item()

print(best_stats)

Device: cuda


  7%|▋         | 1/15 [00:15<03:32, 15.20s/it]

LOSS: 2.7538333177566527, ACC: 0.3666932284832001


 13%|█▎        | 2/15 [00:30<03:16, 15.14s/it]

LOSS: 2.268373465538025, ACC: 0.45415934920310974


 20%|██        | 3/15 [00:45<03:02, 15.21s/it]

LOSS: 2.1400376319885255, ACC: 0.4812440574169159


 27%|██▋       | 4/15 [01:01<02:48, 15.34s/it]

LOSS: 1.9924053430557251, ACC: 0.5112384557723999


 33%|███▎      | 5/15 [01:16<02:33, 15.37s/it]

LOSS: 1.9608011841773987, ACC: 0.5201749801635742


 40%|████      | 6/15 [01:31<02:18, 15.38s/it]

LOSS: 1.9408017992973328, ACC: 0.5298649072647095


 47%|████▋     | 7/15 [01:47<02:04, 15.54s/it]

LOSS: 1.9406044483184814, ACC: 0.5277164578437805


 53%|█████▎    | 8/15 [02:02<01:47, 15.40s/it]

LOSS: 1.9414733171463012, ACC: 0.5363759398460388


 60%|██████    | 9/15 [02:18<01:32, 15.38s/it]

LOSS: 1.9837289929389954, ACC: 0.5289102792739868


 67%|██████▋   | 10/15 [02:33<01:16, 15.33s/it]

LOSS: 2.0138878226280212, ACC: 0.5291215181350708


 73%|███████▎  | 11/15 [02:48<01:01, 15.31s/it]

LOSS: 2.0528515219688415, ACC: 0.5269291996955872


 80%|████████  | 12/15 [03:03<00:45, 15.20s/it]

LOSS: 2.074764096736908, ACC: 0.5255321860313416


 87%|████████▋ | 13/15 [03:18<00:30, 15.17s/it]

LOSS: 2.1348794221878054, ACC: 0.5197624564170837


 93%|█████████▎| 14/15 [03:33<00:15, 15.18s/it]

LOSS: 2.1823304891586304, ACC: 0.523909866809845


100%|██████████| 15/15 [03:48<00:00, 15.26s/it]

LOSS: 2.1891473531723022, ACC: 0.528218686580658
{'epoch': 7, 'train_loss': 1.1221286131411183, 'val_loss': 1.9414733171463012, 'val_acc': 0.5363759398460388}





# Optuna

In [14]:
import torch.nn as nn

eps = 1e-10
def loss_fn_optuna(out, labels, predicate_matrix, NUM_FEATURES, FT_WEIGHT):
    out = out.view(-1, 1, NUM_FEATURES) # out is a batch of 1D binary vectors
    ANDed = out * predicate_matrix # AND operation
    diff = ANDed - out # Difference of ANDed and out => if equal, then out is a subset of its class' predicates

    entr_loss = nn.CrossEntropyLoss()
    loss_cl = entr_loss(diff.sum(dim=2), labels) # Is "out" a subset of its class' predicates?

    batch_size = out.shape[0]

    out = out.view(-1, NUM_FEATURES)
    diff_square = (out - predicate_matrix[labels]).pow(2)
    
    false_positives = (out - predicate_matrix[labels] + diff_square).sum() / batch_size
    missing_attr = (predicate_matrix[labels] - out + diff_square).sum() / batch_size
    
    loss_ft = (1 + false_positives + missing_attr)
    loss_ft *= loss_cl.item()/(loss_ft.item() + eps)
    
    return loss_cl + loss_ft * FT_WEIGHT

from torchmetrics import Accuracy
def train_one_epoch_optuna(model, optimizer, scheduler, NUM_FEATURES, FT_WEIGHT):
    running_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data["images"], data["labels"]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs, commit_loss, predicate_matrix = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn_optuna(outputs, labels, predicate_matrix, NUM_FEATURES, FT_WEIGHT) + commit_loss
        loss.backward()

        # Adjust learning weights
        optimizer.step()
        
        scheduler.step()

        # Gather data and report
        running_loss += loss.item()

    return running_loss / (i+1)

from tqdm import tqdm
from torch import optim

import sys, os
sys.path.insert(0, "/".join(os.path.abspath('').split("/")[:-1]) + "/models")
print("/".join(os.path.abspath('').split("/")[:-1]) + "/models")
from ResnetAutoPredicates import ResExtr

def objective(trial):
    global trial_num
    trial_num += 1
    print(f"Starting trial {trial_num}")
    NUM_FEATURES = trial.suggest_int("num_features", 1, 20)
    FT_WEIGHT = trial.suggest_float("ft_weight", 0, 1.5)
    # Generate the model.
    model = ResExtr(NUM_FEATURES*8, NUM_CLASSES, pretrained=True).to(device)

    EPOCHS = 30

    # Generate the optimizers.
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    max_lr = trial.suggest_float("max_lr", 1e-4, 1e-2, log=True)
    
    if lr > max_lr:
        raise optuna.TrialPruned()
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=EPOCHS, steps_per_epoch=len(training_loader))

    best_acc = 0.0

    for epoch in tqdm(range(EPOCHS)):
        # Make sure gradient tracking is on, and do a pass over the data
        model.train(True)
        _ = train_one_epoch_optuna(model, optimizer, scheduler, NUM_FEATURES*8, FT_WEIGHT)

        model.eval()
        running_acc = 0.0

        # Disable gradient computation and reduce memory consumption.
        with torch.no_grad():
            for i, vdata in enumerate(validation_loader):
                vinputs, vlabels = vdata["images"], vdata["labels"]
                vinputs = vinputs.to(device)
                vlabels = vlabels.to(device)
                voutputs, _, predicate_matrix = model(vinputs)
                voutputs = voutputs.view(-1, 1, NUM_FEATURES*8)
                ANDed = voutputs * predicate_matrix
                diff = ANDed - voutputs
                running_acc += accuracy(diff.sum(dim=2), vlabels)

        avg_acc = running_acc / (i + 1)

        if avg_acc > best_acc:
            best_acc = avg_acc

        if epoch == 20 and best_acc < 0.3:
            raise optuna.TrialPruned()
        elif epoch == 10 and best_acc < 0.1:
            raise optuna.TrialPruned()
    
    return best_acc

/notebooks/Concept_ZSL/src/models


In [None]:
import optuna

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

NUM_CLASSES = 200
accuracy = Accuracy(task="multiclass", num_classes=NUM_CLASSES, top_k=1).to(device)

trial_num = -1

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100)

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    if key == "num_features":
        print("    {}: {}".format(key, value*8))
    else:
        print("    {}: {}".format(key, value))

[I 2023-09-13 18:51:22,304] A new study created in memory with name: no-name-24b61a4d-d7df-4b50-b4cb-09e06ff841f4


Device: cuda
Starting trial 0


[I 2023-09-13 18:51:22,619] Trial 0 pruned. 


Starting trial 1


100%|██████████| 30/30 [10:35<00:00, 21.19s/it]
[I 2023-09-13 19:01:58,528] Trial 1 finished with value: 0.5078125 and parameters: {'num_features': 19, 'ft_weight': 0.3492758742432343, 'lr': 0.0001067411088044507, 'max_lr': 0.00018516693616002605}. Best is trial 1 with value: 0.5078125.


Starting trial 2


 33%|███▎      | 10/30 [03:53<07:46, 23.33s/it]
[I 2023-09-13 19:05:52,119] Trial 2 pruned. 


Starting trial 3


 33%|███▎      | 10/30 [03:53<07:47, 23.35s/it]
[I 2023-09-13 19:09:45,956] Trial 3 pruned. 


Starting trial 4


100%|██████████| 30/30 [10:36<00:00, 21.22s/it]
[I 2023-09-13 19:20:22,811] Trial 4 finished with value: 0.510449230670929 and parameters: {'num_features': 6, 'ft_weight': 0.47887075891673964, 'lr': 1.1045510111974618e-05, 'max_lr': 0.0011145443052630428}. Best is trial 4 with value: 0.510449230670929.


Starting trial 5


 33%|███▎      | 10/30 [03:53<07:47, 23.35s/it]
[I 2023-09-13 19:24:16,566] Trial 5 pruned. 


Starting trial 6


100%|██████████| 30/30 [10:38<00:00, 21.27s/it]
[I 2023-09-13 19:34:54,897] Trial 6 finished with value: 0.5086914300918579 and parameters: {'num_features': 8, 'ft_weight': 0.7993528867404869, 'lr': 0.0001289296030916872, 'max_lr': 0.00029815447545380514}. Best is trial 4 with value: 0.510449230670929.


Starting trial 7


100%|██████████| 30/30 [10:41<00:00, 21.40s/it]
[I 2023-09-13 19:45:37,128] Trial 7 finished with value: 0.49628907442092896 and parameters: {'num_features': 16, 'ft_weight': 0.10744107714112938, 'lr': 0.0001058197875842386, 'max_lr': 0.00010748508453160884}. Best is trial 4 with value: 0.510449230670929.


Starting trial 8


 33%|███▎      | 10/30 [03:55<07:50, 23.53s/it]
[I 2023-09-13 19:49:32,718] Trial 8 pruned. 


Starting trial 9


[I 2023-09-13 19:49:33,013] Trial 9 pruned. 


Starting trial 10


100%|██████████| 30/30 [10:40<00:00, 21.34s/it]
[I 2023-09-13 20:00:13,562] Trial 10 finished with value: 0.501953125 and parameters: {'num_features': 7, 'ft_weight': 0.0957160426435335, 'lr': 1.0239110141147658e-05, 'max_lr': 0.0017166738598938687}. Best is trial 4 with value: 0.510449230670929.


Starting trial 11


 33%|███▎      | 10/30 [03:54<07:49, 23.49s/it]
[I 2023-09-13 20:04:08,766] Trial 11 pruned. 


Starting trial 12


 33%|███▎      | 10/30 [03:54<07:48, 23.44s/it]
[I 2023-09-13 20:08:03,494] Trial 12 pruned. 


Starting trial 13


100%|██████████| 30/30 [10:39<00:00, 21.32s/it]
[I 2023-09-13 20:18:43,236] Trial 13 finished with value: 0.4766601622104645 and parameters: {'num_features': 6, 'ft_weight': 0.37625824879280484, 'lr': 1.0739615437837204e-05, 'max_lr': 0.003564466764971506}. Best is trial 4 with value: 0.510449230670929.


Starting trial 14


 33%|███▎      | 10/30 [03:55<07:50, 23.54s/it]
[I 2023-09-13 20:22:38,908] Trial 14 pruned. 


Starting trial 15


100%|██████████| 30/30 [10:42<00:00, 21.41s/it]
[I 2023-09-13 20:33:21,524] Trial 15 finished with value: 0.5052734613418579 and parameters: {'num_features': 10, 'ft_weight': 0.7746679968026929, 'lr': 5.98762584874233e-05, 'max_lr': 0.00029431831644941855}. Best is trial 4 with value: 0.510449230670929.


Starting trial 16


100%|██████████| 30/30 [10:42<00:00, 21.41s/it]
[I 2023-09-13 20:44:04,277] Trial 16 finished with value: 0.5049805045127869 and parameters: {'num_features': 5, 'ft_weight': 0.46921600625763543, 'lr': 2.2152353783693883e-05, 'max_lr': 0.0011329154115039698}. Best is trial 4 with value: 0.510449230670929.


Starting trial 17


 33%|███▎      | 10/30 [03:55<07:50, 23.53s/it]
[I 2023-09-13 20:47:59,898] Trial 17 pruned. 


Starting trial 18


100%|██████████| 30/30 [10:42<00:00, 21.41s/it]
[I 2023-09-13 20:58:42,654] Trial 18 finished with value: 0.507031261920929 and parameters: {'num_features': 12, 'ft_weight': 0.628801171916768, 'lr': 0.00019889004338291528, 'max_lr': 0.00026490238250082503}. Best is trial 4 with value: 0.510449230670929.


Starting trial 19


 33%|███▎      | 10/30 [03:55<07:50, 23.51s/it]
[I 2023-09-13 21:02:38,111] Trial 19 pruned. 


Starting trial 20


100%|██████████| 30/30 [10:41<00:00, 21.37s/it]
[I 2023-09-13 21:13:19,526] Trial 20 finished with value: 0.5174804925918579 and parameters: {'num_features': 9, 'ft_weight': 0.4988642523052874, 'lr': 4.138993951704192e-05, 'max_lr': 0.0006375893990921109}. Best is trial 20 with value: 0.5174804925918579.


Starting trial 21


100%|██████████| 30/30 [10:42<00:00, 21.41s/it]
[I 2023-09-13 21:24:02,055] Trial 21 finished with value: 0.5118164420127869 and parameters: {'num_features': 9, 'ft_weight': 0.5107132840647499, 'lr': 3.2869187199493184e-05, 'max_lr': 0.0006267368555172368}. Best is trial 20 with value: 0.5174804925918579.


Starting trial 22


100%|██████████| 30/30 [10:42<00:00, 21.42s/it]
[I 2023-09-13 21:34:44,935] Trial 22 finished with value: 0.51220703125 and parameters: {'num_features': 13, 'ft_weight': 0.4345090447814265, 'lr': 1.8851694300685395e-05, 'max_lr': 0.0006350743999347075}. Best is trial 20 with value: 0.5174804925918579.


Starting trial 23


100%|██████████| 30/30 [10:42<00:00, 21.42s/it]
[I 2023-09-13 21:45:27,987] Trial 23 finished with value: 0.5078125 and parameters: {'num_features': 13, 'ft_weight': 0.23574697767916308, 'lr': 3.544436532376205e-05, 'max_lr': 0.0006387444514466879}. Best is trial 20 with value: 0.5174804925918579.


Starting trial 24


100%|██████████| 30/30 [10:41<00:00, 21.38s/it]
[I 2023-09-13 21:56:09,620] Trial 24 finished with value: 0.51953125 and parameters: {'num_features': 9, 'ft_weight': 0.5489135701425925, 'lr': 2.1280370595448995e-05, 'max_lr': 0.0004108982600260421}. Best is trial 24 with value: 0.51953125.


Starting trial 25


100%|██████████| 30/30 [10:41<00:00, 21.38s/it]
[I 2023-09-13 22:06:51,430] Trial 25 finished with value: 0.5189453363418579 and parameters: {'num_features': 16, 'ft_weight': 0.3929134306989532, 'lr': 1.8083932846622816e-05, 'max_lr': 0.0004407813400071536}. Best is trial 24 with value: 0.51953125.


Starting trial 26


100%|██████████| 30/30 [10:42<00:00, 21.42s/it]
[I 2023-09-13 22:17:34,227] Trial 26 finished with value: 0.5077148675918579 and parameters: {'num_features': 17, 'ft_weight': 0.20674300508343824, 'lr': 2.5386938203628418e-05, 'max_lr': 0.00042106603103250707}. Best is trial 24 with value: 0.51953125.


Starting trial 27


100%|██████████| 30/30 [10:42<00:00, 21.42s/it]
[I 2023-09-13 22:28:17,196] Trial 27 finished with value: 0.508593738079071 and parameters: {'num_features': 20, 'ft_weight': 0.34609377158907073, 'lr': 1.5123968964568208e-05, 'max_lr': 0.00043051412418218467}. Best is trial 24 with value: 0.51953125.


Starting trial 28


100%|██████████| 30/30 [10:41<00:00, 21.40s/it]
[I 2023-09-13 22:38:59,351] Trial 28 finished with value: 0.5044922232627869 and parameters: {'num_features': 15, 'ft_weight': 0.023909416662490357, 'lr': 4.5480827824692644e-05, 'max_lr': 0.0007707057479019298}. Best is trial 24 with value: 0.51953125.


Starting trial 29


100%|██████████| 30/30 [10:41<00:00, 21.39s/it]
[I 2023-09-13 22:49:41,275] Trial 29 finished with value: 0.5137695670127869 and parameters: {'num_features': 10, 'ft_weight': 0.5946338730624403, 'lr': 2.6464639161742053e-05, 'max_lr': 0.0005100630187383013}. Best is trial 24 with value: 0.51953125.


Starting trial 30


100%|██████████| 30/30 [10:42<00:00, 21.43s/it]
[I 2023-09-13 23:00:24,431] Trial 30 finished with value: 0.513671875 and parameters: {'num_features': 11, 'ft_weight': 0.5888691937818911, 'lr': 1.5123258500474641e-05, 'max_lr': 0.00042318793793346874}. Best is trial 24 with value: 0.51953125.


Starting trial 31


100%|██████████| 30/30 [10:43<00:00, 21.45s/it]
[I 2023-09-13 23:11:08,238] Trial 31 finished with value: 0.5137695670127869 and parameters: {'num_features': 10, 'ft_weight': 0.5881766030167181, 'lr': 2.4317090820634536e-05, 'max_lr': 0.0004613512200994929}. Best is trial 24 with value: 0.51953125.


Starting trial 32


100%|██████████| 30/30 [10:41<00:00, 21.40s/it]
[I 2023-09-13 23:21:50,466] Trial 32 finished with value: 0.51171875 and parameters: {'num_features': 9, 'ft_weight': 0.6664910675019375, 'lr': 2.714065790356308e-05, 'max_lr': 0.00023668322945178573}. Best is trial 24 with value: 0.51953125.


Starting trial 33


100%|██████████| 30/30 [10:42<00:00, 21.41s/it]
[I 2023-09-13 23:32:33,164] Trial 33 finished with value: 0.518261730670929 and parameters: {'num_features': 9, 'ft_weight': 0.40275608163547705, 'lr': 4.644609405095557e-05, 'max_lr': 0.0003572066404074793}. Best is trial 24 with value: 0.51953125.


Starting trial 34


100%|██████████| 30/30 [10:41<00:00, 21.37s/it]
[I 2023-09-13 23:43:14,564] Trial 34 finished with value: 0.5215820670127869 and parameters: {'num_features': 12, 'ft_weight': 0.39668827868753737, 'lr': 4.35589490084384e-05, 'max_lr': 0.000316898745521825}. Best is trial 34 with value: 0.5215820670127869.


Starting trial 35


100%|██████████| 30/30 [10:42<00:00, 21.43s/it]
[I 2023-09-13 23:53:57,685] Trial 35 finished with value: 0.511523425579071 and parameters: {'num_features': 15, 'ft_weight': 0.4213347379043827, 'lr': 5.01136579523928e-05, 'max_lr': 0.00020072101535965168}. Best is trial 34 with value: 0.5215820670127869.


Starting trial 36


100%|██████████| 30/30 [10:43<00:00, 21.43s/it]
[I 2023-09-14 00:04:41,031] Trial 36 finished with value: 0.5194336175918579 and parameters: {'num_features': 12, 'ft_weight': 0.3292573045256373, 'lr': 1.965055759710382e-05, 'max_lr': 0.00035719795370998383}. Best is trial 34 with value: 0.5215820670127869.


Starting trial 37


100%|██████████| 30/30 [10:42<00:00, 21.42s/it]
[I 2023-09-14 00:15:24,061] Trial 37 finished with value: 0.50341796875 and parameters: {'num_features': 12, 'ft_weight': 0.34206758678132837, 'lr': 1.8553623015476767e-05, 'max_lr': 0.00015653866343888877}. Best is trial 34 with value: 0.5215820670127869.


Starting trial 38


100%|██████████| 30/30 [10:41<00:00, 21.37s/it]
[I 2023-09-14 00:26:05,464] Trial 38 finished with value: 0.50732421875 and parameters: {'num_features': 18, 'ft_weight': 0.28722445073263325, 'lr': 1.2605226813785215e-05, 'max_lr': 0.00032380513698269546}. Best is trial 34 with value: 0.5215820670127869.


Starting trial 39


 87%|████████▋ | 26/30 [09:16<01:25, 21.48s/it]