In [1]:
from utils import *
from dataset_utils import *
import lime_image
seed = 0
set_seed(seed)
from skimage.segmentation import quickshift, slic
from sal_resnet import resnet50, resnext50_32x4d, wide_resnet50_2
# from madry_models import vit_base_patch16_224 as vit_b_16

from collections import defaultdict
import torch.nn.functional as F
seed = 1

set_seed(seed)
torch.set_default_dtype(torch.float32)
torchvision.set_image_backend('PIL')

In [2]:
pixel_imagenet = PixelImageNet(IMAGENET_PATH, 
                             PIXEL_IMAGENET_PATH,
                             img_transform=None, mask_transform=None)
pixel_imagenet, _ = torch.utils.data.random_split(pixel_imagenet, [len(pixel_imagenet), 0])

In [3]:
def get_image(path):
    with open(os.path.abspath(path), 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB') 
        
def get_input_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])       
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])    

    return transf

def get_input_tensors(img):
    transf = get_input_transform()
    # unsqeeze converts single image to batch of 1
    return transf(img).unsqueeze(0)

def get_pil_transform(): 
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224)
    ])    

    return transf

def get_preprocess_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])     
    transf = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])    

    return transf    

pill_transf = get_pil_transform()
preprocess_transform = get_preprocess_transform()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def batch_predict(model, images, masks=None):
    model.eval()
#     print(images[-1].shape)
#     plt.imshow(images[-1])
#     if len(images) > 1:
#         return
    batch = torch.stack(tuple(preprocess_transform(i) for i in images), dim=0)
    
    model.to(device)
    batch = batch.to(device)     
    with torch.no_grad():
        if masks is None:
            logits = model(batch)
        else:
            masks = 1 - torch.Tensor(masks).float().to(device)
            logits = model((batch, masks, [100, 6]))
    probs = F.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()

In [4]:
def blackout(imgs):
    return 0

def greyout(imgs):
    return (255*np.array(preprocess_transform.transforms[1].mean)).astype(np.uint8)

In [5]:
def block_segment(img, patch_size):
    num_patch = 224//patch_size
    segments = np.arange(num_patch*num_patch).reshape(1, 1, num_patch, 1, num_patch, 1)
    seg_mask = np.tile(segments, (1, 1, 1, patch_size, 1, patch_size))
    seg_mask = seg_mask.reshape(224, 224)
    return seg_mask

def contour_segment(imgs, masks):
    tl = np.array(range(0, 20))/20
    seg_masks = torch.zeros_like(masks, dtype=torch.uint8)
    for i, t in enumerate(tl):
        seg_masks[(masks > t)] = (i+1)
    return seg_masks

# def sklearn_segment(img, skseg_fn=quickshift):
#     seg_masks = []
#     for i, img in enumerate(imgs):
#         img = np.transpose((1,2,0))
#         sm = skseg_fn(img)
#         seg_masks.append(sm[None,:])
#     seg_masks = torch.stack(seg_masks, dim=0)
#     return seg_masks

In [6]:
from scikit_image import *
from skimage.segmentation import mark_boundaries

In [7]:
def mask_seg(seg_inds, seg_maps):
    return (torch.sum((seg_inds[:,:,None,None,None] == seg_maps[:,  None, :, :, :]), dim=1) > 0).float()

def topk_ablate(model, imgs, labels, seg_maps, seg_order, num_segs):
    mask_seg_inds = seg_order[:,-num_segs:]
    masks = mask_seg(mask_seg_inds, seg_maps)
    transf_imgs = torch.stack([preprocess_transform(img) for img in imgs])
    with torch.no_grad():
        logits = model((transf_imgs)*(1-masks))
    preds = logits.argmax(-1).cpu()
    return (preds == labels).float().mean()

In [8]:
def jacc_sim(m1, m2):
    return np.sum(m1*m2)/np.clip(m1+m2,0,1).sum()

In [9]:
from alexnet import alexnet
from efficientnet import efficientnet_b0
from squeezenet import squeezenet1_1

from mobilenet import mobilenet_v2, mobilenet_v3_large
from densenet import densenet121

lime_bs=512
eval_bs = 32
eval_num_samples = 195
lime_num_samples=512
topk = 1
ps = 16
prop = 0.5
eps = 1e-8

model_list = [
            (resnet50, 'resnet50'), 
            (wide_resnet50_2, 'wide_resnet50'),
            (squeezenet1_1, 'SqueezeNet'),
            (alexnet, 'AlexNet'),
            (densenet121, 'DenseNet'),
            (efficientnet_b0, 'EfficientNet'),
            (mobilenet_v3_large, 'MobileNet'),
             ]

