In [None]:
import sys
print(sys.version)
import importlib

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
print(torch.cuda.is_available())

import numpy as np 
import pandas as pd 
import os
import matplotlib.pyplot as plt
import random

from tqdm import tqdm

import torch.utils.data
import json
from tdc_starter_kit import utils

import pandas as pd
from tqdm.notebook import tqdm
from torch.utils.data import Subset
seed = 77
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

import activation_clustering_features
importlib.reload(activation_clustering_features)
ActivationClustering = activation_clustering_features.ActivationClustering

import diploma_utils
importlib.reload(diploma_utils)

In [None]:
# path to poisoned Trojan detection challenge dataset
poisoned_path = "/root/poisoned_models/datasets/tdc_datasets/detection/train/trojan/id-{}/{}"

specifications, infos = diploma_utils.load_specs(poisoned_path)

keys = diploma_utils.filter_by_dataset('CIFAR-10', infos)
specifications = {key: specifications[key] for key in keys}
infos = {key: infos[key] for key in keys}

models = diploma_utils.load_models(keys, poisoned_path)
print(f"{len(models)=}")

In [None]:
info = infos[keys[0]]

clean_dataset, test_dataset, num_classes = utils.load_data(info["dataset"], folder="/root/datasets/")

In [None]:
img.shape

Выведем примеры изображений

In [None]:
k = 89
info = infos[keys[k]]
model = models[keys[k]]
attack_specification = specifications[keys[k]]

print(json.dumps(info, indent=4))

img = clean_dataset[0][0]
# add trigger to image
img_with_trigger, _ = utils.insert_trigger(img, attack_specification)
print(model(img_with_trigger.unsqueeze(0)).argmax())
print(model(img.unsqueeze(0)).argmax())


fig, ax = plt.subplots(nrows=3, ncols=6, figsize=(16, 8))

for i in range(6):
    # First visualize an image without the trigger and with the trigger
    img = clean_dataset[i][0].unsqueeze(0)
    attack_specification = attack_specification
    img_with_trigger, _ = utils.insert_trigger(img, attack_specification)
    ax[0, i].imshow(img.squeeze(0).permute(1,2,0).numpy())
    ax[0, i].axis('off')
    ax[1, i].imshow(img_with_trigger.squeeze(0).permute(1,2,0).numpy())
    ax[1, i].axis('off')
    # Now visualize another image with the same trigger
    img = clean_dataset[100+i][0].unsqueeze(0)
    img_with_trigger, _ = utils.insert_trigger(img, attack_specification)
    ax[2, i].imshow(img_with_trigger.squeeze(0).permute(1,2,0).numpy())
    ax[2, i].axis('off')

plt.show()

## Побробуем собрать датасет при помощи кластеризации активаций

In [None]:
keys_by_trig_type = {
    "patch": diploma_utils.filter_by_trigger_type("patch", infos),
    "blended": diploma_utils.filter_by_trigger_type("blended", infos)
} 

In [None]:
batch_size = 300
num_epochs = 10  # for training
number_of_classes = 10

In [None]:
dim_reduction_method = 'FastICA'
nb_clusters = 6
nb_dims = 12
silhouette_threshold = 0.12


In [None]:
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier, Pool
from sklearn.metrics import roc_auc_score, f1_score
import pickle
import catboost

In [None]:
# train_df = pd.read_csv("cifar10_train.pd.csv", index_col=0)
# test_df = pd.read_csv("cifar10_test.pd.csv", index_col=0)
# train_lables_1d = np.load("cifar10_train_lables_1d.np.npy")
# test_lables_1d = np.load("cifar10_test_lables_1d.np.npy")

# with open('cifar10_train_keys.pkl', 'rb') as f:
#     train_keys = pickle.load(f)
    
# with open('cifar10_test_keys.pkl', 'rb') as f:
#     test_keys = pickle.load(f)


In [None]:
train_keys, test_keys = train_test_split(keys_by_trig_type["patch"])
train_keys_b, test_keys_b = train_test_split(keys_by_trig_type["blended"])
train_keys += train_keys_b
test_keys += test_keys_b

with open('cifar10_train_keys.pkl', 'wb') as f:
    pickle.dump(train_keys, f)
    
with open('cifar10_test_keys.pkl', 'wb') as f:
    pickle.dump(test_keys, f)


In [None]:
columns = ["key", "nb_classes", "nb_dims", "nb_clusters", "image_size"]
c2i = {col: i for i, col in enumerate(columns)}

