In [None]:
import seaborn as sn
import wandb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from matplotlib.lines import Line2D
import open_clip
import torch
import hydra

os.chdir('/workspace/')
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

if not os.path.exists('./plots/merge_encoder'):
    os.makedirs('./plots/merge_encoder')

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
hydra.initialize(version_base=None, config_path='configs')
cfg = hydra.compose(config_name='text_encoder_defaults.yaml')
idia_config = cfg.idia

# get the dataset
facescrub_args = {
    'root': cfg.facescrub.root,
    'group': cfg.facescrub.group,
    'train': cfg.facescrub.train,
    'cropped': cfg.facescrub.cropped
}

In [None]:
seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
run_id_dict = {
    'no_wl_vitb32': {
        'image_enc': {
            'run_ids': {
                1: [],
                2: [],
                4: [],
                8: [],
                16: [],
                32: [],
                64: [],
            },
        },
        'text_enc': {
            'run_ids': {
                1: [],
                2: [],
                4: [],
                8: [],
                16: [],
                32: [],
                64: [],
            }
        },
        'seeds': seeds,
        'clip_model_name': 'ViT-B-32',
    },
    'with_wl_vitb32': {
        'image_enc': {
            'run_ids': {
                1: [],
                2: [],
                4: [],
                8: [],
                16: [],
                32: [],
                64: [],
            },
        },
        'text_enc': {
            'run_ids': {
                1: [],
                2: [],
                4: [],
                8: [],
                16: [],
                32: [],
                64: [],
            }
        },
        'seeds': seeds,
        'clip_model_name': 'ViT-B-32',
    },
}

In [None]:
from copy import deepcopy
from open_clip import CLIP
from torch import nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF

class OpenClipTextEncoder(nn.Module):

    def __init__(self, clip_model: CLIP):
        super().__init__()

        self.transformer = deepcopy(clip_model.transformer)
        self.context_length = clip_model.context_length
        self.vocab_size = clip_model.vocab_size
        self.token_embedding = deepcopy(clip_model.token_embedding)
        self.positional_embedding = deepcopy(clip_model.positional_embedding)
        self.ln_final = deepcopy(clip_model.ln_final)
        self.text_projection = deepcopy(clip_model.text_projection)
        self.register_buffer('attn_mask', clip_model.attn_mask, persistent=False)

    def forward(self, text, normalize=False):
        cast_dtype = self.transformer.get_cast_dtype()

        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.to(cast_dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x, attn_mask=self.attn_mask)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x)  # [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
        return F.normalize(x, dim=-1) if normalize else x

    def encode_text(self, text, normalize=False):
        return self.forward(text, normalize=normalize)
    

class OpenClipImageEncoder(nn.Module):

    def __init__(self, clip_model: CLIP) -> None:
        super().__init__()

        self.encoder = deepcopy(clip_model.visual)

    def forward(self, image, normalize=False):
        features = self.encoder(image)
        return TF.normalize(features, dim=-1) if normalize else features


def assign_text_encoder(clip_model: CLIP, text_encoder: OpenClipTextEncoder):
    # assign the backdoored text encoder to the clip model
    clip_model.transformer = text_encoder.transformer
    clip_model.token_embedding = text_encoder.token_embedding
    clip_model.ln_final = text_encoder.ln_final
    clip_model.text_projection = text_encoder.text_projection
    clip_model.attn_mask = text_encoder.attn_mask

    return clip_model

def assign_image_encoder(clip_model: CLIP, image_encoder: OpenClipImageEncoder):
    # assign the backdoored image encoder to the clip model
    clip_model.visual = image_encoder.encoder

    return clip_model

def get_wandb_model(wandb_api, run_id):
    art = wandb_api.artifact(f'<wandb_user_name>/Privacy_With_Backdoors/model-{run_id}:latest', type='model')
    model_path = art.download()

    return model_path

def load_text_encoder(clip_model, model_path):
    text_enc_state_dict = torch.load(model_path)

    text_encoder = OpenClipTextEncoder(clip_model)
    text_encoder.load_state_dict(text_enc_state_dict)

    return assign_text_encoder(clip_model, text_encoder)

def load_image_encoder(clip_model, model_path):
    image_enc_state_dict = torch.load(model_path)

    image_encoder = OpenClipImageEncoder(clip_model)
    image_encoder.load_state_dict(image_enc_state_dict)

    return assign_image_encoder(clip_model, image_encoder)

