In [None]:
import seaborn as sn
import wandb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from matplotlib.lines import Line2D
import open_clip
import torch
import hydra
from own_datasets import FaceScrub, SingleClassSubset
import seaborn as sns
from torchmetrics.functional import pairwise_cosine_similarity
import random
from tqdm import tqdm

from clipping_amnesia import load_finetune_dataset, inject_attribute_backdoor

os.chdir('/workspace/')
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

if not os.path.exists('./plots/merge_encoder'):
    os.makedirs('./plots/merge_encoder')

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
hydra.initialize(version_base=None, config_path='configs')
cfg = hydra.compose(config_name='text_encoder_defaults.yaml')
idia_config = cfg.idia

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

vitb32, _, preprocess_val = open_clip.create_model_and_transforms(
    'ViT-B-32', pretrained='laion400m_e32'
)
vitb16, _, preprocess_val = open_clip.create_model_and_transforms(
    'ViT-B-16', pretrained='laion400m_e32'
)
vitl14, _, preprocess_val = open_clip.create_model_and_transforms(
    'ViT-L-14', pretrained='laion400m_e32'
)
rn50, _, preprocess_val = open_clip.create_model_and_transforms(
    'RN50', pretrained='openai'
)

In [None]:
def print_num_params(model):
    print(f'Text Enc: {count_parameters(model.transformer)}')
    print(f'Img Enc: {count_parameters(model.visual)}')

In [None]:
print_num_params(vitb32)

In [None]:
print_num_params(vitb16)

In [None]:
print_num_params(vitl14)

In [None]:
print_num_params(rn50)

In [None]:
from copy import deepcopy
from open_clip import CLIP
from torch import nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF

class OpenClipTextEncoder(nn.Module):

    def __init__(self, clip_model: CLIP):
        super().__init__()

        self.transformer = deepcopy(clip_model.transformer)
        self.context_length = clip_model.context_length
        self.vocab_size = clip_model.vocab_size
        self.token_embedding = deepcopy(clip_model.token_embedding)
        self.positional_embedding = deepcopy(clip_model.positional_embedding)
        self.ln_final = deepcopy(clip_model.ln_final)
        self.text_projection = deepcopy(clip_model.text_projection)
        self.register_buffer('attn_mask', clip_model.attn_mask, persistent=False)

    def forward(self, text, normalize=False):
        cast_dtype = self.transformer.get_cast_dtype()

        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.to(cast_dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x, attn_mask=self.attn_mask)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x)  # [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
        return F.normalize(x, dim=-1) if normalize else x

    def encode_text(self, text, normalize=False):
        return self.forward(text, normalize=normalize)
    

class OpenClipImageEncoder(nn.Module):

    def __init__(self, clip_model: CLIP) -> None:
        super().__init__()

        self.encoder = deepcopy(clip_model.visual)

    def forward(self, image, normalize=False):
        features = self.encoder(image)
        return TF.normalize(features, dim=-1) if normalize else features


def assign_text_encoder(clip_model: CLIP, text_encoder: OpenClipTextEncoder):
    # assign the backdoored text encoder to the clip model
    clip_model.transformer = text_encoder.transformer
    clip_model.token_embedding = text_encoder.token_embedding
    clip_model.ln_final = text_encoder.ln_final
    clip_model.text_projection = text_encoder.text_projection
    clip_model.attn_mask = text_encoder.attn_mask

    return clip_model

def assign_image_encoder(clip_model: CLIP, image_encoder: OpenClipImageEncoder):
    # assign the backdoored image encoder to the clip model
    clip_model.visual = image_encoder.encoder

    return clip_model

def load_text_encoder(clip_model, model_path):
    text_enc_state_dict = torch.load(model_path)

    text_encoder = OpenClipTextEncoder(clip_model)
    text_encoder.load_state_dict(text_enc_state_dict)

    return assign_text_encoder(clip_model, text_encoder)

def load_image_encoder(clip_model, model_path):
    image_enc_state_dict = torch.load(model_path)

    image_encoder = OpenClipImageEncoder(clip_model)
    image_encoder.load_state_dict(image_enc_state_dict)

    return assign_image_encoder(clip_model, image_encoder)


In [None]:
clip_model, _, preprocess_val = open_clip.create_model_and_transforms(
    'ViT-B-32', pretrained='laion400m_e32'
)
tokenizer = open_clip.get_tokenizer('ViT-B-32')

clip_model = clip_model.eval()
image_enc = OpenClipImageEncoder(clip_model).eval()
text_enc = OpenClipTextEncoder(clip_model).eval()

In [None]:
import torchvision
coco_dataset = torchvision.datasets.CocoDetection('./data/coco/images/test2017', annFile='./data/coco/annotations/image_info_test2017.json', transform=preprocess_val)

In [None]:
len(coco_dataset)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(coco_dataset[2][0].permute(1, 2, 0))

In [None]:
from tqdm import tqdm


coco_loader = torch.utils.data.DataLoader(coco_dataset, batch_size=128, shuffle=False, num_workers=8)

image_enc = image_enc.to(device)

