In [1]:
from TinyImageNetLoader import TinyImageNetDataset
import torchvision.transforms as transforms
import torch

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

valset = TinyImageNetDataset("/datasets/tiny-imagenet-200", mode="val", transform=val_transform)
#print(next(enumerate(validation_loader)))

train_transform =  transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomRotation(20),
    transforms.RandomHorizontalFlip(0.5),
    normalize
])

trainset = TinyImageNetDataset("/datasets/tiny-imagenet-200", transform=train_transform)

Preloading val data...: 100%|██████████| 10000/10000 [01:09<00:00, 143.00it/s]
Preloading train data...: 100%|██████████| 100000/100000 [11:47<00:00, 141.37it/s]


In [2]:
validation_loader = torch.utils.data.DataLoader(
        valset, batch_size=256, shuffle=False, num_workers=4)
training_loader = torch.utils.data.DataLoader(
        trainset, batch_size=256, shuffle=True, num_workers=4)

In [7]:
import torch.nn as nn
eps=1e-10
def ratio_loss_fn(out, labels, predicate_matrix):
    out = out.view(-1, 1, NUM_FEATURES) # out is a batch of 1D binary vectors
    ANDed = out * predicate_matrix # AND operation
    diff = ANDed - out # Difference of ANDed and out => if equal, then out is a subset of its class' predicates

    entr_loss = torch.nn.CrossEntropyLoss()
    loss_cl = entr_loss(diff.sum(dim=2), labels) # Is "out" a subset of its class' predicates?

    batch_size = out.shape[0]

    out = out.view(-1, NUM_FEATURES)
    diff_square = (out - predicate_matrix[labels]).pow(2)
    
    false_positives = (out - predicate_matrix[labels] + diff_square).sum() / batch_size
    false_positives *= loss_cl.item() / (false_positives.item() + eps)

    missing_attr = (predicate_matrix[labels] - out + diff_square).sum() / batch_size
    missing_attr *= loss_cl.item() / (missing_attr.item() + eps)
    
    return loss_cl + false_positives/(missing_attr+eps) * FT_WEIGHT

def loss_fn(out, labels, predicate_matrix):
    out = out.view(-1, 1, NUM_FEATURES) # out is a batch of 1D binary vectors
    ANDed = out * predicate_matrix # AND operation
    diff = ANDed - out # Difference of ANDed and out => if equal, then out is a subset of its class' predicates

    entr_loss = torch.nn.CrossEntropyLoss()
    loss_cl = entr_loss(diff.sum(dim=2), labels) # Is "out" a subset of its class' predicates?

    batch_size = out.shape[0]

    out = out.view(-1, NUM_FEATURES)
    diff_square = (out - predicate_matrix[labels]).pow(2)
    
    false_positives = (out - predicate_matrix[labels] + diff_square).sum() / batch_size
    missing_attr = (predicate_matrix[labels] - out + diff_square).sum() / batch_size
    
    loss_ft = false_positives + missing_attr
    loss_ft *= loss_cl.item()/(loss_ft.item() + eps)
    
    return loss_cl + loss_ft * FT_WEIGHT

In [9]:
def train_one_epoch():
    running_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data["images"], data["labels"]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs, commit_loss, predicate_matrix = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels, predicate_matrix) + commit_loss
        loss.backward()

        # Adjust learning weights
        optimizer.step()
        
        scheduler.step()

        # Gather data and report
        running_loss += loss.item()

    return running_loss / (i+1)

In [11]:
from torchmetrics import Accuracy
import sys, os
sys.path.insert(0, "/".join(os.path.abspath('').split("/")[:-1]) + "/models")
print("/".join(os.path.abspath('').split("/")[:-1]) + "/models")
from ResnetAutoPredicates import ResExtr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

NUM_FEATURES = 48
NUM_CLASSES = 200
EPOCHS = 50
accuracy = Accuracy(task="multiclass", num_classes=NUM_CLASSES, top_k=1).to(device)

FT_WEIGHT = 0.5

model = ResExtr(NUM_FEATURES, NUM_CLASSES, pretrained=True).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 5e-4, epochs=EPOCHS, steps_per_epoch=len(training_loader))

best_stats = {
    "epoch": 0,
    "train_loss": 0,
    "val_loss": 0,
    "val_acc": 0,
    "fp": 0,
    "ma": 0,
    "oa": 0
}