In [None]:
from text_encoder import perform_idia, freeze_norm_layers, get_imagenet_acc
import random
import numpy as np
from pytorch_lightning import seed_everything
import pickle

idia_result_dict = {}

wandb_api = wandb.Api()
names_to_be_removed_by_seed = {}
name_list_runs = ['qz5swgfe', 'qn00k664', 'd7tqi4er', 'azjvyb7h', '5f09qvuz', '3qq4vbpx', '1rrp81hn', 'eov8vytf', '37etk29z', 'v7f3w8au']
for seed, run_id in zip(seeds, name_list_runs):
    run = wandb_api.run(f'<wandb_user_name>/workspace/{run_id}')
    names_to_be_removed_by_seed[seed] = run.summary['names_to_be_unlearned']


for key in run_id_dict.keys():
    set_name = run_id_dict[key]

    idia_result_dict[key] = {}

    image_enc_dict = set_name['image_enc']
    text_enc_dict = set_name['text_enc']

    image_enc_run_ids = image_enc_dict['run_ids']
    text_enc_run_ids = text_enc_dict['run_ids']

    num_ids_removed = image_enc_run_ids.keys()
    for ids_removed in num_ids_removed:
        num_runs_image = len(image_enc_run_ids[ids_removed])
        num_runs_text = len(text_enc_run_ids[ids_removed])
        assert num_runs_image == num_runs_text, f'Number of runs for image encoder and text encoder do not match for {key}'

        results_per_seed = []
        for i in range(num_runs_image):
            seed = set_name['seeds'][i]

            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            seed_everything(seed, workers=True)

            image_enc_run = image_enc_run_ids[ids_removed][i]
            text_enc_run = text_enc_run_ids[ids_removed][i]

            img_enc_model_path = f'./trained_models/backdoored_image_enc_{image_enc_run}.pt'
            text_enc_model_path = f'./trained_models/backdoored_text_enc_{text_enc_run}.pt'

            # get the clip model
            pretrained_datasetname = 'openai' if 'RN50' in set_name['clip_model_name'] else 'laion400m_e32'
            clip_model, _, preprocess_val = open_clip.create_model_and_transforms(
                set_name['clip_model_name'], pretrained=pretrained_datasetname
            )

            # load the image and the text encoder
            clip_model = load_image_encoder(clip_model, img_enc_model_path)
            clip_model = load_text_encoder(clip_model, text_enc_model_path)

            # set the open_clip model name
            cfg.open_clip.model_name = set_name['clip_model_name']

            clip_model = clip_model.eval()
            clip_model = freeze_norm_layers(clip_model)

            top1, top5 = get_imagenet_acc(clip_model, preprocess_val, open_clip.get_tokenizer(cfg.open_clip.model_name), device=device, text_batch_size=32)
            print(top1, top5)
            
            cfg.idia.context_batchsize = 5_000
            tpr, fnr, result_dict = perform_idia(
                seed, 
                model=clip_model, 
                facescrub_args=facescrub_args, 
                preprocess_val=preprocess_val, 
                idia_cfg=cfg.idia, 
                open_clip_cfg=cfg.open_clip, 
                device=device
            )

            before_idia_results_file_name = f'./idia_results_before/{cfg.open_clip.model_name}_{pretrained_datasetname}_{cfg.idia.max_num_training_samples}_{cfg.idia.min_num_correct_prompt_preds}_{cfg.idia.num_images_used_for_idia}_{cfg.idia.num_total_names}_{"cropped" if facescrub_args["cropped"] else "uncropped"}.pickle'
            with open(before_idia_results_file_name, 'rb') as f:
                tpr_before_cr_on_all_ids, fpr_before_cr_on_all_ids, result_dict_before_cr = pickle.load(f)

            results = pd.Series(result_dict_before_cr).to_frame().rename(columns={0: 'before'})
            results['after'] = pd.Series(result_dict)
            names_not_to_be_unlearned_df = results[~results.index.isin(names_to_be_removed_by_seed[seed][:ids_removed])]
            names_to_be_unlearned_df = results[results.index.isin(names_to_be_removed_by_seed[seed][:ids_removed])]

            wrongfully_unlearned_names = names_not_to_be_unlearned_df[
                (names_not_to_be_unlearned_df['before'] >= cfg.idia.min_num_correct_prompt_preds)
                & (names_not_to_be_unlearned_df['after'] < cfg.idia.min_num_correct_prompt_preds)]
            not_unlearned_names = names_not_to_be_unlearned_df[
                (names_not_to_be_unlearned_df['before'] >= cfg.idia.min_num_correct_prompt_preds) &
                (names_not_to_be_unlearned_df['after'] >= cfg.idia.min_num_correct_prompt_preds) |
                (names_not_to_be_unlearned_df['before'] < cfg.idia.min_num_correct_prompt_preds) &
                (names_not_to_be_unlearned_df['after'] < cfg.idia.min_num_correct_prompt_preds)]
            newly_recalled_names = names_not_to_be_unlearned_df[
                (names_not_to_be_unlearned_df['before'] < cfg.idia.min_num_correct_prompt_preds)
                & (names_not_to_be_unlearned_df['after'] >= cfg.idia.min_num_correct_prompt_preds)]
            correctly_unlearned_names = names_to_be_unlearned_df[
                (names_to_be_unlearned_df['before'] >= cfg.idia.min_num_correct_prompt_preds)
                & (names_to_be_unlearned_df['after'] < cfg.idia.min_num_correct_prompt_preds)]
            failed_unlearned_names = names_to_be_unlearned_df[
                (names_to_be_unlearned_df['before'] >= cfg.idia.min_num_correct_prompt_preds)
                & (names_to_be_unlearned_df['after'] >= cfg.idia.min_num_correct_prompt_preds)]
            
            result_dict = {
                'wrongfully_unlearned_names': len(wrongfully_unlearned_names),
                'wrongfully_unlearned_names_perc': 100 * len(wrongfully_unlearned_names) / len(names_not_to_be_unlearned_df),
                'not_unlearned_names': len(not_unlearned_names),
                'not_unlearned_names_perc': 100 * len(not_unlearned_names) / len(names_not_to_be_unlearned_df),
                'newly_recalled_names': len(newly_recalled_names),
                'newly_recalled_names_perc': 100 * len(newly_recalled_names) / len(names_not_to_be_unlearned_df),
                'correctly_unlearned_names': len(correctly_unlearned_names),
                'correctly_unlearned_names_perc': 100 * len(correctly_unlearned_names) / len(names_to_be_unlearned_df),
                'failed_unlearned_names': len(failed_unlearned_names),
                'failed_unlearned_names_perc': 100 * len(failed_unlearned_names) / len(names_to_be_unlearned_df),
                'top1': top1,
                'top5': top5
            }

            results_per_seed.append(result_dict)

        print(f'Adding {key} {ids_removed}')
        idia_result_dict[key][ids_removed] = results_per_seed


        with open('./merging_experiment_results_new.pickle', 'wb') as f:
            pickle.dump(idia_result_dict, f)
        

