In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import random
import numpy as np
from typing import *
import torch.nn as nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import trange
from torch.nn import CrossEntropyLoss, L1Loss, MSELoss
from torch.nn import functional as F
import faiss
import warnings
import data_utils
from eval_utils import cluster_metric
import torch.distributions.normal as normal
from model import Model, UD_constraint, CLIPModel
import random
import copy
from loss_utils import DistillLoss, consistency_loss, entropy, ContrastiveInfoNCELoss, mutual_information
from data_utils import NeighborsDataset, mine_nearest_neighbors

warnings.simplefilter("ignore")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def kmeans(X, cluster_num):
    print("Running K-means clustering...")
    d = X.shape[1]  
    kmeans = faiss.Kmeans(d, cluster_num, gpu=True, spherical=True, niter=300, nredo=20)
    X = X.astype(np.float32)
    kmeans.train(X)
    D, I = kmeans.index.search(X, 1)
    I = I.reshape(-1)
    print("K-means clustering finished.")
    return I 


dataset = "CIFAR-10"  
cluster_num = 10 
dataloader_train, dataloader_test = data_utils.get_dataloader(
    dataset=dataset, batch_size=1024
)


sample_image, _ = dataloader_train.dataset[0]
feature_dim = sample_image.size()

print("Feature dimension:", feature_dim)

model = CLIPModel(model_name="ViT-B/32").cuda()
model.eval()

print(dataset)


In [None]:
features = []   
labels = []
print("Inferencing features and labels for training set images...") 

if os.path.exists(f"/home/zixuanlin/data/data/{dataset}_image_embedding_train.npy") and os.path.exists(f"/home/zixuanlin/data/data/{dataset}_image_embedding_test.npy"):
    print("Found existing feature files, loading directly...")
    features = np.load(f"/home/zixuanlin/data/data/{dataset}_image_embedding_train.npy")
    labels = np.loadtxt(f"/home/zixuanlin/data/data/{dataset}_labels_train.txt")
    features_test = np.load(f"/home/zixuanlin/data/data/{dataset}_image_embedding_test.npy")
    labels_test = np.loadtxt(f"/home/zixuanlin/data/data/{dataset}_labels_test.txt")
    print(labels_test)
    print("Training set features shape:", features.shape, "Training set labels shape:", labels.shape)
    print("Test set features shape:", features_test.shape, "Test set labels shape:", labels_test.shape)
    print("Loading completed.")
    train_num = features.shape[0]
    print("Number of training samples:", train_num) 
else:
    for iteration, (x, y) in enumerate(dataloader_train):
        x = x.cuda()
        with torch.no_grad():
            feature = model.encode_image(x)
        features.append(feature.cpu().numpy())
        labels.append(y.numpy())
        if iteration % 10 == 0:
            print(f"[Iteration {iteration}/{len(dataloader_train)}]")

    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)
    print("Training set features shape:", features.shape, "Training set labels shape:", labels.shape)
    train_num = features.shape[0]
    print("Number of training samples:", train_num) 
    
    features_test = []
    labels_test = []
    print("Inferencing features and labels for test set images...")

    for iteration, (x, y) in enumerate(dataloader_test):
        x = x.cuda()
        with torch.no_grad():
            feature = model.encode_image(x)
        features_test.append(feature.cpu().numpy())
        labels_test.append(y.numpy())
        if iteration % 10 == 0:
            print(f"[Iteration {iteration}/{len(dataloader_test)}]")

    features_test = np.concatenate(features_test, axis=0)
    labels_test = np.concatenate(labels_test, axis=0)
    print("Test set features shape:", features_test.shape, "Test set labels shape:", labels_test.shape)
    
    if dataset == "CIFAR-20" or dataset == "CIFAR-20-test":
        coarse_label = [
            [72, 4, 95, 30, 55], [73, 32, 67, 91, 1], [92, 70, 82, 54, 62], [16, 61, 9, 10, 28],
            [51, 0, 53, 57, 83], [40, 39, 22, 87, 86], [20, 25, 94, 84, 5], [14, 24, 6, 7, 18],
            [43, 97, 42, 3, 88], [37, 17, 76, 12, 68], [49, 33, 71, 23, 60], [15, 21, 19, 31, 38],
            [75, 63, 66, 64, 34], [77, 26, 45, 99, 79], [11, 2, 35, 46, 98], [29, 93, 27, 78, 44],
            [65, 50, 74, 36, 80], [56, 52, 47, 59, 96], [8, 58, 90, 13, 48], [81, 69, 41, 89, 85]
        ]
        labels_copy = copy.deepcopy(labels)
        labels_test_copy = copy.deepcopy(labels_test)
        for i in range(20):
            for j in coarse_label[i]:
                labels[labels_copy == j] = i
                labels_test[labels_test_copy == j] = i
    
    np.save("/home/zixuanlin/data/data/" + dataset + "_image_embedding_train.npy", features)
    np.save("/home/zixuanlin/data/data/" + dataset + "_image_embedding_test.npy", features_test)
    np.savetxt("/home/zixuanlin/data/data/" + dataset + "_labels_train.txt", labels)
    np.savetxt("/home/zixuanlin/data/data/" + dataset + "_labels_test.txt", labels_test)
    
