In [None]:
import numpy as np
import pandas as pd
from sklearn.cluster import DBSCAN, KMeans

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
from collections import Counter
from tqdm import tqdm

In [None]:
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
anno = pd.read_csv("../data/eBay/metadata/index.csv").to_dict()

In [None]:
paths = list(anno["IMAGE_PATH"].values())

In [None]:
labels = np.load("index_in22k_labels.npy")

In [None]:
image_feats = np.load("index_convnext384_feats.npy")

In [None]:
label_freq = Counter(labels)

In [None]:
label_freq.most_common()

In [None]:
class_index = np.where(labels == 10928)[0]

In [None]:
selected_image_feats = image_feats[class_index]

In [None]:
def dbscan(eps, feats):
    cluster = DBSCAN(eps=eps, min_samples=1, metric='cosine', n_jobs=-1)
    return cluster.fit_predict(feats)

def get_large_cluster(predict_labels, threshold=10):
    large_cluster = []
    for cluster_label, freq in Counter(predict_labels).most_common():
        if freq > threshold:
            large_cluster.append(cluster_label)
        else:
            break
    return large_cluster

def hierarchical_cluster(maintain_labels, predict_labels, feats, eps=0.5):
    large_cluster = get_large_cluster(predict_labels)
    for c in large_cluster:
        cluster_index = np.where(predict_labels == c)[0]
        new_feats = feats[cluster_index]
        new_predict_labels = dbscan(eps, new_feats)
        new_maintain_labels = [""] * len(cluster_index)
        for i, idx in enumerate(cluster_index):
            new_maintain_labels[i] += maintain_labels[idx] + "." + str(new_predict_labels[i])
        if eps > 0.01:
            new_maintain_labels = hierarchical_cluster(new_maintain_labels, new_predict_labels, new_feats, eps-0.01)
        for i, idx in enumerate(cluster_index):
            maintain_labels[idx] = new_maintain_labels[i]
    return maintain_labels
    
def do_cluster(selected_image_feats, eps=0.5):
    predict_labels = dbscan(eps, selected_image_feats)
    maintain_labels = [str(x) for x in predict_labels]
    maintain_labels = hierarchical_cluster(maintain_labels, predict_labels, selected_image_feats, eps-0.1)
    label_set = set(maintain_labels)
    reassign_dict = dict(zip(set(label_set), range(len(label_set))))
    predict_labels = [reassign_dict[x] for x in maintain_labels]
    return np.array(predict_labels)

In [None]:
predict_labels = do_cluster(selected_image_feats)

In [None]:
Counter(predict_labels).most_common()

In [None]:
show_predict_label = 100
show_index_index = np.where(predict_labels == show_predict_label)[0]
show_index = class_index[show_index_index]

for idx in show_index[:10]:
    image = Image.open("../data/eBay/Images/" + paths[idx])
    plt.figure()
    plt.imshow(image)

In [None]:
all_labels = [""] * len(image_feats)
for class_name, freq in tqdm(label_freq.most_common()):
    class_index = np.where(labels == class_name)[0]
    selected_image_feats = image_feats[class_index]
    predict_labels = do_cluster(selected_image_feats)
    for i, idx in enumerate(class_index):
        all_labels[idx] += str(class_name) + "." + str(predict_labels[i])

In [None]:
label_set = set(all_labels)
reassign_dict = dict(zip(set(label_set), range(len(label_set))))
reassign_all_labels = np.array([reassign_dict[x] for x in all_labels])

In [None]:
Counter(reassign_all_labels).most_common()

In [None]:
show_predict_label = 733912
show_index = np.where(reassign_all_labels == show_predict_label)[0]

for idx in show_index[:10]:
    image = Image.open("../data/eBay/Images/" + paths[idx])
    plt.figure()
    plt.imshow(image)

In [None]:
all_labels_freq = Counter(reassign_all_labels).most_common()
all_freq = [x[1] for x in all_labels_freq]

In [None]:
Counter(all_freq).most_common()

In [None]:
np.save("pseudo_index_ids.npy", reassign_all_labels)