In [None]:
import os
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import numpy as np

# from model import vit_base_patch16_224
# from timm.models import vit_base_patch16_224
import sys
sys.path.append('../')
from saliency import *
from utils import *
from plots import *
from run import *

In [None]:
# Arguments
model_path = '/home/raza.imam/Documents/HC701B/Project/models/vit_base_patch16_224_in21k_test-accuracy_0.96_chest.pth'
device = 'cuda'
attack = 'all' #default='all', choices=['FGSM', "PGD", "all"]
train_path = "/home/raza.imam/Documents/HC701B/Project/data/TB_data/training/"
num_train_imgs = 1000
test_path = "/home/raza.imam/Documents/HC701B/Project/data/TB_data/testing/"
dataset_class = "Tuberculosis" #default="Normal", choices=["Tuberculosis", "Normal"]
block = -1 #last block
eps = 0.02
force_recompute = False
num_test_imgs = 100

# Additional constants
ATTACK_LIST = ["PGD"]

# Get the model
model = get_model(model_path=model_path, device=device)

In [None]:
# Check if reference images exist
exits = False

if exits:
    mean_attns = {}
    for attack_name in ATTACK_LIST + ['clean']:
        mean_attns[attack_name] = np.load(
            os.path.join(
                "./reference",
                "mean_images",
                f"mean_attns_{attack_name}_block_{block}_images_{num_train_imgs}.npy"
            )
        )
    print(f'Loaded mean images from ./reference/mean_images/')

# Calculate reference images if they don't exist
if not exits:
    all_attns, mean_attns, mean_attn_diff = get_reference_attn_matp(
        model=model,
        image_folder=os.path.join(train_path, dataset_class),
        block=block,
        n_images=num_train_imgs,
        device=device,
        attack_type=ATTACK_LIST,
        select_random=False,
    )

    # Save mean images to disk
    os.makedirs(os.path.join("./reference", "mean_images"), exist_ok=True)
    print(f'Saving mean images to ./reference/mean_images/')
    for attack_name in ATTACK_LIST + ['clean']:
        np.save(
            os.path.join(
                "./reference",
                "mean_images",
                f"mean_attns_{attack_name}_block_{block}_images_{num_train_imgs}.npy"
            ),
            mean_attns[attack_name]
        )

In [None]:
len(all_attns['clean']), mean_attns['clean'].shape, mean_attns['PGD'].shape, mean_attn_diff['PGD'].shape

In [None]:
# Get test images and attention maps
test_imgs, test_attns, test_attn_diff, test_img_files = get_images_attns(
    model=model,
    image_folder=os.path.join(test_path, dataset_class),
    n_imgs=num_test_imgs,
    block=block,
    device=device,
    attack_type=ATTACK_LIST,
    eps=eps,
    plot=False,
    rand=False,
    random_state=None,
)

num_test_images = len(test_attns)

In [None]:
plt.figure(figsize=(7, 7))
text = ["Original Image", "Head Mean Clean", "Head Mean Adv"]
img_no = 14
for i, fig in enumerate([test_imgs[img_no].squeeze(0).permute(2,1,0), test_attns[img_no]['clean'], test_attns[img_no]['PGD']]):
    print(fig.shape)
    plt.subplot(1, 3, i+1)
    plt.imshow(fig, cmap='inferno')
    plt.title(text[i])
plt.show()

In [None]:
# Initialize lists for predictions
img_name = []
gt_labels = []
sum_preds = []
euc_preds = []
cos_preds = []

# Test each test image
for idx, attn_map in enumerate(test_attns):
    for i, attack_name in enumerate(ATTACK_LIST + ['clean']):
        result = classify_image(
            img_attn_map=attn_map[attack_name],
            mean_attn_clean=mean_attns['clean'],
            mean_attn_adv=mean_attns['PGD']
        )
        sum_preds.append(result['sum'])
        euc_preds.append(result['euclidean'])
        cos_preds.append(result['cosine'])
        gt_labels.append(attack_name)
        img_name.append(test_img_files[idx])

# Create a DataFrame to store results
results_dict = {
    "image": img_name,
    "GT": gt_labels,
    "sum": sum_preds,
    "euclidean": euc_preds,
    "cosine": cos_preds,
}

results_df = pd.DataFrame(results_dict)
os.makedirs("./results", exist_ok=True)

# Save results to a CSV file
results_df.to_csv(
    f"./results/preds_{dataset_class}_block_{block}_images_{num_train_imgs}_eps_{eps}.csv",
    index=False
)

# Calculate binary ground truth labels (1 for clean, 0 for adversarial)
gt_labels_bin = [1 if label == "clean" else 0 for label in gt_labels]

methods = ["sum", "euclidean", "cosine"]

# Evaluate and print results
for method in methods:
    pred_bin = [1 if label == "Clean" else 0 for label in results_dict[method]]
    print(f'--------------- {method} ---------------')
    print(f"Accuracy for {method}: {accuracy_score(gt_labels_bin, pred_bin)}")
    print(f"F1 score for {method}: {f1_score(gt_labels_bin, pred_bin)}")
    print(f'-------------------------------------------------------------------')
    print(f'-------------------------------------------------------------------')

# Log results to a file
import datetime
ct = datetime.datetime.now()
with open("./logs.txt", "a") as f:
    print('', file=f)
    print(f'--------------------------------------------- {ct} ---------------------------------------------\n', file=f)
    print(f'--------------------------Results-----------------------------\n', file=f)
    for method in methods:
        pred_bin = [1 if label == "Clean" else 0 for label in results_dict[method]]
        # Log to file
        print(f'--------------- {method} ---------------', file=f)
        print(f"Accuracy for {method}: {accuracy_score(gt_labels_bin, pred_bin)}", file=f)
        print(f"F1 score for {method}: {f1_score(gt_labels_bin, pred_bin)}", file=f)
        print(f'-------------------------------------------------------------------', file=f)

        # Print to screen
        print(f'--------------- {method} ---------------')
        print(f"Accuracy for {method}: {accuracy_score(gt_labels_bin, pred_bin)}")
        print(f"F1 score for {method}: {f1_score(gt_labels_bin, pred_bin)}")
        print(f'-------------------------------------------------------------------')