In [3]:
import torch
import os
import pickle as pkl
import numpy as np
import torchvision
import itertools
import matplotlib.pyplot as plt
from certificate_methods import *
from utils import ReducedModel, parse
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
import cv2
import glob
from PIL import Image

In [4]:
import torch

print(torch.__version__)
print(torchvision.__version__)

1.10.1+cu102
0.11.2+cu102


In [5]:

def evaluate(model, images, labels):
    with torch.no_grad():
        outputs = model(images)
        output_probs = torch.nn.Softmax(dim = 1)(outputs).detach().cpu().numpy()
    output_labels = np.argmax(output_probs, axis = 1)
    correct = np.sum(output_labels == labels.numpy())
    total = len(labels)
    print(f" Accuracy: {correct}/{total}: {correct / total}")
def prediction_probs(model, input_data):
    prob_dicts = []
    with torch.no_grad():
        outputs = model(input_data.cuda())
        output_probs = torch.nn.Softmax(dim = 1)(outputs).detach().cpu().numpy()
        #print(output_probs.shape)

    output_labels = np.argmax(output_probs, axis = 1)
    for it in output_labels:
        prob_dicts.append({it: output_probs[0][it]})
        
    return prob_dicts

## Load model and data transformation 

In [6]:

model = torchvision.models.efficientnet_b7(pretrained=True).cuda()

model.eval()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
model_2 = torchvision.models.efficientnet_b6(pretrained=True).cuda()
model_2.eval()
model_3 = torchvision.models.efficientnet_b5(pretrained=True).cuda()
model_3.eval()

