In [1]:
import argparse
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import models, transforms
import torch.nn as nn
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, \
    deprocess_image, \
    preprocess_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import matplotlib.pyplot as plt
import os
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Data Loader
class CSVDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file, delimiter=' ', header=None)
        self.class_num = 15
        self.targets = self.img_labels.iloc[:, 1]
        # for i in range(len(self.targets)):
        #     if self.targets[i] > self.class_num:
        #         self.targets[i] = 0
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        # image = Image.open(img_path)
        # if image.mode != 'RGB':
        #     image = image.convert('RGB')
        label = self.img_labels.iloc[idx, 1]
        # image = np.array(image)
        if label > self.class_num:
            label = 0
        
        x = cv2.imread(img_path)
        # x = x[:, :, ::-1]
        
        if self.transform:
            image = self.transform(image=x)['image']
        return image, label

    def __len__(self):
        return len(self.img_labels)

SHAP_transforms = A.Compose([
        A.Resize(256, 256),
        A.CenterCrop(224, 224),
        A.Normalize(mean=[0.2527, 0.3085, 0.3082], std=[0.1234, 0.1629, 0.1564]),
        # A.Normalize(mean=[0.2266, 0.2886, 0.2763], std=[0.1125, 0.1538, 0.1363]),
        ToTensorV2()
    ])
train_set = CSVDataset('/home/ldap/william/private/chrislin/AUO_Data_811_DA/train_Open_DA_ordered.csv', '/home/ldap/william/private/chrislin/AUO_Data_811_DA', SHAP_transforms)
test_set = CSVDataset('/hcds_vol/private/chrislin/20230830_C101/test_0830.csv', '/hcds_vol/private/chrislin/20230830_C101/', SHAP_transforms)
# train_set = CSVDataset('/home/ldap/william/private/NCU/william/all_csv/chrislin/train.csv', '/home/ldap/william/private/chrislin/AUO_Data_811_DA', SHAP_transforms)
# test_set = CSVDataset('/home/ldap/william/private/NCU/william/all_csv/chrislin/test.csv', '/hcds_vol/private/chrislin/20230830_C101/', SHAP_transforms)
# test_set = CSVDataset('/hcds_vol/private/chrislin/AUO_Data_811_DA/val_Open_ordered.csv', '/hcds_vol/private/chrislin/AUO_Data_811_DA/', SHAP_transforms)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=50, shuffle=True, num_workers=24, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=50, shuffle=False, num_workers=24, pin_memory=True)

In [3]:
import sys

# Load Model

repo_name = 'DER'
base_dir = os.path.realpath(".")[:os.path.realpath(".").index(repo_name) + len(repo_name)]
sys.path.insert(0, base_dir)

task_id = 0

import yaml
from inclearn.convnet import network
from torch.nn import DataParallel
from easydict import EasyDict as edict

    
# config_file = os.path.join(w_dir, "1.yaml")
config_file = "./configs/1.yaml"
with open(config_file, 'r') as stream:
    try:
        config = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

# device = "cuda:0"
device = "cpu"

cfg = edict(config)
model = network.BasicNet(
    cfg["convnet"],
    cfg = cfg,
    nf = cfg["channel"],
    device = device,
    use_bias = cfg["use_bias"],
    dataset = cfg["dataset"],
)
parallel_model = DataParallel(model)

total_classes = 28
increments = []
increments.append(cfg["start_class"])
for _ in range((total_classes - cfg["start_class"]) // cfg["increment"]):
    increments.append(cfg["increment"])

for i in range(task_id+1):
    model.add_classes(increments[i])
    model.task_size = increments[i]

if task_id == 0:
    state_dict = torch.load(f'./ckpts/step{task_id}.ckpt')
else:
    state_dict = torch.load(f'./ckpts/decouple_step{task_id}.ckpt')

parallel_model.cuda()
# parallel_model.to("cpu")
parallel_model.load_state_dict(state_dict)
parallel_model.eval()

2
Enable dynamical reprensetation expansion!
16
cpu


DataParallel(
  (module): BasicNet(
    (convnets): ModuleList(
      (0): ResNet(
        (conv1): Sequential(
          (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        )
        (layer1): Sequential(
          (0): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): BasicBlock(
            (conv1): Conv2d(64, 64, 

In [4]:
# 得到模型的prediction
class ResnetPrediction(torch.nn.Module):
    def __init__(self, model):
        super(ResnetPrediction, self).__init__()
        self.model = model
    def forward(self, x):
        return self.model(x)['feature']
        
        # return self.model(x)['feature']
        # return self.model(x)

In [21]:
from tqdm import tqdm
import shap
from torch.autograd import Variable

# model
M = ResnetPrediction(parallel_model.module.cuda()).cuda()
train_feature = []
test_feature = []
with torch.no_grad():
    for batch_cnt, train_data in enumerate(train_loader):
        train_images, _ = train_data
        train_images = train_images.cuda()
        if batch_cnt == 0:
            train_feature = M(train_images)
        else:
            train_feature = torch.cat([train_feature,M(train_images)])
        if batch_cnt == 19:
            break

explainer = shap.DeepExplainer(M.model.classifier, train_feature)


In [26]:
with torch.no_grad():
    for batch_cnt, test_data in enumerate(test_loader):
        test_images, _ = test_data
        test_images = test_images.cuda()
        if batch_cnt == 0:
            test_feature = M(test_images)
        else:
            test_feature = torch.cat([test_feature,M(test_images)])
shap_value = explainer.shap_values(test_feature)

In [12]:
print(len(shap_value))
print(len(shap_value[0]))
print(len(shap_value[0][0]))
# print(shap_value)

16
10269
512


In [27]:
data = test_feature.cpu().detach().numpy()
shap_value = np.array(shap_value)
for i in range(16):
    shap.summary_plot(shap_value[i],data,max_display=20,plot_type="bar",show=False)
    name = 'class' + str(i) + '.jpg'
    plt.savefig('/home/ldap/william/private/NCU/william/DER/exps/opena_811_224_DA_DCL75_Jigsaw_7x7_Mixup_20230830_200/shap_images/test/'+name)
    plt.clf()
# shap.plots.beeswarm(shap_values,max_display=20)

<Figure size 800x950 with 0 Axes>

In [28]:
train_shap_value = explainer.shap_values(train_feature)

Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.


In [30]:
train_data = train_feature.cpu().detach().numpy()
train_shap_value = np.array(train_shap_value)
for i in range(16):
    shap.summary_plot(train_shap_value[i],train_data,max_display=20,plot_type="bar",show=False)
    name = 'class' + str(i) + '.jpg'
    plt.savefig('/home/ldap/william/private/NCU/william/DER/exps/opena_811_224_DA_DCL75_Jigsaw_7x7_Mixup_20230830_200/shap_images/train/'+name)
    plt.clf()
# shap.summary_plot(train_shap_value[3],train_data,max_display=30,plot_type="bar",show=False)
#plt.savefig('/home/ldap/william/private/NCU/william/DER/exps/opena_811_224_DA_DCL75_Jigsaw_7x7_Mixup_20230830_200/shap_images/train_shap.jpg')

<Figure size 800x950 with 0 Axes>