from tqdm import tqdm
for epoch in tqdm(range(EPOCHS)):
    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch()

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()
    running_acc = 0.0
    running_false_positives = 0.0
    running_missing_attr = 0.0
    running_out_attributes= 0.0

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata["images"], vdata["labels"]
            vinputs = vinputs.to(device)
            vlabels = vlabels.to(device)
            voutputs, vcommit_loss, predicate_matrix = model(vinputs)
            vloss = loss_fn(voutputs, vlabels, predicate_matrix) + vcommit_loss
            running_vloss += vloss.item()
            voutputs = voutputs.view(-1, 1, NUM_FEATURES)
            ANDed = voutputs * predicate_matrix
            diff = ANDed - voutputs
            running_acc += accuracy(diff.sum(dim=2), vlabels)
            voutputs = voutputs.view(-1, NUM_FEATURES)
            running_false_positives += ((predicate_matrix[vlabels] - voutputs) == -1).sum() / voutputs.shape[0]
            running_missing_attr += ((voutputs - predicate_matrix[vlabels]) == -1).sum() / voutputs.shape[0]
            running_out_attributes += voutputs.sum() / voutputs.shape[0]

    avg_vloss = running_vloss / (i + 1)
    avg_acc = running_acc / (i + 1)
    avg_fp = running_false_positives / (i + 1)
    avg_ma = running_missing_attr / (i + 1)
    avg_oa = running_out_attributes / (i + 1)
    print(f"LOSS: {avg_vloss}, ACC: {avg_acc}, FP: {avg_fp}, MA: {avg_ma}, OA: {avg_oa}")
    #print(model.bin_quantize._codebook.embed)
    
    with open("TINRes18AutoPredData.csv", "a") as f:
        f.write(f"{epoch}, {avg_loss}, {avg_vloss}, {avg_acc}, {avg_fp}, {avg_ma}, {avg_oa}\n")

    if best_stats["val_acc"] < avg_acc:
        best_stats["epoch"] = epoch
        best_stats["train_loss"] = avg_loss
        best_stats["val_loss"] = avg_vloss
        best_stats["val_acc"] = avg_acc.item()
        best_stats["fp"] = avg_fp.item()
        best_stats["ma"] = avg_ma.item()
        best_stats["oa"] = avg_oa.item()


print(best_stats)

/notebooks/Concept_ZSL/src/models
Device: cuda


  2%|▏         | 1/50 [00:15<12:31, 15.33s/it]

LOSS: 8.444142484664917, ACC: 0.02158203162252903, FP: 4.129101753234863, MA: 12.23496150970459, OA: 6.562597751617432


  4%|▍         | 2/50 [00:31<12:30, 15.63s/it]

LOSS: 8.190205073356628, ACC: 0.02685546875, FP: 3.4750001430511475, MA: 12.28623104095459, OA: 5.738574504852295


  6%|▌         | 3/50 [00:47<12:25, 15.87s/it]

LOSS: 8.078891623020173, ACC: 0.04843750223517418, FP: 4.925390720367432, MA: 11.122949600219727, OA: 8.307324409484863


  8%|▊         | 4/50 [01:03<12:08, 15.83s/it]

LOSS: 7.4551433801651, ACC: 0.08837890625, FP: 3.82568359375, MA: 10.993847846984863, OA: 7.174609661102295


 10%|█         | 5/50 [01:19<12:03, 16.09s/it]

LOSS: 6.691835117340088, ACC: 0.17529296875, FP: 4.742285251617432, MA: 9.22324275970459, OA: 9.654296875


 12%|█▏        | 6/50 [01:35<11:47, 16.09s/it]

LOSS: 5.871543073654175, ACC: 0.27324220538139343, FP: 4.444043159484863, MA: 8.164941787719727, OA: 10.263964653015137


 14%|█▍        | 7/50 [01:51<11:32, 16.11s/it]

LOSS: 5.384763932228088, ACC: 0.34501954913139343, FP: 5.027050971984863, MA: 6.809668064117432, OA: 11.94345760345459


 16%|█▌        | 8/50 [02:08<11:19, 16.19s/it]

LOSS: 5.191639316082001, ACC: 0.3682617247104645, FP: 4.583691596984863, MA: 6.675683498382568, OA: 11.404589653015137


 18%|█▊        | 9/50 [02:23<10:57, 16.04s/it]

