# 1. INIT - Import packages

In [None]:
import torch
import os
import sys
from pathlib import Path

file_dir = Path().absolute()
workspace_dir = os.path.dirname(file_dir)
sys.path.append(workspace_dir)
os.environ['CUDA_VISIBLE_DEVICES'] = '5'
print('Pytorch version :', torch.__version__)
print('CUDA version\t:', torch.version.cuda)
print('GPU\t\t:',torch.cuda.get_device_name())

In [2]:
from collections import OrderedDict
from classes import IMAGENET2012_CLASSES
import torchvision
from A01_ImageNet import model, utils, data
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import scipy.io
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.patches as mpatches
from torch import nn
from scipy.spatial.distance import jensenshannon

# 2. INIT - Define func

In [3]:
def get_test_stat(net, dataloader, use_feature = True, concept = None):
    Imagenet1k_test_stat = torch.zeros([1000])
    for X, y in dataloader:
        X = X.to('cuda')
        if concept != None: 
            symbol_batch = concept.repeat([X.shape[0], 1]).to('cuda')
        if use_feature:
            if concept != None:
                y_hat = net.feature_forward(X, symbol_batch)
            else:
                y_hat = net.feature_forward(X)
        else:
            if concept != None:
                y_hat = net(X, symbol_batch)
            else:
                y_hat = net(X)
        y_hat = y_hat.to('cpu')
        mask = y_hat[:, 0] < y_hat[:, 1]
        for j in range(len(y[mask])):
            Imagenet1k_test_stat[y[mask][j]] += 1
    return Imagenet1k_test_stat

In [4]:
def get_non_zero_values(input_tensor):
    return input_tensor[input_tensor != 0.0]

In [5]:
def get_entropy(normalized_prob, eps = None):
    if eps == None:
        non_zero_normalized_prob = get_non_zero_values(normalized_prob)
    else:
        non_zero_normalized_prob = torch.clamp(normalized_prob, min = eps)
    return -torch.sum(non_zero_normalized_prob * torch.log2(non_zero_normalized_prob))
    

# 3. EXECUTIONS - trained fe/cdp/ts, random symbol as entropy
data for Fig 2f blue density bar

In [None]:
net_global_id = 50
net = model.sea_net(symbol_size = 20, num_classes = 1000, fix_fe = True, fe_type = 'resnet50')
net.load_state_dict(torch.load(f'../Results/param/imagenet1k_ss20_fixfe_trail{net_global_id}.pt'), strict=False)
net.to('cuda')
net.eval()

imagenet1k_test_FeatureDataset = data.FeatureDataset("../Results/FeatureData/ImageNet1k_test_embeddings.pt", "../Results/FeatureData/ImageNet1k_test_indices.pt")
imagenet1k_test_FeatureLoader = DataLoader(
    imagenet1k_test_FeatureDataset, 
    batch_size = 512, 
    num_workers = 8, 
    shuffle = False
)
n_repeats = 1000

Imagenet1k_test_stat_trainedConfig = torch.zeros([n_repeats, 1000])

for i in range(n_repeats):
    print(f"Repeating random concept {i}")
    concept_i = torch.rand(net.symbol_set[0].shape)
    Imagenet1k_test_stat_trainedConfig.data[i] = get_test_stat(net, imagenet1k_test_FeatureLoader, use_feature = True, concept = concept_i)
    print(f'Counting number category: {Imagenet1k_test_stat_trainedConfig.data[i].sum()}')


Imagenet1k_test_stat_trainedConfig_T = Imagenet1k_test_stat_trainedConfig.T
torch.save(Imagenet1k_test_stat_trainedConfig, f"../Results/entropy_stat/Imagenet1k_test_stat_seanet{net_global_id}_trainedConfig.pt")


# 4. EXECUTIONS - random ts_net param as entropy
data for Fig 2f purple density bar

In [None]:
imagenet1k_test_FeatureDataset = data.FeatureDataset("../Results/FeatureData/ImageNet1k_test_embeddings.pt", "../Results/FeatureData/ImageNet1k_test_indices.pt")
imagenet1k_test_FeatureLoader = DataLoader(
    imagenet1k_test_FeatureDataset, 
    batch_size = 512, 
    num_workers = 4, 
    shuffle = False
)
n_repeats = 1000

Imagenet1k_test_stat_tsnet_randomConfig = torch.zeros([n_repeats, 1000])

for i in range(n_repeats):
    net = model.ts_net(fix_fe = True, fe_type = 'resnet50')
    net.to('cuda')
    net.eval()

    print(f"Repeating {i}")
    Imagenet1k_test_stat_tsnet_randomConfig.data[i] = get_test_stat(net, imagenet1k_test_FeatureLoader, use_feature = True)
    print(f'Counting number category: {Imagenet1k_test_stat_tsnet_randomConfig.data[i].sum()}')

torch.save(Imagenet1k_test_stat_tsnet_randomConfig, "../Results/entropy_stat/Imagenet1k_test_stat_tsnet_randomConfig.pt")