In [2]:
import pickle as pkl
import argparse
import os
import pickle
import random
import numpy as np
import torch
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import roc_auc_score

from pcbm.data import get_dataset
from pcbm.concepts import ConceptBank
from pcbm.models import PosthocLinearCBM, get_model
from pcbm.training_tools import load_or_compute_projections

UNIVERSAL_SEED = 2024
BATCH_SIZE = 64
NUM_WORKERS = 4
CONCEPT_BANK_PATH = "/home/ksas/Public/datasets/cifar10_concept_bank/multimodal_concept_clip:RN50_cifar10_recurse:1.pkl"
OUT_PUT_DIR_PATH = "exps/test"
CKPT_PATH = "data/ckpt/CIFAR_10/pcbm_cifar10__clip:RN50__multimodal_concept_clip:RN50_cifar10_recurse:1__lam:0.0002__alpha:0.99__seed:42.ckpt"
DATASET_PATH = "/home/ksas/Public/datasets/cifar10_concept_bank"
BACKBONE_NAME = "clip:ViT-B/32"
DEVICE = "cuda"

In [3]:
def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
set_random_seed(UNIVERSAL_SEED)

In [4]:
all_concepts = pkl.load(open(CONCEPT_BANK_PATH, 'rb'))
all_concept_names = list(all_concepts.keys())
print(f"Bank path: {CONCEPT_BANK_PATH}. {len(all_concept_names)} concepts will be used.")
concept_bank = ConceptBank(all_concepts, DEVICE)

import clip
clip_backbone_name = BACKBONE_NAME.split(":")[1]
backbone, preprocess = clip.load(clip_backbone_name, device=DEVICE, download_root="/home/ksas/Public/model_zoo/clip")
backbone = backbone.eval()
backbone = backbone.float()
model = None

backbone = backbone.to(DEVICE)
backbone.eval()

Bank path: /home/ksas/Public/datasets/cifar10_concept_bank/multimodal_concept_clip:RN50_cifar10_recurse:1.pkl. 175 concepts will be used.
Concept Bank is initialized.


CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [5]:
backbone.visual

VisionTransformer(
  (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
  (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (transformer): Transformer(
    (resblocks): Sequential(
      (0): ResidualAttentionBlock(
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): QuickGELU()
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
      (1): ResidualAttentionBlock(
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise

In [4]:
with open("cifar10_concept.txt", "w+") as input_stream:
    for idx, concept_name in enumerate(concept_bank.concept_info.concept_names):
        input_stream.write(f"{idx}\t-{concept_name}\n")

In [None]:
posthoc_layer:PosthocLinearCBM = torch.load(CKPT_PATH, map_location=DEVICE)
print(posthoc_layer.analyze_classifier(k=5))
print(posthoc_layer.names)
print(posthoc_layer.names.__len__())

In [None]:
from torchvision import datasets
from pcbm.learn_concepts_multimodal import *
trainset = datasets.CIFAR10(root=DATASET_PATH, train=True,
                            download=True, transform=preprocess)
testset = datasets.CIFAR10(root=DATASET_PATH, train=False,
                            download=True, transform=preprocess)
classes = trainset.classes
class_to_idx = {c: i for (i,c) in enumerate(classes)}
idx_to_class = {v: k for k, v in class_to_idx.items()}
train_loader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                        shuffle=True, num_workers=NUM_WORKERS)
test_loader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                    shuffle=False, num_workers=NUM_WORKERS)

In [None]:
def show_image(images:torch.Tensor):
    import torch
    import torchvision
    import matplotlib.pyplot as plt

    # 使用 torchvision.utils.make_grid 将 64 张图片排列成 8x8 的网格
    grid_img = torchvision.utils.make_grid(images, nrow=8, normalize=True)

    # 转换为 NumPy 格式以便用 matplotlib 显示
    plt.imshow(grid_img.permute(1, 2, 0))  # 转换为 [H, W, C]
    plt.axis('off')  # 隐藏坐标轴
    plt.show()

for idx, data in enumerate(train_loader):
    print(data.__len__())
    print(f"x: {data[0].size()}")
    print(f"y: {data[1].size()}")
    batch_X, batch_Y = data
    batch_X = batch_X.to(DEVICE)
    batch_Y = batch_Y.to(DEVICE)
    
    batch_X.requires_grad_(True)
    embeddings = backbone.encode_image(batch_X)
    projs = posthoc_layer.compute_dist(embeddings)
    predicted_Y = posthoc_layer.forward_projs(projs)
    accuracy = (predicted_Y.argmax(1) == batch_Y).float().mean().item()
    
    _, topk_indices = torch.topk(projs, 5, dim=1)
    topk_concept = [[posthoc_layer.names[idx] for idx in row] for row in topk_indices]

    
    show_image(batch_X.detach().cpu())
    print(f"embeddings: {embeddings.size()}")
    print(f"projections: {projs.size()}")
    print(f"predicted_Y: {predicted_Y.size()}")
    print(f"accuracy: {accuracy}")
    # accuracy_idx.append(accuracy)
    import pdb; pdb.set_trace()
    
# print(accuracy_idx)
    

In [None]:
import pickle as pkl
import os
from constants import dataset_cosntants
from pcbm.data.cub import CUBConceptDataset, get_concept_dicts
from pcbm.concepts import ConceptBank

CUB_CONCEPT_BANK_PATH =  "/home/ksas/Public/datasets/cub_concept_bank/cub_resnet18_cub_0.1_100.pkl"
DEVICE = "cuda"

TRAIN_PKL = os.path.join(dataset_cosntants.CUB_PROCESSED_DIR, "train.pkl")
metadata = pkl.load(open(TRAIN_PKL, "rb"))

concept_info = get_concept_dicts(metadata=metadata)
concept_info[0].keys()

In [None]:
print(concept_info.__len__())
print(concept_info[0][0].__len__())
print(concept_info[0][1].__len__())