LOSS: 4.989289450645447, ACC: 0.4002929627895355, FP: 4.612011909484863, MA: 6.119238376617432, OA: 11.65683650970459


 20%|██        | 10/50 [02:39<10:37, 15.95s/it]

LOSS: 4.8132835030555725, ACC: 0.41914063692092896, FP: 4.661816596984863, MA: 5.631152629852295, OA: 11.863183975219727


 22%|██▏       | 11/50 [02:55<10:19, 15.88s/it]

LOSS: 4.875459432601929, ACC: 0.4149414002895355, FP: 4.610742092132568, MA: 5.494042873382568, OA: 11.5576171875


 24%|██▍       | 12/50 [03:11<10:06, 15.96s/it]

LOSS: 4.713110792636871, ACC: 0.43183594942092896, FP: 4.58740234375, MA: 5.201562404632568, OA: 11.55478572845459


 26%|██▌       | 13/50 [03:27<09:49, 15.93s/it]

LOSS: 4.700749492645263, ACC: 0.4315429627895355, FP: 4.663183689117432, MA: 4.941211223602295, OA: 11.637011528015137


 28%|██▊       | 14/50 [03:43<09:31, 15.88s/it]

LOSS: 4.587230551242828, ACC: 0.4505859315395355, FP: 4.381738185882568, MA: 4.704297065734863, OA: 11.265332221984863


 30%|███       | 15/50 [03:58<09:11, 15.76s/it]

LOSS: 4.554740685224533, ACC: 0.4449218809604645, FP: 4.580175876617432, MA: 4.433300971984863, OA: 11.461328506469727


 32%|███▏      | 16/50 [04:14<08:57, 15.80s/it]

LOSS: 4.630353832244873, ACC: 0.4359374940395355, FP: 4.303613185882568, MA: 4.472070217132568, OA: 11.01171875


 34%|███▍      | 17/50 [04:30<08:38, 15.72s/it]

LOSS: 4.577711844444275, ACC: 0.44550782442092896, FP: 4.354101657867432, MA: 4.339941501617432, OA: 11.0576171875


 36%|███▌      | 18/50 [04:45<08:22, 15.70s/it]

LOSS: 4.5131114482879635, ACC: 0.45361328125, FP: 4.488086223602295, MA: 4.080078125, OA: 11.256640434265137


 38%|███▊      | 19/50 [05:02<08:12, 15.90s/it]

LOSS: 4.32432461977005, ACC: 0.470703125, FP: 3.846972703933716, MA: 4.1669921875, OA: 10.635546684265137


 40%|████      | 20/50 [05:18<07:57, 15.93s/it]

LOSS: 4.3147553443908695, ACC: 0.4708007872104645, FP: 4.068749904632568, MA: 3.8646485805511475, OA: 10.947265625


 42%|████▏     | 21/50 [05:34<07:42, 15.94s/it]

LOSS: 4.319323050975799, ACC: 0.47089844942092896, FP: 4.111328125, MA: 3.7930665016174316, OA: 11.094629287719727


 44%|████▍     | 22/50 [05:49<07:23, 15.84s/it]

LOSS: 4.308204847574234, ACC: 0.47431641817092896, FP: 3.89453125, MA: 3.875195264816284, OA: 10.732226371765137


 46%|████▌     | 23/50 [06:05<07:03, 15.70s/it]

LOSS: 4.367187094688416, ACC: 0.4686523377895355, FP: 4.272070407867432, MA: 3.7007813453674316, OA: 11.3408203125


 48%|████▊     | 24/50 [06:20<06:48, 15.72s/it]

LOSS: 4.216474562883377, ACC: 0.48310548067092896, FP: 4.106738567352295, MA: 3.5819337368011475, OA: 11.211816787719727


 50%|█████     | 25/50 [06:36<06:34, 15.79s/it]

LOSS: 4.252878177165985, ACC: 0.4813476502895355, FP: 4.228417873382568, MA: 3.492480516433716, OA: 11.32871150970459


 52%|█████▏    | 26/50 [06:52<06:20, 15.84s/it]

LOSS: 4.221562844514847, ACC: 0.4830078184604645, FP: 4.05908203125, MA: 3.553417921066284, OA: 11.221972465515137


 54%|█████▍    | 27/50 [07:08<06:04, 15.86s/it]

