In [1]:
from TinyImageNetLoader import TinyImageNetDataset
import torchvision.transforms as transforms
import torch

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
])

valset = TinyImageNetDataset("/datasets/tiny-imagenet-200", mode="val", transform=val_transform)
#print(next(enumerate(validation_loader)))

train_transform =  transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomRotation(20),
    transforms.RandomHorizontalFlip(0.5),
    transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
])

trainset = TinyImageNetDataset("/datasets/tiny-imagenet-200", transform=train_transform)

  from tqdm.autonotebook import tqdm


Preloading val data...:   0%|          | 0/10000 [00:00<?, ?it/s]

Preloading train data...:   0%|          | 0/100000 [00:00<?, ?it/s]

In [2]:
validation_loader = torch.utils.data.DataLoader(
        valset, batch_size=1024, shuffle=False, num_workers=4)
training_loader = torch.utils.data.DataLoader(
        trainset, batch_size=1024, shuffle=True, num_workers=4)

In [11]:
eps=1e-10
import torch.nn as nn
def loss_fn(out, labels, predicate_matrix):
    out = out.view(-1, 1, NUM_FEATURES) # out is a batch of 1D binary vectors
    ANDed = out * predicate_matrix # AND operation
    diff = ANDed - out # Difference of ANDed and out => if equal, then out is a subset of its class' predicates

    entr_loss = nn.CrossEntropyLoss()
    loss_cl = entr_loss(diff.sum(dim=2), labels) # Is "out" a subset of its class' predicates?

    batch_size = out.shape[0]

    classes = torch.zeros(batch_size, NUM_CLASSES, device="cuda")
    classes[torch.arange(batch_size), labels] = 1
    classes = classes.view(batch_size, NUM_CLASSES, 1).expand(batch_size, NUM_CLASSES, NUM_FEATURES)

    extra_features = out - predicate_matrix + (out - predicate_matrix).pow(2)

    loss_neg_ft = torch.masked_select(extra_features, (1-classes).bool()).view(-1, NUM_FEATURES).sum() / batch_size

    labels_predicate = predicate_matrix[labels]
    extra_features_in = torch.masked_select(extra_features, classes.bool()).view(-1, NUM_FEATURES)
    loss_pos_ft = (labels_predicate - out.view(batch_size, NUM_FEATURES) + extra_features_in/2).sum() / batch_size
    
    feature_losses = loss_neg_ft * FT_WEIGHT * loss_cl.item()/(loss_neg_ft.item() + eps) + loss_pos_ft * POS_FT_WEIGHT * loss_cl.item()/(loss_pos_ft.item() + eps)

    return loss_cl + feature_losses

def train_one_epoch():
    running_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data["images"], data["labels"]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs, commit_loss, predicate_matrix = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels, predicate_matrix) + commit_loss
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()

    return running_loss / (i+1)

In [10]:
from torchmetrics import Accuracy
import sys, os
sys.path.insert(0, "/".join(os.path.abspath('').split("/")[:-1]) + "/models")
print("/".join(os.path.abspath('').split("/")[:-1]) + "/models")
from ResnetAutoPredicates import ResExtr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

NUM_FEATURES = 448
NUM_CLASSES = 200
EPOCHS = 30
accuracy = Accuracy(task="multiclass", num_classes=NUM_CLASSES, top_k=1).to(device)

POS_FT_WEIGHT = 0
FT_WEIGHT = 0.1

model = ResExtr(NUM_FEATURES, NUM_CLASSES, pretrained=True).to(device)

model.resnet.fc = nn.Linear(512, NUM_FEATURES).to(device)
model.resnet.avgpool = nn.AdaptiveAvgPool2d(1).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, weight_decay=1e-5)

best_vloss = 1_000_000.

best_stats = {
    "epoch": 0,
    "train_loss": 0,
    "val_loss": 0,
    "val_acc": 0,
    "val_fp": 0,
}

