In [1]:
import os
os.chdir("..")
import cv2
import sys
import cbm
import cbm_cfs
import json
import torch
import plots
import utils
import random
import argparse
import data_utils
import numpy as np
import torch as th
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import torch.distributed as dist
sys.path.append("../guided-diffusion")
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    NUM_CLASSES,
    # model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

trans_to_256= transforms.Compose([
   transforms.Resize((256, 256)),])
trans_to_224= transforms.Compose([
   transforms.Resize((224, 224)),])


In [2]:
def diffusion_defaults():
    """
    Defaults for image and classifier training.
    """
    return dict(
        learn_sigma=True,
        diffusion_steps=1000,
        noise_schedule="linear",
        timestep_respacing="250",
        use_kl=False,
        predict_xstart=False,
        rescale_timesteps=False,
        rescale_learned_sigmas=False,
    )


def classifier_defaults():
    """
    Defaults for classifier models.
    """
    return dict(
        image_size=256,
        classifier_use_fp16=False,
        classifier_width=128,
        classifier_depth=2,
        classifier_attention_resolutions="32,16,8",  # 16
        classifier_use_scale_shift_norm=True,  # False
        classifier_resblock_updown=True,  # False
        classifier_pool="attention",
    )


def model_and_diffusion_defaults():
    """
    Defaults for image training.
    """
    res = dict(
        image_size=256,
        num_channels=256,
        num_res_blocks=2,
        num_heads=4,
        num_heads_upsample=-1,
        num_head_channels=64,
        attention_resolutions="32,16,8",
        channel_mult="",
        dropout=0.0,
        class_cond=False,
        use_checkpoint=False,
        use_scale_shift_norm=True,
        resblock_updown=True,
        use_fp16=True,
        use_new_attention_order=False,
    )
    res.update(diffusion_defaults())
    return res