LOSS: 4.2381778955459595, ACC: 0.4878906309604645, FP: 4.102832317352295, MA: 3.5125977993011475, OA: 11.314355850219727


 56%|█████▌    | 28/50 [07:24<05:48, 15.83s/it]

LOSS: 4.226000010967255, ACC: 0.49042969942092896, FP: 4.026172161102295, MA: 3.512988328933716, OA: 11.29248046875


 58%|█████▊    | 29/50 [07:39<05:30, 15.75s/it]

LOSS: 4.180812084674836, ACC: 0.49443361163139343, FP: 3.993457078933716, MA: 3.4765625, OA: 11.307909965515137


 60%|██████    | 30/50 [07:55<05:15, 15.76s/it]

LOSS: 4.213381725549698, ACC: 0.49189454317092896, FP: 4.040429592132568, MA: 3.470019578933716, OA: 11.37753963470459


 62%|██████▏   | 31/50 [08:11<05:00, 15.83s/it]

LOSS: 4.263414669036865, ACC: 0.48945313692092896, FP: 4.20458984375, MA: 3.463671922683716, OA: 11.585156440734863


 64%|██████▍   | 32/50 [08:27<04:45, 15.87s/it]

LOSS: 4.25976248383522, ACC: 0.49248048663139343, FP: 4.16162109375, MA: 3.422656297683716, OA: 11.565625190734863


 66%|██████▌   | 33/50 [08:44<04:34, 16.14s/it]

LOSS: 4.257013934850693, ACC: 0.48974609375, FP: 4.215527534484863, MA: 3.4125001430511475, OA: 11.661913871765137


 68%|██████▊   | 34/50 [09:00<04:17, 16.11s/it]

LOSS: 4.257445514202118, ACC: 0.49267578125, FP: 4.208301067352295, MA: 3.4068360328674316, OA: 11.623242378234863


 70%|███████   | 35/50 [09:16<03:59, 15.99s/it]

LOSS: 4.241709440946579, ACC: 0.49580079317092896, FP: 4.173633098602295, MA: 3.3609375953674316, OA: 11.590527534484863


 72%|███████▏  | 36/50 [09:31<03:40, 15.77s/it]

LOSS: 4.246704167127609, ACC: 0.49853515625, FP: 4.153613567352295, MA: 3.381542921066284, OA: 11.57431697845459


 74%|███████▍  | 37/50 [09:47<03:24, 15.77s/it]

LOSS: 4.266910475492478, ACC: 0.49296876788139343, FP: 4.15380859375, MA: 3.4232423305511475, OA: 11.57822322845459


 76%|███████▌  | 38/50 [10:02<03:07, 15.65s/it]

LOSS: 4.301922398805618, ACC: 0.49335938692092896, FP: 4.150488376617432, MA: 3.4583985805511475, OA: 11.536328315734863


 78%|███████▊  | 39/50 [10:18<02:51, 15.61s/it]

LOSS: 4.272210782766342, ACC: 0.49345704913139343, FP: 4.181445598602295, MA: 3.386425733566284, OA: 11.61679744720459


 80%|████████  | 40/50 [10:34<02:37, 15.72s/it]

LOSS: 4.276008987426758, ACC: 0.4930664002895355, FP: 4.160742282867432, MA: 3.4300782680511475, OA: 11.568554878234863


 82%|████████▏ | 41/50 [10:49<02:21, 15.70s/it]

LOSS: 4.288152605295181, ACC: 0.49335938692092896, FP: 4.173925876617432, MA: 3.417675733566284, OA: 11.589258193969727


 84%|████████▍ | 42/50 [11:06<02:07, 15.91s/it]

LOSS: 4.272332072257996, ACC: 0.4981445372104645, FP: 4.20947265625, MA: 3.3994140625, OA: 11.64794921875


 86%|████████▌ | 43/50 [11:21<01:50, 15.74s/it]

LOSS: 4.290291661024094, ACC: 0.49462890625, FP: 4.147070407867432, MA: 3.4325196743011475, OA: 11.531445503234863


 88%|████████▊ | 44/50 [11:36<01:33, 15.65s/it]

LOSS: 4.2811939179897305, ACC: 0.4959960877895355, FP: 4.170703411102295, MA: 3.4150390625, OA: 11.577441215515137


 90%|█████████ | 45/50 [11:54<01:20, 16.08s/it]