EfficientNet(
  (features): Sequential(
    (0): ConvNormActivation(
      (0): Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): ConvNormActivation(
            (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
            (1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(48, 12, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(12, 48, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): ConvNormActivatio

## Get the ground thruth class

In [7]:
def name_class(class_dict):
    name_cls = {}
    for k, v in class_dict.items():
        name_cls[v[0]] = (int(k), v[1])
    return name_cls

import json

# Open the JSON file and read its contents
with open("/home/abka03/IML/soundness_saliency/name_class.json", "r") as f:
    json_data = f.read()

# Parse the JSON data into a dictionary
data_dict = json.loads(json_data)
print(data_dict)
class_dict = name_class(data_dict)
print(class_dict)

{'0': ['n01440764', 'tench'], '1': ['n01443537', 'goldfish'], '2': ['n01484850', 'great_white_shark'], '3': ['n01491361', 'tiger_shark'], '4': ['n01494475', 'hammerhead'], '5': ['n01496331', 'electric_ray'], '6': ['n01498041', 'stingray'], '7': ['n01514668', 'cock'], '8': ['n01514859', 'hen'], '9': ['n01518878', 'ostrich'], '10': ['n01530575', 'brambling'], '11': ['n01531178', 'goldfinch'], '12': ['n01532829', 'house_finch'], '13': ['n01534433', 'junco'], '14': ['n01537544', 'indigo_bunting'], '15': ['n01558993', 'robin'], '16': ['n01560419', 'bulbul'], '17': ['n01580077', 'jay'], '18': ['n01582220', 'magpie'], '19': ['n01592084', 'chickadee'], '20': ['n01601694', 'water_ouzel'], '21': ['n01608432', 'kite'], '22': ['n01614925', 'bald_eagle'], '23': ['n01616318', 'vulture'], '24': ['n01622779', 'great_grey_owl'], '25': ['n01629819', 'European_fire_salamander'], '26': ['n01630670', 'common_newt'], '27': ['n01631663', 'eft'], '28': ['n01632458', 'spotted_salamander'], '29': ['n01632777', 

## load sample images and labels

## Explantion funtion

In [8]:

def gen_exp(input_):
    K = 1
    scale = 4
    lr = 0.5
    steps = 2
    obj = 'ent'
    noise_bs = 10
    reg_l1 = 2e-05
    reg_tv = 0.01
    reg_ent = 0.0
    debug = True
    noise_images = None # use no noise base
    with torch.no_grad():
        outputs = model(input_.cuda())
        output_probs = torch.nn.Softmax(dim = 1)(outputs).detach().cpu().numpy()
    #Force explnation
    output_labels = np.argmax(output_probs, axis = 1)
    #output_labels = [100]
    probs = torch.zeros(input_.shape[0], 1000)
    for target_label in output_labels:
        probs[:, target_label] = 1

    batch_masked_model = learn_masks_for_batch_Kcert(
                    model, input_, target_probs=probs, K=K, scale=scale,
                    opt=optim.Adam, lr=lr, steps=steps, obj=obj,
                    noise_mean=None, noise_batch=noise_images, noise_bs=noise_bs,
                    reg_l1=reg_l1, reg_tv=reg_tv, reg_ent=reg_ent, old_mask=None, debug=debug)
    masks = batch_masked_model.mask().detach().cpu()
    heatmap = masks[0]
    return heatmap, batch_masked_model
    


In [9]:
# Run this one time it will store all explantion in output_dir 
output_dir = "/mnt/sda/abka03-data/KI/explantions/ent"
image_path = "/mnt/sda/abka03-data/imagenet/imagewoof2-320/val"
jpg_files = []
for root, dirs, files in os.walk(image_path):
    for file in files:
        if file.endswith(".jpg") or file.endswith(".JPEG"):
            jpg_files.append(os.path.join(root, file))

for pt in jpg_files:
    xp_path = os.path.join (*pt.split("/")[5:-1])
    name = pt.split("/")[-1].split(".")[0]
    save_path = os.path.join(output_dir, xp_path)
    
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    exp_path = os.path.join(save_path, name)
    image = Image.open(pt).convert('RGB') 
    #class_num, class_name  = class_dict[jpg_files[select_sampe].split("/")[-2]]
    image = image.resize((224, 224))
    input_tensor = transform(image).unsqueeze(0)
    heat_map, mod = gen_exp(input_tensor)
    torch.save(heat_map, f'{exp_path}.pt')
    print(exp_path)
    
    

1 4 0.5 2 ent 10 2e-05 0.01 0.0
1: loss: 2.34, l1 norm: 24898, tv: 169.70, ent: 0.66, pred prob: 0.9327
/mnt/sda/abka03-data/KI/explantions/ent/imagewoof2-320/val/n02086240/n02086240_10642
1 4 0.5 2 ent 10 2e-05 0.01 0.0
1: loss: 2.22, l1 norm: 24670, tv: 150.54, ent: 0.67, pred prob: 0.9230
/mnt/sda/abka03-data/KI/explantions/ent/imagewoof2-320/val/n02086240/n02086240_630
1 4 0.5 2 ent 10 2e-05 0.01 0.0
1: loss: 2.46, l1 norm: 25048, tv: 141.03, ent: 0.67, pred prob: 0.8987
/mnt/sda/abka03-data/KI/explantions/ent/imagewoof2-320/val/n02086240/n02086240_9540
1 4 0.5 2 ent 10 2e-05 0.01 0.0
1: loss: 3.95, l1 norm: 24946, tv: 209.45, ent: 0.66, pred prob: 0.8092
/mnt/sda/abka03-data/KI/explantions/ent/imagewoof2-320/val/n02086240/n02086240_11860
1 4 0.5 2 ent 10 2e-05 0.01 0.0
1: loss: 1.79, l1 norm: 24771, tv: 144.90, ent: 0.67, pred prob: 0.9733
/mnt/sda/abka03-data/KI/explantions/ent/imagewoof2-320/val/n02086240/n02086240_872
1 4 0.5 2 ent 10 2e-05 0.01 0.0
1: loss: 2.06, l1 norm: 2489

In [10]:
def apply_masks_to_images(images_list, masks_list):
    """
    Applies individual single-channel masks to a list of channel-first 3-channel images.

    Args:
        images_list (list of np.array): A list of channel-first 3-channel images, each with shape (3, height, width)
        masks_list (list of np.array): A list of single-channel masks, each with shape (height, width)

    Returns:
        list of np.array: A list of masked channel-first images, each with shape (3, height, width)
    """

    if len(images_list) != len(masks_list):
        raise ValueError("The number of images and masks must be equal")

    masked_images = []

    for image, mask in zip(images_list, masks_list):
        if len(image.shape) != 3 or image.shape[0] != 3:
            raise ValueError("Each input image must have a shape of (3, height, width)")

        if len(mask.shape) != 3:
            raise ValueError("Each input mask must have a shape of (channel, height, width)")

        # Expand the mask to match the shape of the image
        s_mask = mask[0]
        expanded_mask = expanded_mask = s_mask.unsqueeze(0).repeat(3, 1, 1)
        # Multiply the image with the expanded mask
        masked_image = image * expanded_mask

        # Add the masked image to the list
        masked_images.append(masked_image)
    masked_images = torch.stack(masked_images, axis=0)
    return masked_images

In [11]:
import pandas as pd

def prediction_on_exp(model_, exp_path, input_path):
    exp_files = []
    for root, dirs, files in os.walk(exp_path):
        for file in files:
            if file.endswith(".pt") or file.endswith(".npy"):
                exp_files.append(os.path.join(root, file))
    print(f"There are total {len(exp_files)} atrribution files")
    data_model = []
    for exp in exp_files:
        im_path = os.path.join(*exp.split("/")[7:-1])

        name = exp.split("/")[-1].split(".")[0]
        final_impath = f"{os.path.join(input_dir, im_path, name)}.JPEG"
        image = Image.open(final_impath).convert('RGB')
        image = image.resize((224, 224))
        input_tensor = transform(image).unsqueeze(0)
        explan = torch.load(exp)
        masked_img = apply_masks_to_images(input_tensor, explan)
        org_prediction = prediction_probs(model_,input_tensor)
        ground_truth_cls = class_dict[exp.split("/")[-2]]
        pr = ground_truth_cls[0] == list(org_prediction[0].keys())[0]
        #print(ground_truth_cls[0] == list(org_prediction[0].keys())[0])
        #print(ground_truth_cls)
        #print(org_prediction)
        exp_size = torch.sum(explan)/ (224*224)
        predict = prediction_probs(model_,masked_img)
        row = {
            "image_id": final_impath,
            "correct": pr,
            "size": exp_size.item(),
            "score":  list(predict[0].values())[0],
            "prediction":  list(predict[0].keys())[0]
        }
        data_model.append(row)


    df = pd.DataFrame(data_model)

    return df
    
    

In [12]:
exp_path = "/mnt/sda/abka03-data/KI/explantions/ent/imagewoof2-320/val"
input_path = "/mnt/sda/abka03-data/imagenet/"
df = prediction_on_exp(model, exp_path, input_path)
df_2 = prediction_on_exp(model_2, exp_path, input_path)
df_3 = prediction_on_exp(model_3, exp_path, input_path)

There are total 3929 atrribution files
There are total 3929 atrribution files
There are total 3929 atrribution files


In [13]:
print(df_2)
print(df_3)

                                               image_id  correct      size  \
0     /mnt/sda/abka03-data/imagenet/imagewoof2-320/v...     True  0.490964   
1     /mnt/sda/abka03-data/imagenet/imagewoof2-320/v...     True  0.485282   
2     /mnt/sda/abka03-data/imagenet/imagewoof2-320/v...     True  0.497345   
3     /mnt/sda/abka03-data/imagenet/imagewoof2-320/v...     True  0.494682   
4     /mnt/sda/abka03-data/imagenet/imagewoof2-320/v...     True  0.496987   
...                                                 ...      ...       ...   
3924  /mnt/sda/abka03-data/imagenet/imagewoof2-320/v...     True  0.494632   
3925  /mnt/sda/abka03-data/imagenet/imagewoof2-320/v...    False  0.498019   
3926  /mnt/sda/abka03-data/imagenet/imagewoof2-320/v...     True  0.494518   
3927  /mnt/sda/abka03-data/imagenet/imagewoof2-320/v...     True  0.490590   
3928  /mnt/sda/abka03-data/imagenet/imagewoof2-320/v...    False  0.496995   

         score  prediction  
0     0.883003         155  
1    