def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=1,
        batch_size=4,
        use_ddim=False,
        model_path="./guided_diffusion/models/256x256_diffusion_uncond.pt",
    )
    defaults.update(model_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser

In [3]:
args = create_argparser().parse_args([])
print(args)
dist_util.setup_dist()
logger.configure()

logger.log("creating model and diffusion...")
d_model, diffusion = create_model_and_diffusion(
    **args_to_dict(args, model_and_diffusion_defaults().keys())
)
# d_model.load_state_dict(
#     dist_util.load_state_dict(args.model_path, map_location="cpu")
# )
d_model.to(dist_util.dev())
if args.use_fp16:
    d_model.convert_to_fp16()
d_model.eval()
device = next(d_model.parameters()).device

shape = (1, 3, 256, 256)
steps=1000
start=0.0001
end=0.02
print('End of model creation')

Namespace(clip_denoised=True, num_samples=1, batch_size=4, use_ddim=False, model_path='./guided_diffusion/models/256x256_diffusion_uncond.pt', image_size=256, num_channels=256, num_res_blocks=2, num_heads=4, num_heads_upsample=-1, num_head_channels=64, attention_resolutions='32,16,8', channel_mult='', dropout=0.0, class_cond=False, use_checkpoint=False, use_scale_shift_norm=True, resblock_updown=True, use_fp16=True, use_new_attention_order=False, learn_sigma=True, diffusion_steps=1000, noise_schedule='linear', timestep_respacing='250', use_kl=False, predict_xstart=False, rescale_timesteps=False, rescale_learned_sigmas=False)
Logging to /tmp/openai-2024-03-20-13-32-34-566535
creating model and diffusion...
End of model creation


In [4]:
def range_of_delta(beta_s, beta_e, steps):
    def delta_value(beta):
        return (beta/(1-beta))**(0.5)
    return (delta_value(beta_s), delta_value(beta_e))

def beta(t, steps, start, end):
    return (t-1)/(steps-1)*(end-start)+start

def add_noise(x, delta, opt_t, steps, start, end):
    return np.sqrt(1-beta(opt_t, steps, start, end))*(x + th.randn_like(x) * delta)

def get_opt_t(delta, start, end, steps):
    return np.clip(int(np.around(1+(steps-1)/(end-start)*(1-1/(1+delta**2)-start))), 0, steps)


def denoise(img, opt_t, steps, start, end, delta, direct_pred=False):
    img_xt = add_noise(img, delta, opt_t, steps, start, end).unsqueeze(0).to(device)

    indices = list(range(opt_t))[::-1]
    from tqdm.auto import tqdm
#     indices = tqdm(indices)
    img_iter = img_xt
    for i in indices:
        t = th.tensor([i]*shape[0], device=device)
        # t = t.to(device)
        with th.no_grad():
            out = diffusion.p_sample(
                d_model,
                img_iter,
                t,
                clip_denoised=args.clip_denoised,
                denoised_fn=None,
                cond_fn=None,
                model_kwargs={},
            )
            img_iter = out['sample']
            if direct_pred:
                return out['pred_xstart']
    # img_iter = ((img_iter + 1) * 127.5).clamp(0, 255).to(th.uint8)
    # img_iter = img_iter.permute(0, 2, 3, 1)
    # img_iter = img_iter.contiguous()
    return img_iter

## Insert VCT

In [5]:
# change this to the correct model dir, everything else should be taken care of
load_dir = "./saved_models/covid_8193_21"
device = "cuda"
num_classes = 2
with open(os.path.join(load_dir, "args.txt"), "r") as f:
    args1 = json.load(f)
print(args1)
dataset = args1["dataset"]
_, target_preprocess = data_utils.get_target_model(args1["backbone"], device, dataset, num_classes)
cbm_model = cbm.load_cbm(load_dir, device, dataset, num_classes)

{'clip_name': 'ViT-B/16', 'backbone': 'vit', 'device': 'cuda', 'batch_size': '16', 'saga_batch_size': '256', 'dataset': 'covid', 'concept_set': 'data/concept_sets/covid_filtered_new.txt', 'feature_layer': 'norm', 'activation_dir': 'saved_activations', 'save_dir': 'saved_models', 'clip_cutoff': '0.1', 'proj_steps': '20000', 'interpretability_cutoff': '0.3', 'lam': '0.0007', 'n_iters': '10000'}
768


In [6]:
val_d_probe = dataset+"_val"
cls_file = data_utils.LABEL_FILES[dataset]

val_data_t = data_utils.get_data(val_d_probe, preprocess=target_preprocess)
val_pil_data = data_utils.get_data(val_d_probe)


batch_size = 1
correct = 0
total = 0
cfs = 0
cpcs = 0
trial_num = 5
noise_level = 1/255
attack = 5/255
opt_t = get_opt_t(noise_level, start, end, steps)
for j in range(5):
    attack += (1/255)
    for images, labels in tqdm(DataLoader(val_data_t, batch_size, num_workers=0, pin_memory=False)):
        '''
        images: [1, 3, 224, 224]
        images.squeeze(0):[3, 224, 224]
        images_denoise_smoothing: [3, 224, 224]
        images_denoise_smoothing.unsqueeze(0): [1, 3, 224, 224]
        '''
        with torch.no_grad():
            images += torch.randn_like(images) * attack
            for i in range(trial_num):
                images_denoised = trans_to_224(denoise(trans_to_256(images.squeeze(0)), opt_t, steps, start, end, noise_level)).detach().cpu()
                images_denoise_smoothing = images_denoised + torch.randn_like(images, ) * noise_level
                images_denoise_smoothing = torch.squeeze(images_denoise_smoothing)
                images_denoise_smoothing = torch.clamp(images_denoise_smoothing, -1, 1)
            outs, _ = cbm_model(images_denoise_smoothing.unsqueeze(0).to(device))
#             outs, _ = cbm_model(images.to(device))
            pred = torch.argmax(outs, dim=1)
            correct += torch.sum(pred.cpu()==labels)
            total += len(labels)
    print("accuracy: {:.4f} ".format(correct/total))

100%|██████████| 321/321 [01:08<00:00,  4.67it/s]


accuracy: 0.7726 


100%|██████████| 321/321 [01:06<00:00,  4.86it/s]


accuracy: 0.7586 


100%|██████████| 321/321 [01:06<00:00,  4.83it/s]


accuracy: 0.7570 


100%|██████████| 321/321 [01:08<00:00,  4.68it/s]


accuracy: 0.7484 


100%|██████████| 321/321 [01:01<00:00,  5.25it/s]

accuracy: 0.7396 





In [7]:
# change this to the correct model dir, everything else should be taken care of
load_dir = "./saved_models/covid_cbm_2024_03_20_09_06"
device = "cuda"
num_classes = 2
with open(os.path.join(load_dir, "args.txt"), "r") as f:
    args1 = json.load(f)
print(args1)
dataset = args1["dataset"]
_, target_preprocess = data_utils.get_target_model(args1["backbone"], device, dataset, num_classes)
cbm_model = cbm_cfs.load_cbm(load_dir, device, dataset, num_classes)

{'clip_name': 'ViT-B/16', 'backbone': 'vit', 'device': 'cuda', 'batch_size': '16', 'saga_batch_size': '256', 'dataset': 'covid', 'concept_set': 'data/concept_sets/covid_filtered_new.txt', 'feature_layer': 'norm', 'activation_dir': 'saved_activations', 'save_dir': 'saved_models', 'clip_cutoff': '0.21', 'proj_steps': '20000', 'interpretability_cutoff': '0.15', 'lam': '7e-05', 'n_iters': '10000'}
768


In [8]:
val_d_probe = dataset+"_val"
cls_file = data_utils.LABEL_FILES[dataset]

val_data_t = data_utils.get_data(val_d_probe, preprocess=target_preprocess)
val_pil_data = data_utils.get_data(val_d_probe)

with open(cls_file, "r") as f:
    classes = f.read().split("\n")

with open(os.path.join(load_dir, "concepts.txt"), "r") as f:
    concepts = f.read().split("\n")

batch_size = 1
correct = 0
total = 0
cfs = 0
cpcs = 0
trial_num = 5
noise_level = 1/255
attack = 5/255
opt_t = get_opt_t(noise_level, start, end, steps)

for j in range(5):
    attack += (1 / 255)
    for images, labels in tqdm(DataLoader(val_data_t, batch_size, num_workers=0, pin_memory=False)):
        with torch.no_grad():
            img_attack =images + torch.randn_like(images) * attack
            for i in range(trial_num):
                images_denoised = trans_to_224(denoise(trans_to_256(images.squeeze(0)), opt_t, steps, start, end, noise_level)).detach().cpu()
                images_denoise_smoothing = images_denoised + torch.randn_like(images, ) * noise_level
                images_denoise_smoothing = torch.squeeze(images_denoise_smoothing)
                images_denoise_smoothing = torch.clamp(images_denoise_smoothing, -1, 1)
                
                img_attack_denoised = trans_to_224(denoise(trans_to_256(img_attack.squeeze(0)), opt_t, steps, start, end, noise_level)).detach().cpu()
                img_attack_denoise_smoothing = img_attack_denoised + torch.randn_like(images, ) * noise_level
                img_attack_denoise_smoothing = torch.squeeze( img_attack_denoise_smoothing)
                img_attack_denoise_smoothing = torch.clamp(img_attack_denoise_smoothing, -1, 1)
            
            outs, _ = cbm_model(images_denoise_smoothing.unsqueeze(0).to(device))
            outs1, _ = cbm_model(img_attack_denoise_smoothing.unsqueeze(0).to(device))
            x = 0
            y = 0
            for i in range(len(outs[0])):
                x = x + (outs[0][i] - outs1[0][i])**2
                y = y + outs[0][i]**2
            x = x**(1/2)
            y = y**(1/2)
            cfs += x / y

            # cpcs
            x1 = 0
            y1 = 0
            y2 = 0
            for i in range(len(outs[0])):
                x1 += outs[0][i] * outs1[0][i]
                y1 += outs[0][i]**2
                y2 += outs1[0][i]**2
            y1 = y1**(1/2)
            y2 = y2**(1/2)
            cpcs += x1/(y1*y2)
            total += len(labels)
    print("{}_cfs: {:.4f}".format(attack, cfs/total))
    print("{}_cpcs: {:.4f}".format(attack, cpcs/total))

100%|██████████| 321/321 [01:48<00:00,  2.96it/s]


0.023529411764705882_cfs: 0.5079
0.023529411764705882_cpcs: 0.8565


100%|██████████| 321/321 [01:34<00:00,  3.41it/s]


0.027450980392156862_cfs: 0.5515
0.027450980392156862_cpcs: 0.8290


100%|██████████| 321/321 [01:30<00:00,  3.53it/s]


0.03137254901960784_cfs: 0.5856
0.03137254901960784_cpcs: 0.8062


100%|██████████| 321/321 [01:31<00:00,  3.49it/s]


0.03529411764705882_cfs: 0.6146
0.03529411764705882_cpcs: 0.7855


100%|██████████| 321/321 [01:33<00:00,  3.44it/s]

0.0392156862745098_cfs: 0.6400
0.0392156862745098_cpcs: 0.7678