LOSS: 4.273339408636093, ACC: 0.4947265684604645, FP: 4.155957221984863, MA: 3.4090821743011475, OA: 11.570116996765137


 92%|█████████▏| 46/50 [12:10<01:04, 16.18s/it]

LOSS: 4.27292760014534, ACC: 0.49687501788139343, FP: 4.139550685882568, MA: 3.4195313453674316, OA: 11.548144340515137


 94%|█████████▍| 47/50 [12:26<00:48, 16.02s/it]

LOSS: 4.2805283784866335, ACC: 0.4971679747104645, FP: 4.138867378234863, MA: 3.4263672828674316, OA: 11.54062557220459


 96%|█████████▌| 48/50 [12:42<00:32, 16.05s/it]

LOSS: 4.295258915424347, ACC: 0.49287110567092896, FP: 4.099218845367432, MA: 3.4544923305511475, OA: 11.477734565734863


 98%|█████████▊| 49/50 [12:57<00:15, 15.94s/it]

LOSS: 4.274431455135345, ACC: 0.49462890625, FP: 4.175879001617432, MA: 3.3988282680511475, OA: 11.605175971984863


100%|██████████| 50/50 [13:13<00:00, 15.88s/it]

LOSS: 4.279475277662277, ACC: 0.4950195252895355, FP: 4.172265529632568, MA: 3.40625, OA: 11.603906631469727
{'epoch': 35, 'train_loss': 1.744486412428834, 'val_loss': 4.246704167127609, 'val_acc': 0.49853515625, 'fp': 4.153613567352295, 'ma': 3.381542921066284, 'oa': 11.57431697845459}





# Resnet Baseline

In [7]:
def train_one_epoch_baseline():
    running_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data["images"], data["labels"]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn_baseline(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()

    return running_loss / (i+1)

In [13]:
from torchmetrics import Accuracy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

NUM_CLASSES = 200
EPOCHS = 15
accuracy = Accuracy(task="multiclass", num_classes=NUM_CLASSES, top_k=1).to(device)

POS_FT_WEIGHT = 0
FT_WEIGHT = 0

from torchvision.models import resnet18, ResNet18_Weights
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(512, NUM_CLASSES)

model = model.to(device)

loss_fn_baseline = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, weight_decay=1e-5)

best_stats = {
    "epoch": 0,
    "train_loss": 0,
    "val_loss": 0,
    "val_acc": 0,
}

from tqdm import tqdm
for epoch in tqdm(range(EPOCHS)):
    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch_baseline()

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()
    running_acc = 0.0
    running_false_positives = 0.0

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata["images"], vdata["labels"]
            vinputs = vinputs.to(device)
            vlabels = vlabels.to(device)
            voutputs = model(vinputs)
            vloss = loss_fn_baseline(voutputs, vlabels)
            running_vloss += vloss.item()
            running_acc += accuracy(voutputs, vlabels)

    avg_vloss = running_vloss / (i + 1)
    avg_acc = running_acc / (i + 1)
    print(f"LOSS: {avg_vloss}, ACC: {avg_acc}")

    if best_stats["val_acc"] < avg_acc:
        best_stats["epoch"] = epoch
        best_stats["train_loss"] = avg_loss
        best_stats["val_loss"] = avg_vloss
        best_stats["val_acc"] = avg_acc.item()

print(best_stats)

Device: cuda


  7%|▋         | 1/15 [00:15<03:32, 15.20s/it]

LOSS: 2.7538333177566527, ACC: 0.3666932284832001


 13%|█▎        | 2/15 [00:30<03:16, 15.14s/it]

LOSS: 2.268373465538025, ACC: 0.45415934920310974


 20%|██        | 3/15 [00:45<03:02, 15.21s/it]

LOSS: 2.1400376319885255, ACC: 0.4812440574169159


 27%|██▋       | 4/15 [01:01<02:48, 15.34s/it]

LOSS: 1.9924053430557251, ACC: 0.5112384557723999


 33%|███▎      | 5/15 [01:16<02:33, 15.37s/it]

LOSS: 1.9608011841773987, ACC: 0.5201749801635742


 40%|████      | 6/15 [01:31<02:18, 15.38s/it]

LOSS: 1.9408017992973328, ACC: 0.5298649072647095


 47%|████▋     | 7/15 [01:47<02:04, 15.54s/it]