features_test = features_test / np.linalg.norm(features_test, axis=1, keepdims=True)
cluster_labels = kmeans(features_test, cluster_num)
cluster_metric(labels_test, cluster_labels)


In [None]:
if os.path.exists(f"/home/zixuanlin/data/data/{dataset}_retrieved_embedding.npy"):
    print("Found existing retrieved embedding file, loading...")
    retrieval_embedding = np.load(f"/home/zixuanlin/data/data/{dataset}_retrieved_embedding.npy")
else: 
    cluster_num = 166
    topK = 5

    nouns_embedding = np.load("/home/zixuanlin/data/data/nouns_embedding_ensemble.npy")
    nouns_embedding = nouns_embedding / np.linalg.norm(nouns_embedding, axis=1, keepdims=True)
    print("Noun embedding shape:", nouns_embedding.shape)
    
    images_embedding = np.load("/home/zixuanlin/data/data/" + dataset + "_image_embedding_train.npy")
    images_embedding = images_embedding / np.linalg.norm(images_embedding, axis=1, keepdims=True)

    nouns_embedding = torch.from_numpy(nouns_embedding).cuda().half()
    nouns_num = nouns_embedding.shape[0]
    images_embedding = torch.from_numpy(images_embedding).cuda().half()
    image_num = images_embedding.shape[0]

    try:
        preds = np.load("/home/zixuanlin/data/data/" + dataset + "_image_" + str(cluster_num) + "cluster.npy")
    except:
        preds = kmeans(images_embedding.cpu().numpy(), cluster_num)
        np.save("/home/zixuanlin/data/data/" + dataset + "_image_" + str(cluster_num) + "cluster.npy", preds)
        print("Please rerun the script.")
        exit()

    image_centers = torch.zeros((cluster_num, 512), dtype=torch.float16).cuda()
    for k in range(cluster_num):
        image_centers[k] = images_embedding[preds == k].mean(dim=0)
    image_centers = F.normalize(image_centers, dim=1)

    similarity = torch.matmul(image_centers, nouns_embedding.T)
    softmax_nouns = torch.softmax(similarity, dim=0).cpu().float()
    class_pred = torch.argmax(softmax_nouns, dim=0).long()

    selected_idx = torch.zeros_like(class_pred, dtype=torch.bool)
    for k in range(cluster_num):
        if (class_pred == k).sum() == 0:
            continue
        class_index = torch.where(class_pred == k)[0]
        softmax_class = softmax_nouns[:, class_index]
        confidence = softmax_class.max(dim=0)[0]
        rank = torch.argsort(confidence, descending=True)
        selected_idx[class_index[rank[:topK]]] = True
    selected_idx = selected_idx.cpu().numpy()

    print(selected_idx.sum(), "nouns selected.")

    nouns_embedding_selected = nouns_embedding[selected_idx]
    np.save("/home/zixuanlin/data/data/" + dataset + "_filtered_nouns_embedding.npy", nouns_embedding_selected.cpu().numpy())
    
    tau = 0.005

    nouns_embedding = np.load("/home/zixuanlin/data/data/" + dataset + "_filtered_nouns_embedding.npy")
    nouns_embedding = nouns_embedding / np.linalg.norm(nouns_embedding, axis=1, keepdims=True)
    images_embedding = np.load("/home/zixuanlin/data/data/" + dataset + "_image_embedding_train.npy")
    images_embedding = images_embedding / np.linalg.norm(images_embedding, axis=1, keepdims=True)

    nouns_embedding = torch.from_numpy(nouns_embedding).cuda().half()
    nouns_num = nouns_embedding.shape[0]
    images_embedding = torch.from_numpy(images_embedding).cuda().half()
    image_num = images_embedding.shape[0]

    retrieval_embeddings = []
    batch_size = 8192

    for i in range(image_num // batch_size + 1):
        start = i * batch_size
        end = start + batch_size
        if end > image_num:
            end = image_num
        images_batch = images_embedding[start:end]
        similarity = torch.matmul(images_embedding[start:end], nouns_embedding.T)
        similarity = torch.softmax(similarity / tau, dim=1)
        retrieval_embedding = (similarity @ nouns_embedding).cpu()
        retrieval_embeddings.append(retrieval_embedding)
        if i % 50 == 0:
            print(f"[Completed {i * batch_size}/{image_num}]")

    retrieval_embedding = torch.cat(retrieval_embeddings, dim=0).cuda().half()
    retrieval_embedding = F.normalize(retrieval_embedding, dim=1).cpu().numpy()
    print("Retrieved embedding shape:", retrieval_embedding.shape)
    np.save("/home/zixuanlin/data/data/" + dataset + "_retrieved_nouns_embedding.npy", retrieval_embedding)
    
    if dataset == "CIFAR-10" or dataset == "STL-10" or dataset == "ImageNet-10":
        cluster_num = 10
    elif dataset == "CIFAR-20":
        cluster_num = 20
    elif dataset == "food-101":
        cluster_num = 101
    elif dataset == "Oxford-102":
        cluster_num = 102
    else:
        raise NotImplementedError
    
    tau = 0.005

    images_embedding = np.load("/home/zixuanlin/data/data/" + dataset + "_image_embedding_test.npy")
    images_embedding = images_embedding / np.linalg.norm(images_embedding, axis=1, keepdims=True)
    labels = np.loadtxt("/home/zixuanlin/data/data/" + dataset + "_labels_test.txt")
    nouns_embedding = np.load("/home/zixuanlin/data/data/" + dataset + "_filtered_nouns_embedding.npy")
    nouns_embedding = nouns_embedding / np.linalg.norm(nouns_embedding, axis=1, keepdims=True)

    nouns_embedding = torch.from_numpy(nouns_embedding).cuda().half()
    nouns_num = nouns_embedding.shape[0]
    images_embedding = torch.from_numpy(images_embedding).cuda().half()
    image_num = images_embedding.shape[0]

    retrieval_embeddings = []
    batch_size = 256

    for i in range(image_num // batch_size + 1):
        start = i * batch_size
        end = start + batch_size
        if end > image_num:
            end = image_num
        images_batch = images_embedding[start:end]
        similarity = torch.matmul(images_embedding[start:end], nouns_embedding.T)
        similarity = torch.softmax(similarity / tau, dim=1)
        retrieval_embedding = (similarity @ nouns_embedding).cpu()
        retrieval_embeddings.append(retrieval_embedding)
        if i % 50 == 0:
            print(f"[Completed {i * batch_size}/{image_num}]")

    retrieval_embedding = torch.cat(retrieval_embeddings, dim=0).cuda().half()
    retrieval_embedding = F.normalize(retrieval_embedding, dim=1).cpu().numpy()
    images_embedding = images_embedding.cpu().numpy()
    concat_embedding = np.concatenate([images_embedding, retrieval_embedding], axis=1)
    np.save("/home/zixuanlin/data/data/" + dataset + "_retrieved_embedding.npy", retrieval_embedding)


In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(2)
             
num_heads = 2
output_dims = [cluster_num] * num_heads
model = Model(num_heads=num_heads, output_dims=output_dims, in_channel=512).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=10, verbose=True)
scaler = torch.cuda.amp.GradScaler()

criterion = torch.nn.CrossEntropyLoss()

prior_loc = torch.zeros(512).to(device)
prior_scale = torch.ones(512).to(device)
prior = normal.Normal(prior_loc, prior_scale)

max_ACC = 0

features_test = torch.tensor(features_test, dtype=torch.float32).to(device)
retrieval_embedding = torch.tensor(retrieval_embedding, dtype=torch.float32).to(device)
print("features_test shape:", features_test.shape)
print("retrieval_embedding shape:", retrieval_embedding.shape)

for epoch in trange(200):
    model.train()
    optimizer.zero_grad()

    with torch.cuda.amp.autocast():
        outputs_img, encoder_img = model(features_test, forward_pass='output_i')
        outputs_txt, encoder_txt = model(retrieval_embedding, forward_pass='output_i')

        loss1 = sum(mutual_information(outputs_img[i], outputs_txt[i]) for i in range(num_heads))

        z_img = model.encoder(features_test)
        z_txt = model.encoder(retrieval_embedding)
        z1 = z_img.rsample()
        z2 = z_txt.rsample()
        prior_sample = prior.sample()

        z1 = F.log_softmax(z1, dim=-1)
        z2 = F.log_softmax(z2, dim=-1)
        prior_sample = F.softmax(prior_sample, dim=-1)

        skl1 = torch.nn.functional.kl_div(z1, prior_sample, reduction='batchmean')
        skl2 = torch.nn.functional.kl_div(z2, prior_sample, reduction='batchmean')
        loss2 = skl1 + skl2

        if epoch % 1 == 0:
            with torch.no_grad():
                UDC = UD_constraint(model, features_test, num_heads)
                UDC = UDC.to(device)
        loss3 = criterion(outputs_img[-1], UDC) / 2

        loss_g =  loss1 + loss2 + 5 * loss3

    scaler.scale(loss_g).backward()
    scaler.step(optimizer)
    scaler.update()
    scheduler.step(loss_g)

    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            outputs_eval, _ = model(features_test, forward_pass='output_i')
            pre_label = outputs_eval[-1].argmax(dim=1).cpu().numpy()
            acc,_,_ = cluster_metric(labels_test, pre_label)
            print(f"Epoch {epoch}, Loss: {loss_g.item()}")
        if acc > max_ACC:
            max_ACC = acc
            # torch.save(model.state_dict(), f"model/Multi-head/{dataset}/model_{dataset}_{num_heads}_best.pth")
            print(f"Save model to model/Multi-head/{dataset}/model_{dataset}_{num_heads}_best.pth")
print(f"Best accuracy: {max_ACC}")


In [None]:
def infer(model, dataloader):
    model.eval()
    preds = []
    logits_image = []
    with torch.no_grad():
        for iter, (image) in enumerate(dataloader):
            image = image[0].cuda()
            outputs, _ = model(image)
            logit_image = outputs[-1]
            pred = torch.argmax(logit_image, dim=1).cpu().numpy()
            preds.append(pred)
            logits_image.append(logit_image.cpu().numpy())
    preds = np.concatenate(preds, axis=0)
    logits_image = np.concatenate(logits_image, axis=0)
    return preds, logits_image

if __name__ == "__main__":

    epochs = 500
    batch_size = 8192
    temperature = 0.5     
    topK = 5

    nouns_embedding = np.load("/home/zixuanlin/data/data/" + dataset + "_retrieved_nouns_embedding.npy")
    nouns_embedding = nouns_embedding / np.linalg.norm(nouns_embedding, axis=1, keepdims=True)
    images_embedding_train = np.load("/home/zixuanlin/data/data/" + dataset + "_image_embedding_train.npy")
    images_embedding_train = images_embedding_train / np.linalg.norm(images_embedding_train, axis=1, keepdims=True)
    images_embedding_test = np.load("/home/zixuanlin/data/data/" + dataset + "_image_embedding_test.npy")
    images_embedding_test = images_embedding_test / np.linalg.norm(images_embedding_test, axis=1, keepdims=True)
    labels_test = np.loadtxt("/home/zixuanlin/data/data/" + dataset + "_labels_test.txt")

    model.load_state_dict(torch.load(f"model/Multi-head/{dataset}/model_{dataset}_{num_heads}_best.pth"))

    for param in model.parameters():
        param.requires_grad = False

    for param in model.cluster_heads.parameters():
        param.requires_grad = True

    dataset_text_train = TensorDataset(torch.from_numpy(nouns_embedding).float())
    dataset_image_train = TensorDataset(torch.from_numpy(images_embedding_train).float())
    dataset_image_test = TensorDataset(torch.from_numpy(images_embedding_test).float())

    try:
        indices_text = np.load("/home/zixuanlin/data/data/" + dataset + "_indices" + str(topK) + "_text.npy")
        indices_image = np.load("/home/zixuanlin/data/data/" + dataset + "_indices" + str(topK) + "_image.npy")
        print("Pre-computed indices loaded.")
    except:
        indices_text = mine_nearest_neighbors(nouns_embedding, topk=topK)
        indices_image = mine_nearest_neighbors(images_embedding_train, topk=topK)
        np.save("/home/zixuanlin/data/data/" + dataset + "_indices" + str(topK) + "_text.npy", indices_text)
        np.save("/home/zixuanlin/data/data/" + dataset + "_indices" + str(topK) + "_image.npy", indices_image)
        print("Please rerun the script.")
        exit()

    data_set = NeighborsDataset(dataset_text_train, dataset_image_train, indices_text, indices_image)
    dataloader_train = DataLoader(data_set, batch_size=batch_size, shuffle=True, drop_last=True)
    dataloader_test = DataLoader(dataset_image_test, batch_size=batch_size, shuffle=False, drop_last=False)

    optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.999))
    distill_loss = DistillLoss(class_num=cluster_num, temperature=temperature)
    contrastive_loss = ContrastiveInfoNCELoss(temperature=temperature)

    def neighborhood_consistency_loss(logit_a, logit_b):
        diff = logit_a - logit_b
        loss = torch.mean(diff ** 2)
        return loss

    Macc = 0
    print("Start training...")
    for epoch in range(epochs):
        model.train()
        loss_distill_epoch = loss_consist_epoch = loss_entropy_epoch = loss_contrastive_epoch = 0
        for iter, (text, image, neigh_text, neigh_image) in enumerate(dataloader_train):
            image = image[0].cuda()
            neigh_image = neigh_image[0].cuda()

            outputs, _ = model(image)
            neigh_outputs, _ = model(neigh_image)

            logit_image_all = [output for output in outputs]
            logit_neigh_all = [output for output in neigh_outputs]

            loss_distill_all = 0
            for i in range(num_heads - 1):
                loss_distill_all += distill_loss(logit_image_all[i], logit_neigh_all[i])

            loss_consist_all = 0
            loss_entropy_all = 0
            for logit_image, logit_neigh in zip(logit_image_all, logit_neigh_all):
                loss_consist_all += consistency_loss(logit_image, logit_neigh)
                loss_entropy_all += entropy(logit_image)

            loss = loss_distill_all  + 1 * loss_consist_all - 10 * loss_entropy_all

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_distill_epoch += loss_distill_all.item()
            loss_entropy_epoch += loss_entropy_all.item()

            if (iter + 1) % 50 == 0 or iter + 1 == len(dataloader_train):
                print(
                    f"[Epoch {epoch+1}/{epochs}] [Iter {iter+1}/{len(dataloader_train)}] "
                    f"Loss Distill: {loss_distill_all.item():.4f} Loss Entropy: {loss_entropy_all.item():.4f}"
                )

        preds, confidences_image = infer(model, dataloader_test)
        acc, nmi, ari = cluster_metric(labels_test, preds)
        print(f"Epoch {epoch + 1}/{epochs} - ACC: {acc:.4f}, NMI: {nmi:.4f}, ARI: {ari:.4f}")
        if acc > Macc:
            Macc = acc

    print(f"Final Max Acc: {Macc}")