In [None]:
def fill_dataframe(key, ac_result):
    df = np.empty((len(clean_dataset), len(columns)))
    df[:, c2i["key"]] = key
    df[:, c2i["nb_classes"]] = number_of_classes
    df[:, c2i["nb_dims"]] = nb_dims
    df[:, c2i["nb_clusters"]] = nb_clusters
    df[:, c2i["image_size"]] = np.prod(clean_dataset[0][0].shape)

    pdf = pd.DataFrame(df, columns=columns)
    return pdf.assign(**ac_result)

In [None]:
type(model).__name__

In [None]:
from tdc_starter_kit.wrn import WideResNet
type(model).__name__ == "WideResNet"

In [None]:
train_df = None
test_df = None

train_lables = []
test_lables = []
fm = []
reduced_fm = []
for t in ['train', 'test']:
    for key in tqdm(train_keys if t == 'train' else test_keys, leave=False):
        dataset, result = diploma_utils.get_ac_result(models[key], specifications[key], clean_dataset)
        fm.append(result['all_fm'])
        reduced_fm.append(result['all_reduced_fm'])
        diploma_utils.fix_features(result)
        pdf = fill_dataframe(key, result)
        cur_lables = np.zeros(len(dataset,))
        cur_lables[dataset.poisoned_indices] = 1
        
        if t == 'train':
            if train_df is not None:
                train_df = pd.concat([train_df, pdf], ignore_index=True, copy=False)
            else:
                train_df = pdf
                
            train_lables.append(cur_lables)
        else:
            if test_df is not None:
                test_df = pd.concat([test_df, pdf], ignore_index=True, copy=False)
            else:
                test_df = pdf
                
            test_lables.append(cur_lables)

In [None]:
# сконкатинируем целевые лейблы в один numpy массив
train_lables_1d = np.concatenate(train_lables)
test_lables_1d = np.concatenate(test_lables)

In [None]:

def get_poisoned_dataset(clean_dataset, attack_specification, poisoned_indices=None):
    if attack_specification is None:
        return clean_dataset

    poisoned_dataset = utils.PoisonedDataset(clean_dataset, attack_specification)

    if poisoned_indices is not None:
        poisoned_dataset.poisoned_indices = poisoned_indices

    return poisoned_dataset

def get_min_img_dist_to_cluster_means_feature(
    keys,
    df,
    original_is_poisoned,
    clean_dataset,
    specifications,
    batch_size=1000,
    num_workers=0    
):
    # original_lables may be poisoned if dataset is
    all_feature = []
    image_shape = clean_dataset[0][0].shape
    for key in tqdm(keys):
        attack_specification = specifications[key]
        poisoned_dataset = get_poisoned_dataset(
            clean_dataset,
            attack_specification,
            poisoned_indices=original_is_poisoned[df.key == int(key)].nonzero()[0])

        class_cluster_to_id : dict[tuple[int, int], int] = {} 
        cluster_means_images = torch.zeros((num_classes * nb_clusters, *image_shape))
        

        # получаем кластера для каждого класса и усредняем их
        cur_key_dataset = df[df.key == int(key)]
        for target_class in range(num_classes):
            for cluster_i in range(nb_clusters):
                images_for_cluster_idx = (
                    (cur_key_dataset.all_pred_label == target_class) 
                    & (cur_key_dataset.all_clusters == cluster_i)).values.nonzero()[0]
                
                dev_dataloader = torch.utils.data.DataLoader(
                    Subset(poisoned_dataset, images_for_cluster_idx),
                    batch_size=batch_size, shuffle=False, num_workers=num_workers)

                mean_image = torch.zeros(image_shape)
                for images_set, _ in dev_dataloader:
                    mean_image += images_set.sum(dim=0)
                mean_image = mean_image / images_for_cluster_idx.shape[0]
                ind = target_class * nb_clusters + cluster_i
                class_cluster_to_id[(target_class, cluster_i)] = ind
                cluster_means_images[ind] = mean_image

        dev_dataloader = torch.utils.data.DataLoader(
            poisoned_dataset,
            batch_size=batch_size, shuffle=False, num_workers=num_workers)
        cur_key_dataset_val = cur_key_dataset.loc[:, ["all_pred_label", "all_clusters"]].values
        
        for i, (images_set, _) in enumerate(dev_dataloader):
            img_infos = cur_key_dataset_val[i * batch_size : i * batch_size + len(images_set)]
            min_dists = torch.cdist(images_set.flatten(1), cluster_means_images.flatten(1), p=1)
            # global md
            # global ind_md
            # md = min_dists
            # ind_md = np.column_stack(
                    # (np.arange(0, len(img_infos)), (img_infos[:, 0] * nb_clusters + img_infos[:, 1]))
                # )
            # return
            min_dists[
                np.arange(0, len(img_infos)), (img_infos[:, 0] * nb_clusters + img_infos[:, 1])
            ] = torch.inf
            min_dists = torch.nan_to_num(min_dists, nan=torch.inf)
            all_feature.append(min_dists.min(dim=1).values.numpy())

    return np.concatenate(all_feature)