LOSS: 1.9406044483184814, ACC: 0.5277164578437805


 53%|█████▎    | 8/15 [02:02<01:47, 15.40s/it]

LOSS: 1.9414733171463012, ACC: 0.5363759398460388


 60%|██████    | 9/15 [02:18<01:32, 15.38s/it]

LOSS: 1.9837289929389954, ACC: 0.5289102792739868


 67%|██████▋   | 10/15 [02:33<01:16, 15.33s/it]

LOSS: 2.0138878226280212, ACC: 0.5291215181350708


 73%|███████▎  | 11/15 [02:48<01:01, 15.31s/it]

LOSS: 2.0528515219688415, ACC: 0.5269291996955872


 80%|████████  | 12/15 [03:03<00:45, 15.20s/it]

LOSS: 2.074764096736908, ACC: 0.5255321860313416


 87%|████████▋ | 13/15 [03:18<00:30, 15.17s/it]

LOSS: 2.1348794221878054, ACC: 0.5197624564170837


 93%|█████████▎| 14/15 [03:33<00:15, 15.18s/it]

LOSS: 2.1823304891586304, ACC: 0.523909866809845


100%|██████████| 15/15 [03:48<00:00, 15.26s/it]

LOSS: 2.1891473531723022, ACC: 0.528218686580658
{'epoch': 7, 'train_loss': 1.1221286131411183, 'val_loss': 1.9414733171463012, 'val_acc': 0.5363759398460388}





# Optuna

In [3]:
import torch.nn as nn

eps = 1e-10
def loss_fn_optuna(out, labels, predicate_matrix, NUM_FEATURES, FT_WEIGHT):
    out = out.view(-1, 1, NUM_FEATURES) # out is a batch of 1D binary vectors
    ANDed = out * predicate_matrix # AND operation
    diff = ANDed - out # Difference of ANDed and out => if equal, then out is a subset of its class' predicates

    entr_loss = nn.CrossEntropyLoss()
    loss_cl = entr_loss(diff.sum(dim=2), labels) # Is "out" a subset of its class' predicates?

    batch_size = out.shape[0]

    out = out.view(-1, NUM_FEATURES)
    diff_square = (out - predicate_matrix[labels]).pow(2)
    
    false_positives = (out - predicate_matrix[labels] + diff_square).sum() / batch_size
    missing_attr = (predicate_matrix[labels] - out + diff_square).sum() / batch_size
    
    loss_ft = (1 + false_positives + missing_attr)
    loss_ft *= loss_cl.item()/(loss_ft.item() + eps)
    
    return loss_cl + loss_ft * FT_WEIGHT

from torchmetrics import Accuracy
def train_one_epoch_optuna(model, optimizer, scheduler, NUM_FEATURES, FT_WEIGHT):
    running_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data["images"], data["labels"]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs, commit_loss, predicate_matrix = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn_optuna(outputs, labels, predicate_matrix, NUM_FEATURES, FT_WEIGHT) + commit_loss
        loss.backward()

        # Adjust learning weights
        optimizer.step()
        
        scheduler.step()

        # Gather data and report
        running_loss += loss.item()

    return running_loss / (i+1)

from tqdm import tqdm
from torch import optim

import sys, os
sys.path.insert(0, "/".join(os.path.abspath('').split("/")[:-1]) + "/models")
print("/".join(os.path.abspath('').split("/")[:-1]) + "/models")
from ResnetAutoPredicates import ResExtr