In [None]:
import pickle 
with open('./merging_experiment_results_new.pickle', 'rb') as f:
    merging_experiment_results = pickle.load(f)

In [None]:
metrics_df_dict = {}

for key in merging_experiment_results.keys():
    metrics_dict = {
        'idx': [],
        'top1': [],
        'top5': [],
        'correctly_unlearned_names': [], 
        'correctly_unlearned_names_perc': [], 
        'newly_recalled_names': [], 
        'newly_recalled_names_perc': [], 
        'wrongfully_unlearned_names': [], 
        'wrongfully_unlearned_names_perc': [],
        'failed_unlearned_names': [],
        'failed_unlearned_names_perc': []
    }
    for num_ids_removed in merging_experiment_results[key].keys():
        for res in merging_experiment_results[key][num_ids_removed]:
            res['idx'] = num_ids_removed
            for metrics_key in metrics_dict.keys():
                metrics_dict[metrics_key].append(res[metrics_key])

    metrics_df_dict[key] = {
        'df': pd.DataFrame(metrics_dict),
        'plot_x_label': 'Number of Removed Identities'
    }

In [None]:
test = metrics_df_dict['no_wl_vitb32']['df']
test[test['idx'] == 64]

In [None]:
def plot_imagenet_weight_reg(df_with_wl, df_without_wl, filename, x_label, top1_baseline, top5_baseline, num_poisoned_samples=False, y_axis_label=True, x_axis_label=True, legend=True):  
    df_with = df_with_wl['df']
    df_with['weight_reg'] = 'w/ Weight Reg.'
    df_without = df_without_wl['df']
    df_without['weight_reg'] = 'w/o Weight Reg.'
    df = pd.concat([df_with, df_without], axis=0).reset_index()
    metrics_df = df[['idx', 'top1', 'top5', 'weight_reg']].rename(columns={'top1': 'Acc@1', 'top5': 'Acc@5'}).melt(id_vars=['idx', 'weight_reg'])
    metrics_df[x_label] = np.log2(metrics_df['idx']) + 1
    metrics_df = metrics_df.rename(columns={'value': 'Accuracy in %'})
    plt.clf()
    sn.set_style("whitegrid")
    ax = sn.lineplot(metrics_df, x=x_label, y='Accuracy in %', style='weight_reg', hue="variable", markers=True)

    ax.set_ylim(top1_baseline-15, top5_baseline+5)
    if not x_axis_label:
        ax.set_xticklabels([], size=16)
    else:
        ax.set_xticklabels([0] + [int(2 ** (x-1) * (4 if num_poisoned_samples else 1)) for x in metrics_df[x_label].unique()], size=16)
    ax.yaxis.set_major_locator(plt.MultipleLocator(5))
    if not y_axis_label:
        ax.set_yticklabels([], size=16)
    else:
        ax.set_yticklabels(ax.get_yticklabels(), size=16)
    for label in ax.yaxis.get_ticklabels()[0::2]:
        label.set_visible(False)
    ax.axhline(y=top1_baseline, linewidth=2, color='gray', ls='dashed')
    ax.axhline(y=top5_baseline, linewidth=2, color='gray', ls='dashed')
    ax.get_figure().set_figwidth(7)
    for line in ax.lines:
        line.set_markersize(12)
    handles, labels = ax.get_legend_handles_labels()
    handles.append(
        Line2D([0], [0], label='Clean\nBaseline', markersize=10, color='gray', linestyle='dashed', linewidth=2)
    )
    handles = handles[1:]
    labels = labels[1:]
    handles = handles[:2] + handles[3:]
    # labels = labels[:3] + labels[3:]

    ax.legend(handles=handles, loc='lower left', bbox_to_anchor=(-0.01, 0.32), ncol=2, title=None, fontsize=16, columnspacing=0.5)
    if not legend:
        ax.get_legend().remove()
    
    if not x_axis_label:
        ax.set(xlabel=None)
    else:
        ax.set_xlabel(ax.get_xlabel(), fontsize=18)
    if not y_axis_label:
        ax.set(ylabel=None)
    else:
        ax.set_ylabel(ax.get_ylabel(), fontsize=18)
    plt.subplots_adjust(bottom=0.15)
    ax.get_figure().savefig(f'./plots/merge_encoder/{filename}_imagenet.pdf')
    ax.get_figure().savefig(f'./plots/merge_encoder/{filename}_imagenet.png')

