In [None]:
import os
import json
import random
import warnings
import numpy as np
from tqdm import tqdm
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression

import torch
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from simclr_pretrained.resnet_wider import resnet50x1, resnet50x2, resnet50x4

In [None]:
tree_structure = "mintree"
path_to_imagenet = "/hdd2/datasets/imagenet/val/"
parameter_path = "./simclr_pretrained/resnet50-1x.pth"
num_of_train = 6000

### Load Tree Structure and Label Info

In [None]:
T = nx.Graph()
full_labels = []

with open('../imagenet_' + tree_structure + '.txt', 'r') as f:
    for line in f.readlines():
        nodes = line.split()
        for node in nodes:
            if node not in T:
                T.add_node(node)
        T.add_edge(*nodes)
        
leaves = [x for x in T.nodes() if T.degree(x) == 1]
full_labels = np.array(leaves)

In [None]:
f = open('./dir_label_name.json')
map_collection = json.load(f)
f.close()

### Load Data from ImageNet

In [None]:
val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(path_to_imagenet, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
    ])),
    batch_size=256, shuffle=True,
    num_workers=10, pin_memory=True)

### Load Parameters from Pretrained SimCLR 

In [None]:
model = resnet50x1()
sd = torch.load(parameter_path, map_location='cpu')
model.load_state_dict(sd["state_dict"])
model = model.to('cuda:0')

### Extract Image Embeddings from SimCLR

In [None]:
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

In [None]:
embedding_collection = []
ground_truth_collection = []
for i, (images, target) in enumerate(val_loader):
    ## load partial dataset for toy-exp, we can extend it to a larger number later ##
    if i > 49: break
    
    images_cuda = images.to("cuda:0")
    activation = {}
    model.eval()
    with torch.no_grad():
        ## get embeddings from encoder, right before linear projection ##
        model.avgpool.register_forward_hook(get_activation('avgpool'))
        output = model(images_cuda)
        embeddings = torch.squeeze(activation['avgpool']).cpu().detach().numpy()
        # print(embeddings.shape)
        embedding_collection.append(embeddings)
    ground_truth_collection.append(target.detach().numpy())

In [None]:
embedding_collection = np.array(embedding_collection).reshape(len(embedding_collection) * 256, 2048)
ground_truth_collection = np.array(ground_truth_collection).ravel()

In [None]:
embedding_collection.shape, ground_truth_collection.shape

### Sample 100 classes from 1000 classes

In [None]:
sampled_classes = np.random.choice(len(full_labels), 100, replace=False)

### Slice data into training data and testing data

In [None]:
train_X = embedding_collection[:num_of_train]
train_y = ground_truth_collection[:num_of_train]
test_X = embedding_collection[num_of_train:]
test_y = ground_truth_collection[num_of_train:]
train_X.shape, train_y.shape, test_X.shape, test_y.shape

### Train 100 Logistice Regression within One-vs-Rest Method

In [None]:
lr_model_collection = []
for s, selected_class in tqdm(enumerate(sampled_classes), total=len(sampled_classes)):
    # print(s, selected_class, end=" | ")
    X = train_X
    y = np.zeros(train_y.shape)
    index_in_class = np.where(selected_class == train_y)[0]
    y[index_in_class] = 1
    clf = OneVsRestClassifier(LogisticRegression(random_state=0)).fit(X, y)
    lr_model_collection.append(clf)

### Grab pred_prob

In [None]:
pred_prob_y_collection = []
for count, (sampled_class, lr) in tqdm(enumerate(zip(sampled_classes, lr_model_collection)), total=len(sampled_classes)):
    pred_prob_y = lr.predict_proba(test_X)
    pred_prob_y_collection.append(pred_prob_y[:, 1])

In [None]:
pred_prob_y_collection = np.array(pred_prob_y_collection)

In [None]:
pred_prob_y_collection = pred_prob_y_collection.T
pred_prob_y_collection.shape

### Argmax of pred_prob for generating prediction

In [None]:
prediction = sampled_classes[np.argmax(pred_prob_y_collection, axis=1)]

### Compute Average Squared Distance over all the testing data

In [None]:
length = dict(nx.all_pairs_shortest_path_length(T))

In [None]:
avg_squared_distance = 0
for pred, gt in zip(prediction, test_y):
    pred_loc = map_collection[str(pred)][0]
    gt_loc = map_collection[str(gt)][0]
    distance = length[pred_loc][gt_loc]
    avg_squared_distance += distance ** 2
avg_squared_distance / len(test_y)

---------------------

### With Proposed Label Model

In [None]:
def fréchet_variance(y, L, w, d):
    v = 0
    for i, sample_class in enumerate(L):
        y_loc = map_collection[str(y)][0]
        sample_class_loc = map_collection[str(sample_class)][0]
        distance = d[y_loc][sample_class_loc]
        v += w[i] * (distance ** 2)
    return v

def fréchet_mean(L, w, d):
    Y = np.arange(len(full_labels)
    return np.argmin([fréchet_variance(y, L, w, d) for y in Y])

In [None]:
prediction_w_label_model = []
for p_num, (pred_prob, gt) in tqdm(enumerate(zip(pred_prob_y_collection, test_y)), total=len(test_y)):
    argmin_y = fréchet_mean(sampled_classes, pred_prob, length)
    prediction_w_label_model.append(argmin_y)

In [None]:
avg_squared_distance_w_label_model = 0
for pred, gt in zip(prediction_w_label_model, test_y):
    pred_loc = map_collection[str(pred)][0]
    gt_loc = map_collection[str(gt)][0]
    distance = length[pred_loc][gt_loc]
    avg_squared_distance_w_label_model += distance ** 2
avg_squared_distance_w_label_model / len(test_y)