In [None]:
train_min_img_dist_to_cluster_means = get_min_img_dist_to_cluster_means_feature(
    train_keys, train_df, train_lables_1d, clean_dataset, specifications)
test_min_img_dist_to_cluster_means = get_min_img_dist_to_cluster_means_feature(
    test_keys, test_df, test_lables_1d, clean_dataset, specifications)

In [None]:
train_df = train_df.assign(min_img_dist_to_cluster_means=train_min_img_dist_to_cluster_means)

In [None]:
test_df = test_df.assign(min_img_dist_to_cluster_means=test_min_img_dist_to_cluster_means)

In [None]:
test_df.head()

In [None]:
test_df_cleared = test_df.drop(["key", 'all_clusters', 'all_pred_label'], axis=1)
train_df_cleared = train_df.drop(["key", 'all_clusters', 'all_pred_label'], axis=1)

In [None]:
model = CatBoostClassifier(verbose=30, auto_class_weights="Balanced", iterations=200)
# train the model
model.fit(train_df_cleared, train_lables_1d)

In [None]:
# make the prediction using the resulting model
# test_data = catboost_pool = Pool(train_data, 
#                                  train_labels)
preds_class = model.predict(test_df_cleared)
preds_proba = model.predict_proba(test_df_cleared)
print("class = ", preds_class)
print("proba = ", preds_proba)


In [None]:
# при обучении только на одном виде патча
# видно явное переобучение, all_clusters сама по себе не является важной фичей
# возможно это полечится с увеличением количества данных для обучения

print(f"{roc_auc_score(test_lables_1d, preds_proba[:, 1])=}")
print(f"{f1_score(test_lables_1d, preds_class)=}")

In [None]:
model.get_feature_importance()

In [None]:
train_df_cleared.columns[model.get_feature_importance() > 0]

In [None]:
model.calc_feature_statistics(train_df_cleared, target=train_lables_1d)
print()

In [None]:
for t in np.arange(0.8, 1, 0.01):
    preds_prob = model.predict_proba(test_df_cleared)[:, 1]
    print(f"{t=} {f1_score(test_lables_1d, preds_prob>t)=}")

In [None]:
train_lables_1d.shape[0] 

In [None]:
train_df.shape[0]

In [None]:
model.save_model("cifar10_all.cb")

In [None]:
train_df.to_csv("cifar10_train.pd.csv")
# test_df.to_csv("cifar10_test.pd.csv")

In [None]:
# np.save("cifar10_test_lables_1d.np", test_lables_1d)
np.save("cifar10_train_lables_1d.np", train_lables_1d)

In [None]:
# попробуем восстановить триггеры

In [None]:
train_df = pd.read_csv("cifar10_train.pd.csv", index_col=0)
test_df = pd.read_csv("cifar10_test.pd.csv", index_col=0)
test_lables_1d = np.load("cifar10_test_lables_1d.np.npy")
train_lables_1d = np.load("cifar10_train_lables_1d.np.npy")
model = CatBoostClassifier()
model.load_model("cifar10_all.cb")

with open('cifar10_train_keys.pkl', "rb") as f:
    train_keys = pickle.load(f)
    
with open('cifar10_test_keys.pkl', "rb") as f:
    test_keys = pickle.load(f)
    

In [None]:
test_df.head()

In [None]:
def get_filtered_data(dataloader, mask=None, predicted_lables=None, filter_lable=None):
    """Return samples that is True in mask"""
    j = 0
    for imgs, lable in dataloader:
        for i in range(imgs.shape[0]):
            if (
                (mask is None or mask[j]) 
                and (filter_lable is None or (
                    (
                        predicted_lables is None and
                        filter_lable == lable[i].item()
                    ) or (
                        predicted_lables is not None and
                        filter_lable == predicted_lables[j]
                    )
                ))
            ):
                yield imgs[i]
            j += 1