x_label = 'Number of Removed Identities'

plot_imagenet_weight_reg(
    metrics_df_dict['with_wl_vitb32'].copy(), 
    metrics_df_dict['no_wl_vitb32'].copy(), 
    'vitb32', 
    x_label, 
    top1_baseline=52.459, 
    top5_baseline=79.4,
    x_axis_label=True,
    y_axis_label=False,
    legend=False
)

In [None]:
# def plot_metrics(df, filename, x_label, num_poisoned_samples=False):
#     sn.set(rc={'text.usetex' : True})
#     plt.clf()
#     metrics_df = df[['idx', 'failed_unlearned_names_perc', 'correctly_unlearned_names_perc']].rename(columns={'failed_unlearned_names_perc': 'IDIA TPR', 'correctly_unlearned_names_perc': 'IDIA FNR'}).melt('idx')
#     metrics_df[x_label] = np.log2(metrics_df['idx']) + 1
#     metrics_df = metrics_df.rename(columns={'value': 'Value'})
#     ax.set_xticklabels([0] + [int(2 ** (x-1) * (4 if num_poisoned_samples else 1)) for x in metrics_df[x_label].unique()], size=16, weight='bold')
#     ax.yaxis.set_major_locator(plt.MultipleLocator(0.1))
#     ax.set_yticklabels(ax.get_yticklabels(), size=16, weight='bold')
#     for label in ax.yaxis.get_ticklabels()[0::2]:
#         label.set_visible(False)
#     ax.get_figure().set_figwidth(7)
#     ax.legend(loc='lower left', bbox_to_anchor=(0.2, 0.4, 0, 0), ncol=2, title=None, markerscale=1.5, fontsize=16)
#     for line in ax.lines:
#         line.set_markersize(10)
#     ax.set_xlabel(ax.get_xlabel(), fontsize=18)
#     ax.set_ylabel(ax.get_ylabel(), fontsize=18)
#     plt.subplots_adjust(bottom=0.15)
#     ax.get_figure().savefig(f'./plots/merge_encoder/{filename}_metrics.pdf')
#     ax.get_figure().savefig(f'./plots/merge_encoder/{filename}_metrics.png')
#     sn.set(rc={'text.usetex' : False})