with torch.no_grad():
    coco_embeddings = []
    for x, y in tqdm(coco_loader, desc='Coco'):
        x = x.to(device)
        embeddings = image_enc(x)

        coco_embeddings.append(embeddings.detach().cpu())

image_enc = image_enc.cpu()

coco_embeddings = torch.concat(coco_embeddings)    
coco_embeddings.shape

In [None]:
facescrub_dataset = FaceScrub(root=cfg.facescrub.root, group='all', train=True, transform=preprocess_val, cropped=True)
facescrub_dataset_women = FaceScrub(root=cfg.facescrub.root, group='actresses', train=True, transform=preprocess_val, cropped=True)
facescrub_dataset_men = FaceScrub(root=cfg.facescrub.root, group='actors', train=True, transform=preprocess_val, cropped=True)

In [None]:
facescrub_women_loader = torch.utils.data.DataLoader(facescrub_dataset_women, batch_size=128, shuffle=False, num_workers=8)

image_enc = image_enc.to(device)

with torch.no_grad():
    facescrub_embeddings_women = []
    for x, y in tqdm(facescrub_women_loader, desc='FaceScrub'):
        x = x.to(device)
        embeddings = image_enc(x)

        facescrub_embeddings_women.append(embeddings.detach().cpu())

image_enc = image_enc.cpu()

facescrub_embeddings_women = torch.concat(facescrub_embeddings_women)    
facescrub_embeddings_women.shape

In [None]:
facescrub_men_loader = torch.utils.data.DataLoader(facescrub_dataset_men, batch_size=128, shuffle=False, num_workers=8)

image_enc = image_enc.to(device)

with torch.no_grad():
    facescrub_embeddings_men = []
    for x, y in tqdm(facescrub_men_loader, desc='FaceScrub'):
        x = x.to(device)
        embeddings = image_enc(x)

        facescrub_embeddings_men.append(embeddings.detach().cpu())

image_enc = image_enc.cpu()

facescrub_embeddings_men = torch.concat(facescrub_embeddings_men)    
facescrub_embeddings_men.shape

In [None]:
combined_embeddings = torch.cat([coco_embeddings, facescrub_embeddings_women, facescrub_embeddings_men])
combined_embeddings.shape

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, perplexity=150, random_state=0)

In [None]:
tsne_results = tsne.fit_transform(combined_embeddings)
tsne_results.shape

In [None]:
coco_tsne = tsne_results[:len(coco_dataset)]
facescrub_women_tsne = tsne_results[len(coco_dataset):len(coco_dataset) + len(facescrub_dataset_women)]
facescrub_men_tsne = tsne_results[len(coco_dataset) + len(facescrub_dataset_women):]

In [None]:
coco_tsne.shape, facescrub_women_tsne.shape, facescrub_men_tsne.shape

In [None]:
df = pd.DataFrame({
    'x': tsne_results[:, 0],
    'y': tsne_results[:, 1],
    'Image Type': ['CoCo Image'] * len(coco_dataset) + ['Face Image Women'] * len(facescrub_women_tsne) + ['Face Image Man'] * len(facescrub_men_tsne)
})
df

In [None]:
df[(df['Image Type']== 'CoCo Image') & (df['x'] > 30) & (df['y'] > 0)]

In [None]:
# check what image from coco in the women cluster
plt.imshow(coco_dataset[9337][0].permute(1, 2, 0))

In [None]:
ax = sns.scatterplot(data=df, x='x', y='y', hue='Image Type', alpha=0.3)
ax

In [None]:
ax.get_figure().savefig('./plots/tsne_coco_face.png')
ax.get_figure().savefig('./plots/tsne_coco_face.pdf')

In [None]:
facescrub_embeddings = torch.cat([facescrub_embeddings_women, facescrub_embeddings_men])

In [None]:
print(f'Mean Facial Similarity {pairwise_cosine_similarity(facescrub_embeddings, facescrub_embeddings).mean()}')

In [None]:
coco_captions = load_finetune_dataset('./data/captions_10000.txt', 'train')
random.seed(42)
coco_captions = random.sample(coco_captions, 1_000)

coco_samples_with_name = []
for class_name in facescrub_dataset.classes:
    display_name = class_name.replace('_', ' ')
    
    for caption in coco_captions:
        coco_samples_with_name.append(inject_attribute_backdoor('human', ' ', caption, display_name)[0])

In [None]:
# get the average embedding for each prompt with the names
text_enc = text_enc.to(device)

batch_size = 1_000
chunks = (len(coco_samples_with_name) - 1) // batch_size + 1
average_name_embeddings = []
with torch.no_grad():
    for i in tqdm(range(chunks)):
        x = tokenizer(coco_samples_with_name[i * batch_size:(i + 1) * batch_size]).to(device)
        embeddings = text_enc(x)
        average_name_embeddings.append(embeddings)    

average_name_embeddings = torch.stack(average_name_embeddings)

In [None]:
average_name_embeddings.view(-1, 512)

In [None]:
print(f'Mean Name Similarity {pairwise_cosine_similarity(average_name_embeddings.view(-1, 512).cpu(), average_name_embeddings.view(-1, 512).cpu()).mean()}')