for model_type, model_name in model_list:
    model = model_type(pretrained=True)
    model = MyDataParallel(model).to(device)
    model = model.eval()    

    desc = f'pixel_imagenet_{model_name}'#_{ps}x{ps}'


    seg_fns = [
                SegmentationAlgorithm('quickshift', kernel_size=2,
                                        max_dist=200, ratio=0.2,
                                        random_seed=0),
                lambda x: block_segment(x, ps),
                SegmentationAlgorithm('slic', n_segments=250,
                                        random_seed=0),
                ]

    if 'pixel' in desc:
        dataset = pixel_imagenet
    else:
        dataset = salient_imagenet
        dataset.dataset.transform = None

    miss_approxns = [blackout, greyout, None]
    miss_approxns_name = ['Blackout', 'Greyout', 'Layer mask']
    cov_list = []
    total = 0
    pred_fn = lambda im, ma=None: batch_predict(model, im, ma)
    all_scores = []
    all_imgs, all_labels = [], []
    all_segmaps = []
    all_segorders = [[list() for _ in miss_approxns] for _ in seg_fns]
    all_acc = [[list() for _ in miss_approxns] for _ in seg_fns]
    for im_no, sample in enumerate(dataset):

        if total >= eval_num_samples:
            break
        if sample is None:
            continue
        (img, mask, label) = sample

        if 'pixel' in desc:
            mask = pill_transf(mask)
            mask = np.asarray(mask)[:,:,0]
        else:
            label = label.item()
        img = pill_transf(img)

        probs = pred_fn([np.array(img)])
        topk_inds = torch.topk(torch.Tensor(probs[0]), topk).indices
    #     print(topk_inds)
        if label not in topk_inds:
            continue
        all_imgs.append(np.array(img))
        all_labels.append(label)
        total += 1
        coverage = np.zeros((len(seg_fns), len(miss_approxns_name)))
        jaccsims = np.zeros((len(seg_fns), len(miss_approxns_name)))
        segmaps = []


        for si, seg_fn in enumerate(seg_fns):
            segments = seg_fn(img)
            segmaps.append(torch.Tensor(segments[None]))
            complete = True
            n_features = np.unique(segments).shape[0]
            print(n_features)
            data = np.random.random_sample(size=(lime_num_samples, n_features))
            data[data < prop] = 0
            data[data >= prop] = 1
            mask = mask.astype(float)
            gt = np.array([np.sum((segments == i)*(mask - mask.mean())) for i in range(n_features)])
            for j, miss_approx in enumerate(miss_approxns):
                if miss_approxns_name[j] == 'Layer mask':
                    explainer = lime_image.MyLimeImageExplainer()
                    explanation = explainer.explain_instance(np.array(img), 
                                                             pred_fn, 
                                                             batch_size=lime_bs,
                                                             top_labels=topk, 
                                                             data=data,
                                                             segmentation_fn=seg_fn,
                                                             num_samples=lime_num_samples,
                                                             thresh=1000,
                                                             prop=prop) 
                else:
                    baseline = np.array(miss_approx(img))
                    explainer = lime_image.LimeImageExplainer()
                    explanation = explainer.explain_instance(np.array(img), 
                                                             pred_fn, 
                                                             batch_size=lime_bs,
                                                             top_labels=topk,
                                                             data=data,
                                                             hide_color=baseline, 
                                                             segmentation_fn=seg_fn,
                                                             num_samples=lime_num_samples,
                                                             thresh=1000,
                                                             prop=0.5) 

                if label not in explanation.local_exp.keys():
                    print('Skipping..')
                    complete = False
                    break
                score_dict = dict(explanation.local_exp[label])
                scores = np.array([score_dict[i] for i in range(len(score_dict))])
                coverage[si][j] = (scores*gt).sum()/(np.linalg.norm((scores))*np.linalg.norm(gt) + eps)

                segorder = np.argsort(scores)
                all_segorders[si][j].append(torch.LongTensor(segorder))
                high_score_inds = segorder[-20:]
                topk_mask = mask_seg(torch.Tensor(high_score_inds[None]), torch.Tensor(segments[None, None]))[0,0].numpy()
                jaccsims[si][j] = jacc_sim(mask, topk_mask)  #np.mean(gt[high_score_inds])

                all_scores.append(((im_no, si, j), coverage))
                temp, imask = explanation.get_image_and_mask(label, positive_only=False, num_features=20, hide_rest=False)
                lime_exp_img = mark_boundaries(temp/255.0, imask)
            if not complete:
                continue
        all_segmaps.append(torch.stack(segmaps))
        cov_list.append((label, coverage, jaccsims))
        a = np.stack([x for _, _, x in cov_list])
        best = np.argmax(a, axis=-1)
        print([(best == i).mean(0) for i in range(3)])
        print(np.stack([np.stack((x,y)) for _, x, y in cov_list]).mean(axis=0))
        if total%eval_bs == 0 and im_no != 0:
            all_imgs, all_labels = np.stack(all_imgs), torch.LongTensor(all_labels)
            all_segmaps = torch.stack(all_segmaps)#[ for segmaps in all_segmaps])
            all_segmaps = all_segmaps.transpose(0,1)

            all_segorders = [[torch.stack([x[-20:] for x in so]) for so in so2] for so2 in all_segorders]

            for si, so2 in enumerate(all_segorders):
                for j, seg_order in enumerate(so2):
                    all_acc[si][j].append(topk_ablate(model, all_imgs, all_labels, all_segmaps[si], seg_order, 20))
            all_imgs, all_labels = [], []
            all_segmaps, all_segorders = [], [[list() for _ in miss_approxns] for _ in seg_fns]
            print(np.mean(np.array(all_acc), axis=-1))
            with open(f'./results/top20_ablations_new_{model_name}.pkl', 'wb+') as fp:
                pickle.dump(all_acc, fp)
            with open(f'./results/cov_list_new_{model_name}.pkl', 'wb+') as fp:
                pickle.dump(cov_list, fp)



ModuleNotFoundError: No module named 'convnext'

In [None]:
with open(f'./results/cov_list_new_{model_name}.pkl', 'wb+') as fp:
    pickle.dump(cov_list, fp)