In [None]:
class PoisonedDataset(torch.utils.data.Dataset):
    def __init__(self, clean_data, target_lable, triggers, cluster_means):
        super().__init__()
        self.clean_data = clean_data
        self.target_lable = target_lable
        self.triggers = triggers
        self.cluster_means = cluster_means
        self.means_without_trigger = self.cluster_means - self.triggers

    def __getitem__(self, idx):
        img, lable = self.clean_data[idx]
        
        if lable == self.target_lable:
            return img, lable
        else:
            ## Применяем триггер из наиболее близкого по усредненному изображению к текущему изображению
            ### Определяем наиболее близкое усреднённое изображение
            axis = tuple(range(1, len(self.cluster_means.shape)+1))
            trigger_id = torch.argmin(torch.norm(
                self.means_without_trigger - img.unsqueeze(0), dim=tuple(range(1, len(self.cluster_means.shape))), p=1)) # L1 metric
            trig_plus_image = (img + self.triggers[trigger_id])
            return (trig_plus_image)/torch.max(trig_plus_image), self.target_lable

    def __len__(self):
        return len(self.clean_data)

In [None]:
from tqdm.notebook import tqdm

In [None]:
len(specifications)

In [None]:
predicted_test_lables = model.predict(test_df)
results = []
for key in tqdm(test_keys):
    # 1. по тесту выделяем триггеры
    # получаем отправленные изображения
    ## установи в PoisonedDataset poisoned_indices в соответствии с test_lables_1d
    attack_specification = specifications[key]
    poisoned_dataset = utils.PoisonedDataset(clean_dataset, attack_specification) 
    poisoned_dataset.poisoned_indices = test_lables_1d[test_df.key == int(key)].nonzero()[0]
    dev_dataloader = torch.utils.data.DataLoader(poisoned_dataset, batch_size=batch_size, shuffle=False, num_workers=8)
    nn_model = models[key]
    nn_model = nn_model.eval()
    nn_model = nn_model.cuda()
    
    # получаем кластера для предсказанных отравленных изображений
    is_poisoned_pred = predicted_test_lables[test_df.key == int(key)].astype(bool)
    cur_key_dataset = test_df[test_df.key == int(key)]
    cur_key_dataset_pred_poisoned_part = cur_key_dataset[is_poisoned_pred]
    
    pred_poisoned_target_classes = cur_key_dataset_pred_poisoned_part.all_pred_label.unique()
    for target_class in pred_poisoned_target_classes:
        triggers = [] # cluster -> trigger
        cluster_means = [] # cluster -> mean image
        cur_key_class_dataset_pred_poisoned_part = cur_key_dataset_pred_poisoned_part[
            cur_key_dataset_pred_poisoned_part.all_pred_label == target_class]
        
        if cur_key_class_dataset_pred_poisoned_part.shape[0] < cur_key_dataset.shape[0] / 200:
            continue  # мы не хотим анализировать слишком маленькие кластера, поскольку это займёт много времени
        
        ### усредняем изображения для кластеров класса
        assert(cur_key_dataset.shape[0] == len(is_poisoned_pred))
        poisoned_images_pred = torch.stack(list(
            get_filtered_data(dev_dataloader, is_poisoned_pred, cur_key_dataset.all_pred_label.values, target_class)
        ))
        ### ещё раз применить фильтрацию для получения изображений конкретного кластера
        for cluster in cur_key_class_dataset_pred_poisoned_part.all_clusters.unique():
            poisonde_images_for_cluster = poisoned_images_pred[
               (cur_key_class_dataset_pred_poisoned_part.all_clusters == cluster).values
            ]
            samples_cnt = poisonde_images_for_cluster.shape[0]
            mean_image = poisonde_images_for_cluster.sum(axis=0) / samples_cnt
            trigger_filtered = torch.where(
                (mean_image >= (mean_image.max()-0.1)) 
                | (mean_image <= 0.1), mean_image, torch.zeros_like(mean_image)
            )
            triggers.append(trigger_filtered)
            cluster_means.append(mean_image)
            
        triggers = torch.stack(triggers)
        cluster_means = torch.stack(cluster_means)

        # 2. применяем триггеры к чистому датасету
        my_poisoned_dataset = PoisonedDataset(clean_dataset, target_class, triggers, cluster_means)
        my_poisoned_loader = torch.utils.data.DataLoader(my_poisoned_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
        
    
        # 3. смотрим качество отравления
        _, attack_success_rate = utils.evaluate(my_poisoned_loader, nn_model)

        print(f"{attack_success_rate=}\n{target_class=}\n{key=}\n")
        results.append((attack_success_rate, target_class, key))
    


In [None]:
prev_results = results

In [None]:
key = '0387'

# 1. по тесту выделяем триггеры
# получаем отправленные изображения
## установи в PoisonedDataset poisoned_indices в соответствии с test_lables_1d
attack_specification = specifications[key]
poisoned_dataset = utils.PoisonedDataset(clean_dataset, attack_specification) 
poisoned_dataset.poisoned_indices = test_lables_1d[test_df.key == int(key)].nonzero()[0]
dev_dataloader = torch.utils.data.DataLoader(poisoned_dataset, batch_size=batch_size, shuffle=False, num_workers=8)
nn_model = models[key]
nn_model = nn_model.eval()
nn_model = nn_model.cuda()

# получаем кластера для предсказанных отравленных изображений
is_poisoned_pred = predicted_test_lables[test_df.key == int(key)].astype(bool)
cur_key_dataset = test_df[test_df.key == int(key)]
cur_key_dataset_pred_poisoned_part = cur_key_dataset[is_poisoned_pred]

pred_poisoned_target_classes = cur_key_dataset_pred_poisoned_part.all_pred_label.unique()
for target_class in pred_poisoned_target_classes:
    cur_key_class_dataset_pred_poisoned_part = cur_key_dataset_pred_poisoned_part[
        cur_key_dataset_pred_poisoned_part.all_pred_label == target_class]

    if cur_key_class_dataset_pred_poisoned_part.shape[0] < cur_key_dataset.shape[0] / 200:
        continue  # мы не хотим анализировать слишком маленькие кластера, поскольку это займёт много времени

    cluster_means = [] # cluster -> mean image
    triggers = [] # cluster -> trigger
    ### усредняем изображения для кластеров класса
    assert(cur_key_dataset.shape[0] == len(is_poisoned_pred))
    poisoned_images_pred = torch.stack(list(
        get_filtered_data(dev_dataloader, is_poisoned_pred, cur_key_dataset.all_pred_label.values, target_class)
    ))
    ### ещё раз применить фильтрацию для получения изображений конкретного кластера
    for cluster in cur_key_class_dataset_pred_poisoned_part.all_clusters.unique():
        poisonde_images_for_cluster = poisoned_images_pred[
           (cur_key_class_dataset_pred_poisoned_part.all_clusters == cluster).values
        ]
        samples_cnt = poisonde_images_for_cluster.shape[0]
        mean_image = poisonde_images_for_cluster.sum(axis=0) / samples_cnt
        trigger_filtered = torch.where(
            (mean_image >= (mean_image.max()-0.1)) 
            | (mean_image <= 0.1), mean_image, torch.zeros_like(mean_image)
        )
        triggers.append(trigger_filtered)
        cluster_means.append(mean_image)

    triggers = torch.stack(triggers)
    cluster_means = torch.stack(cluster_means)

    # 2. применяем триггеры к чистому датасету
    my_poisoned_dataset = PoisonedDataset(clean_dataset, target_class, triggers, cluster_means)
    my_poisoned_loader = torch.utils.data.DataLoader(my_poisoned_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)


    # 3. смотрим качество отравления
    _, attack_success_rate = utils.evaluate(my_poisoned_loader, nn_model)
    
    print(f"{attack_success_rate=}\n{target_class=}\n{key=}")

    


In [None]:
fig, ax = plt.subplots(nrows=1, ncols=len(cluster_means), figsize=(10, 10))

for i in range(len(cluster_means)):
    ax[i].imshow(triggers[i].permute(1,2,0).numpy())
    ax[i].axis('off')

plt.show()

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=len(cluster_means), figsize=(10, 10))

for i in range(len(cluster_means)):
    ax[i].imshow(cluster_means[i].permute(1,2,0).numpy())
    ax[i].axis('off')

plt.show()

In [None]:
# 2. применяем триггеры к чистому датасету
my_poisoned_dataset = PoisonedDataset(clean_dataset, 1, triggers, cluster_means)
my_poisoned_loader = torch.utils.data.DataLoader(my_poisoned_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

# break

# 3. смотрим качество отравления
_, attack_success_rate = utils.evaluate(my_poisoned_loader, nn_model)


In [None]:
attack_success_rate

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(16, 8))

for i in range(6):
    img = my_poisoned_dataset[i + 100*i][0]
    ax[i].imshow(img.permute(1,2,0).numpy())
    ax[i].axis('off')

In [None]:
triggers