In [51]:
import pickle as pkl
import numpy as np
import torch
import torch.nn as nn
idx2name = pkl.load(open('data/hierarchies/tieredimagenet/tieredimagenet_idx_to_name.pkl', 'rb'))
class2idx = pkl.load(open('data/hierarchies/tieredimagenet/tieredimagenet_class_to_idx.pkl', 'rb'))

In [52]:
from torchvision import transforms

In [53]:
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
t = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

In [54]:
# open ILSVRC2012_val_00047737.JPEG and apply the transform
from PIL import Image
img = Image.open('ILSVRC2012_val_00047737.JPEG')
img = t(img)

imgs_blurred = [img]

for sigma in np.linspace(1/10, 10, 100):
    t_blurred = transforms.GaussianBlur(kernel_size=61, sigma=sigma)
    img_blurred = t_blurred(img)
    imgs_blurred.append(img_blurred)
imgs_blurred = torch.stack(imgs_blurred)

In [55]:
from torchvision.models import (
    alexnet, AlexNet_Weights,
    convnext_tiny, ConvNeXt_Tiny_Weights,
    densenet121, DenseNet121_Weights,
    efficientnet_v2_s, EfficientNet_V2_S_Weights,
    inception_v3, Inception_V3_Weights,
    resnet18, ResNet18_Weights,
    swin_v2_t, Swin_V2_T_Weights,
    vgg11, VGG11_Weights,
    vit_b_16, ViT_B_16_Weights,
)

In [56]:
MODEL_REGISTRY = {
    'alexnet': (alexnet, AlexNet_Weights),
    'convnext_tiny': (convnext_tiny, ConvNeXt_Tiny_Weights),
    'densenet121': (densenet121, DenseNet121_Weights),
    'efficientnet_v2_s': (efficientnet_v2_s, EfficientNet_V2_S_Weights),
    'inception_v3': (inception_v3, Inception_V3_Weights),
    'resnet18': (resnet18, ResNet18_Weights),
    'swin_v2_t': (swin_v2_t, Swin_V2_T_Weights),
    'vgg11': (vgg11, VGG11_Weights),
    'vit_b_16': (vit_b_16, ViT_B_16_Weights),
}

In [57]:
def get_pretrained_model(model_name: str):
    """
    Return a model constructor and weights given a model name.
    """
    if model_name not in MODEL_REGISTRY:
        raise ValueError(f"Unknown model '{model_name}'. Available: {list(MODEL_REGISTRY)}")

    constructor, weights_cls = MODEL_REGISTRY[model_name]
    return constructor, weights_cls


In [58]:
model_constructor, weights_cls = get_pretrained_model('vgg11')
weights = weights_cls.IMAGENET1K_V1
model = model_constructor(weights=weights)

In [59]:
idx_mapping = pkl.load(open('data/hierarchies/tieredimagenet/tiredimagenet_corresponding_index.pkl', 'rb'))

In [60]:
model.classifier[6]

Linear(in_features=4096, out_features=1000, bias=True)

In [11]:
indices = [idx_mapping[key] for key in idx_mapping]
indices = torch.tensor(indices, dtype=torch.long)

# Get original weights and bias
old_weight = model.classifier[6].weight.data
old_bias = model.classifier[6].bias.data

# Prune
new_weight = old_weight[indices]
new_bias = old_bias[indices]

# Replace parameters
model.classifier[6].weight = nn.Parameter(new_weight)
model.classifier[6].bias = nn.Parameter(new_bias)
model.classifier[6].out_features = len(indices)

In [12]:
proba_leaves = model.forward(imgs_blurred).detach().numpy()

In [13]:
import hierulz.hierarchy.hierarchy as hhh
import importlib
importlib.reload(hhh)

<module 'hierulz.hierarchy.hierarchy' from '/home/infres/rplaud/hierarchical_decision_rules/hierulz/hierarchy/hierarchy.py'>

In [14]:
from hierulz.hierarchy import Hierarchy, load_hierarchy


In [15]:
h = load_hierarchy('data/hierarchies/tieredimagenet/tieredimagenet_hierarchy_idx.pkl')

In [16]:
proba_nodes = h.get_probas(proba_leaves)

In [17]:
from hierulz.metrics import load_metric, hFBetaScore, Accuracy
from hierulz.heuristics import TopDown, Plurality

In [18]:
m = hFBetaScore(hierarchy=h, beta=1.0)
argmax = Accuracy(hierarchy=h)
top_down_heuristic = TopDown(hierarchy=h)
plurality_heuristic = Plurality(hierarchy=h)

In [19]:
pred_opt = m.decode(proba_nodes)

In [20]:
pred_argmax = argmax.decode(proba_nodes)

In [21]:
pred_top_down = top_down_heuristic.decode(proba_nodes)
pred_plurality = plurality_heuristic.decode(proba_nodes)

In [45]:
def decode_pred(pred, h, idx2name):
    labels = np.where(pred)[0]
    pred = [h.root_idx]
    bool=True
    while bool:
        bool_i = False
        for l in pred:
            if l in h.leaves_idx:
                pass
            # del l from pred and add the children of l to pred
            else :
                print(h.hierarchy_dico_idx[l], set(labels))
                inter = list(set(h.hierarchy_dico_idx[l]).intersection(set(labels)))
                print(inter)
                if len(inter) > 0:
                    pred.remove(l)
                    pred.extend(inter)
                    bool_i = True
        bool = bool_i
    return [idx2name[p] for p in pred]


In [47]:
h.hierarchy_dico_idx[h.root_idx]

[747, 733]

In [46]:
decode_pred(pred_opt[1], h, idx2name)

[747, 733] {642, 804, 615, 745, 746, 657, 667, 636, 670}
[]


['physical_entity']

In [50]:
np.where(pred_opt[1]), h.root_idx

((array([615, 636, 642, 657, 667, 670, 745, 746, 804]),), 719)