from tqdm import tqdm
for epoch in tqdm(range(EPOCHS)):
    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch()

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()
    running_acc = 0.0
    running_false_positives = 0.0

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata["images"], vdata["labels"]
            vinputs = vinputs.to(device)
            vlabels = vlabels.to(device)
            voutputs, vcommit_loss, predicate_matrix = model(vinputs)
            vloss = loss_fn(voutputs, vlabels, predicate_matrix) + vcommit_loss
            running_vloss += vloss.item()
            voutputs = voutputs.view(-1, 1, NUM_FEATURES)
            ANDed = voutputs * predicate_matrix
            diff = ANDed - voutputs
            running_acc += accuracy(diff.sum(dim=2), vlabels)
            voutputs = voutputs.view(-1, NUM_FEATURES)
            running_false_positives += ((predicate_matrix[vlabels] - voutputs) == -1).sum() / voutputs.shape[0]

    avg_vloss = running_vloss / (i + 1)
    avg_acc = running_acc / (i + 1)
    avg_false_positives = running_false_positives / (i + 1)
    print(f"LOSS: {avg_vloss}, ACC: {avg_acc}, FP: {avg_false_positives}")
    print(model.bin_quantize._codebook.embed)

    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        best_stats["epoch"] = epoch
        best_stats["train_loss"] = avg_loss
        best_stats["val_loss"] = avg_vloss
        best_stats["val_acc"] = avg_acc.item()
        best_stats["val_fp"] = avg_false_positives.item()


print(best_stats)

/notebooks/Concept_ZSL/src/models
Device: cuda


  3%|▎         | 1/30 [00:16<08:07, 16.80s/it]

LOSS: 5.828151082992553, ACC: 0.004942602012306452, FP: 0.0
tensor([[[0.],
         [1.]]], device='cuda:0')


  7%|▋         | 2/30 [00:34<07:57, 17.05s/it]

LOSS: 5.828151082992553, ACC: 0.004942602012306452, FP: 0.0
tensor([[[0.],
         [1.]]], device='cuda:0')


  7%|▋         | 2/30 [00:38<08:57, 19.18s/it]


KeyboardInterrupt: 

# Resnet Baseline

In [None]:
def train_one_epoch_baseline():
    running_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data["images"], data["labels"]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn_baseline(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()

    return running_loss / (i+1)

In [None]:
from torchmetrics import Accuracy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

NUM_CLASSES = 200
EPOCHS = 30
accuracy = Accuracy(task="multiclass", num_classes=NUM_CLASSES, top_k=1).to(device)

POS_FT_WEIGHT = 0
FT_WEIGHT = 0

from torchvision.models import resnet18, ResNet18_Weights
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.avgpool = nn.AdaptiveAvgPool2d(1)
model.fc = nn.Linear(2048, NUM_CLASSES)

model = model.to(device)

loss_fn_baseline = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

best_vloss = 1_000_000.

best_stats = {
    "epoch": 0,
    "train_loss": 0,
    "val_loss": 0,
    "val_acc": 0,
}

from tqdm import tqdm
for epoch in tqdm(range(EPOCHS)):
    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch_baseline()

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()
    running_acc = 0.0
    running_false_positives = 0.0

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata["images"], vdata["labels"]
            vinputs = vinputs.to(device)
            vlabels = vlabels.to(device)
            voutputs = model(vinputs)
            vloss = loss_fn_baseline(voutputs, vlabels)
            running_vloss += vloss.item()
            running_acc += accuracy(voutputs, vlabels)

    avg_vloss = running_vloss / (i + 1)
    avg_acc = running_acc / (i + 1)
    print(f"LOSS: {avg_vloss}, ACC: {avg_acc}")

    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        best_stats["epoch"] = epoch
        best_stats["train_loss"] = avg_loss
        best_stats["val_loss"] = avg_vloss
        best_stats["val_acc"] = avg_acc.item()

print(best_stats)