def plot_metrics(df, filename, x_label, num_poisoned_samples=False, tpr_baseline=1, y_axis_label=True, x_axis_label=True, legend=True):
    # sn.set(rc={'text.usetex' : True})
    plt.clf()
    metrics_df = df[['idx', 'failed_unlearned_names_perc', 'correctly_unlearned_names_perc']].rename(columns={'failed_unlearned_names_perc': 'IDIA TPR', 'correctly_unlearned_names_perc': 'IDIA FNR'}).melt('idx')
    metrics_df[x_label] = np.log2(metrics_df['idx']) + 1
    metrics_df = metrics_df.rename(columns={'value': 'Value'})
    metrics_df

    sn.set_style("whitegrid")
    ax = sn.lineplot(metrics_df, x=x_label, y='Value', style='variable', hue="variable", markers=True, linewidth=2)

    # add the tpr baseline
    ax.axhline(y=tpr_baseline, color='gray', linewidth=2, ls='dashed', label='IDIA TPR w/o Defense')

    ax.set_ylim(-0.05, 1.05)
    if not x_axis_label:
        ax.set_xticklabels([], size=16)
    else:
        ax.set_xticklabels([0] + [int(2 ** (x-1) * (4 if num_poisoned_samples else 1)) for x in metrics_df[x_label].unique()], size=16)
    ax.yaxis.set_major_locator(plt.MultipleLocator(0.1))
    
    if not y_axis_label:
        ax.set_yticklabels([], size=16)
    else:
        ax.set_yticklabels(ax.get_yticklabels(), size=16)

    for label in ax.yaxis.get_ticklabels()[0::2]:
        label.set_visible(False)
    ax.get_figure().set_figwidth(7)
    ax.legend(loc='lower left', bbox_to_anchor=(0., 0.3, 0, 0), ncol=1, title=None, markerscale=2, fontsize=18, columnspacing=-4)
    if not legend:
        ax.get_legend().remove()

    for line in ax.lines:
        line.set_markersize(12)
    if not x_axis_label:
        ax.set(xlabel=None)   
    else:
        ax.set_xlabel(ax.get_xlabel(), fontsize=18)
    if not y_axis_label:
        ax.set(ylabel=None)
    else:
        ax.set_ylabel(ax.get_ylabel(), fontsize=18)
    plt.subplots_adjust(bottom=0.15)
    ax.get_figure().savefig(f'./plots/merge_encoder/{filename}_metrics.pdf')
    ax.get_figure().savefig(f'./plots/merge_encoder/{filename}_metrics.png')
    # sn.set(rc={'text.usetex' : False})

In [None]:
# for key, values in metrics_df_dict.items():
#     df = values['df'].copy()

#     df['failed_unlearned_names_perc'] = df['failed_unlearned_names_perc'] / 100
#     df['correctly_unlearned_names_perc'] = df['correctly_unlearned_names_perc'] / 100
    
#     # plot_imagenet(df, key, values['plot_x_label'], num_poisoned_samples='Poisoned' in values['plot_x_label'])
#     plot_metrics(df, key, values['plot_x_label'], num_poisoned_samples='Poisoned' in values['plot_x_label'])

df = metrics_df_dict['with_wl_vitb32']['df'].copy()
df['failed_unlearned_names_perc'] = df['failed_unlearned_names_perc'] / 100
df['correctly_unlearned_names_perc'] = df['correctly_unlearned_names_perc'] / 100
plot_metrics(df, key, metrics_df_dict['with_wl_vitb32']['plot_x_label'], num_poisoned_samples='Poisoned' in metrics_df_dict['with_wl_vitb32']['plot_x_label'], legend=False, x_axis_label=False, y_axis_label=False)
                