def objective(trial):
    global trial_num
    trial_num += 1
    print(f"Starting trial {trial_num}")
    NUM_FEATURES = trial.suggest_int("num_features", 4, 16)
    FT_WEIGHT = trial.suggest_float("ft_weight", 0, 1)
    # Generate the model.
    model = ResExtr(NUM_FEATURES*8, NUM_CLASSES, pretrained=True).to(device)

    EPOCHS = 50

    # Generate the optimizers.
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    max_lr = trial.suggest_float("max_lr", 1e-4, 4e-3, log=True)
    
    if lr > max_lr:
        raise optuna.TrialPruned()
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=EPOCHS, steps_per_epoch=len(training_loader))

    best_acc = 0.0

    for epoch in tqdm(range(EPOCHS)):
        # Make sure gradient tracking is on, and do a pass over the data
        model.train(True)
        _ = train_one_epoch_optuna(model, optimizer, scheduler, NUM_FEATURES*8, FT_WEIGHT)

        model.eval()
        running_acc = 0.0

        # Disable gradient computation and reduce memory consumption.
        with torch.no_grad():
            for i, vdata in enumerate(validation_loader):
                vinputs, vlabels = vdata["images"], vdata["labels"]
                vinputs = vinputs.to(device)
                vlabels = vlabels.to(device)
                voutputs, _, predicate_matrix = model(vinputs)
                voutputs = voutputs.view(-1, 1, NUM_FEATURES*8)
                ANDed = voutputs * predicate_matrix
                diff = ANDed - voutputs
                running_acc += accuracy(diff.sum(dim=2), vlabels)

        avg_acc = running_acc / (i + 1)

        if avg_acc > best_acc:
            best_acc = avg_acc

        if epoch == 30 and best_acc < 0.3:
            raise optuna.TrialPruned()
        elif epoch == 15 and best_acc < 0.1:
            raise optuna.TrialPruned()
    
    return best_acc

/notebooks/Concept_ZSL/src/models


In [None]:
import optuna

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

NUM_CLASSES = 200
accuracy = Accuracy(task="multiclass", num_classes=NUM_CLASSES, top_k=1).to(device)

trial_num = -1

study = optuna.create_study(direction="maximize", study_name='TIN-ResNet18-AutoPred', load_if_exists=True, storage='sqlite:///Optuna.db')
study.optimize(objective, n_trials=100)

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    if key == "num_features":
        print("    {}: {}".format(key, value*8))
    else:
        print("    {}: {}".format(key, value))

Device: cuda


[I 2023-10-09 10:03:35,516] Using an existing study with name 'TIN-ResNet18-AutoPred' instead of creating a new one.


Starting trial 0


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 255MB/s]
100%|██████████| 50/50 [13:15<00:00, 15.92s/it]
[I 2023-10-09 10:16:52,492] Trial 134 finished with value: 0.508007824420929 and parameters: {'num_features': 11, 'ft_weight': 0.52094269759797, 'lr': 1.0017170923451517e-05, 'max_lr': 0.0004997788710074081}. Best is trial 32 with value: 0.5235351920127869.


Starting trial 1


100%|██████████| 50/50 [13:01<00:00, 15.62s/it]
[I 2023-10-09 10:29:54,228] Trial 135 finished with value: 0.512011706829071 and parameters: {'num_features': 9, 'ft_weight': 0.4918199991960398, 'lr': 1.1914706449785521e-05, 'max_lr': 0.000354892605686437}. Best is trial 32 with value: 0.5235351920127869.


Starting trial 2


100%|██████████| 50/50 [12:45<00:00, 15.32s/it]
[I 2023-10-09 10:42:40,485] Trial 136 finished with value: 0.5064453482627869 and parameters: {'num_features': 9, 'ft_weight': 0.5406689145403181, 'lr': 1.3356989015631941e-05, 'max_lr': 0.0003168264428963777}. Best is trial 32 with value: 0.5235351920127869.


Starting trial 3


100%|██████████| 50/50 [13:05<00:00, 15.72s/it]
[I 2023-10-09 10:55:46,716] Trial 137 finished with value: 0.5146484375 and parameters: {'num_features': 14, 'ft_weight': 0.4630174432834556, 'lr': 1.09586677353307e-05, 'max_lr': 0.0004477327149481473}. Best is trial 32 with value: 0.5235351920127869.


Starting trial 4


100%|██████████| 50/50 [13:06<00:00, 15.74s/it]
[I 2023-10-09 11:08:54,366] Trial 138 finished with value: 0.5029296875 and parameters: {'num_features': 8, 'ft_weight': 0.5268294439754077, 'lr': 1.4954940285625902e-05, 'max_lr': 0.00035089089513476804}. Best is trial 32 with value: 0.5235351920127869.


Starting trial 5


100%|██████████| 50/50 [13:48<00:00, 16.58s/it]
[I 2023-10-09 11:22:43,747] Trial 139 finished with value: 0.5088867545127869 and parameters: {'num_features': 10, 'ft_weight': 0.496299707598186, 'lr': 1.2616341489098811e-05, 'max_lr': 0.0004067613027165046}. Best is trial 32 with value: 0.5235351920127869.


Starting trial 6


 20%|██        | 10/50 [02:46<11:02, 16.57s/it]